| """ |
| PXDesign + Q_theta Classifier Guidance. |
| |
| Monkey-patches PXDesign's diffusion sampling loop to inject Q_theta selectivity |
| gradient after each denoising step. This steers the diffusion trajectory toward |
| binder backbones that are conformationally selective. |
| |
| The patched diffusion loop: |
| x_denoised = denoise_net(x_noisy, t_hat, ...) |
| grad = ∇_{x_denoised}[Q(holo,Y) - Q(apo,Y)] # <-- INJECTED |
| x_denoised = x_denoised + scale(t) * grad # <-- INJECTED |
| delta = (x_noisy - x_denoised) / t_hat |
| x_l = x_noisy + eta * dt * delta |
| |
| Usage: |
| python code/scripts/pxdesign_guidance/guided_pxdesign.py \ |
| --input experiments/pxdesign_cam/output/cam_binder.json \ |
| --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \ |
| --ref_holo data/pdbs/cam_holo/3CLN.pdb \ |
| --ref_apo data/pdbs/cam_apo/1CFD.pdb \ |
| --guidance_scale 1.0 \ |
| --N_sample 50 --N_step 400 \ |
| --gpu 0 |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import json |
| import logging |
| import time |
| import shutil |
| from typing import Callable, Optional, Union |
| from functools import partial |
|
|
| import numpy as np |
| import torch |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..')) |
| _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..')) |
| _PXDESIGN_DIR = os.environ.get('PXDESIGN_DIR', '') |
|
|
| if _ALLO_CODE_DIR not in sys.path: |
| sys.path.insert(0, _ALLO_CODE_DIR) |
| if _PXDESIGN_DIR not in sys.path: |
| sys.path.insert(0, _PXDESIGN_DIR) |
|
|
|
|
| def guided_sample_diffusion( |
| denoise_net: Callable, |
| input_feature_dict: dict, |
| s_inputs: torch.Tensor, |
| s_trunk: torch.Tensor, |
| z_trunk: torch.Tensor, |
| noise_schedule: torch.Tensor, |
| N_sample: int = 1, |
| gamma0: float = 0.8, |
| gamma_min: float = 1.0, |
| noise_scale_lambda: float = 1.003, |
| step_scale_eta: Union[float, dict] = {"type": "const", "min": 1.5, "max": 1.5}, |
| diffusion_chunk_size: Optional[int] = None, |
| inplace_safe: bool = False, |
| attn_chunk_size: Optional[int] = None, |
| |
| guidance_module=None, |
| guidance_scale: float = 1.0, |
| guidance_start: float = 0.8, |
| guidance_end: float = 0.1, |
| ) -> torch.Tensor: |
| """ |
| Modified PXDesign sample_diffusion with Q_theta classifier guidance. |
| |
| Same as original generator.sample_diffusion but with gradient injection |
| after each denoising step. The gradient is scaled by a schedule that |
| applies stronger guidance at high noise levels (early steps). |
| """ |
| from protenix.model.utils import centre_random_augmentation |
|
|
| N_atom = input_feature_dict["atom_to_token_idx"].size(-1) |
| batch_shape = s_inputs.shape[:-2] |
| device = s_inputs.device |
| dtype = s_inputs.dtype |
|
|
| logger.info(f"Guided sampling: scale={guidance_scale}, " |
| f"window=[{guidance_end:.1f}, {guidance_start:.1f}]") |
|
|
| def _chunk_sample_diffusion_guided(chunk_n_sample, inplace_safe): |
| x_l = noise_schedule[0] * torch.randn( |
| size=(*batch_shape, chunk_n_sample, N_atom, 3), |
| device=device, dtype=dtype |
| ) |
| T = len(noise_schedule) |
|
|
| for step_t, (c_tau_last, c_tau) in enumerate( |
| zip(noise_schedule[:-1], noise_schedule[1:]) |
| ): |
| |
| x_l = ( |
| centre_random_augmentation(x_input_coords=x_l, N_sample=1) |
| .squeeze(dim=-3) |
| .to(dtype) |
| ) |
|
|
| |
| gamma = float(gamma0) if c_tau > gamma_min else 0 |
| t_hat = c_tau_last * (gamma + 1) |
| delta_noise_level = torch.sqrt(t_hat**2 - c_tau_last**2) |
| x_noisy = x_l + noise_scale_lambda * delta_noise_level * torch.randn( |
| size=x_l.shape, device=device, dtype=dtype |
| ) |
|
|
| |
| t_hat_tensor = ( |
| t_hat.reshape((1,) * (len(batch_shape) + 1)) |
| .expand(*batch_shape, chunk_n_sample) |
| .to(dtype) |
| ) |
|
|
| |
| x_denoised = denoise_net( |
| x_noisy=x_noisy, |
| t_hat_noise_level=t_hat_tensor, |
| input_feature_dict=input_feature_dict, |
| s_inputs=s_inputs, |
| s_trunk=s_trunk, |
| z_trunk=z_trunk, |
| chunk_size=attn_chunk_size, |
| inplace_safe=inplace_safe, |
| ) |
|
|
| |
| if guidance_module is not None: |
| |
| progress = step_t / (T - 1) if T > 1 else 1.0 |
|
|
| |
| if guidance_end <= (1.0 - progress) <= guidance_start: |
| |
| x_for_grad = x_denoised |
| if x_for_grad.dim() > 3: |
| x_for_grad = x_for_grad.squeeze(0) |
|
|
| |
| noise_fraction = 1.0 - progress |
| scale = guidance_scale * noise_fraction |
|
|
| try: |
| |
| n_guide = min(chunk_n_sample, 4) |
| grad_accum = torch.zeros_like(x_for_grad) |
|
|
| for si in range(n_guide): |
| grad, margin = guidance_module.compute_guidance_gradient( |
| x_for_grad, input_feature_dict, |
| t_hat=t_hat, sample_idx=si |
| ) |
| grad_accum[si] = grad[si] if grad.shape[0] > si else grad[0] |
|
|
| |
| if n_guide < chunk_n_sample and n_guide > 0: |
| avg_grad = grad_accum[:n_guide].mean(dim=0, keepdim=True) |
| grad_accum[n_guide:] = avg_grad.expand( |
| chunk_n_sample - n_guide, -1, -1) |
|
|
| |
| grad_norm = grad_accum.norm(dim=-1, keepdim=True).clamp(min=1e-8) |
| grad_normalized = grad_accum / grad_norm |
| avg_norm = grad_norm.mean().item() |
|
|
| |
| if avg_norm > 1e-6: |
| |
| x_denoised = x_denoised + scale * avg_norm * grad_normalized |
|
|
| if step_t % 50 == 0: |
| logger.info( |
| f" Step {step_t}/{T}: margin={margin:.3f}, " |
| f"grad_norm={avg_norm:.4f}, scale={scale:.3f}") |
| except Exception as e: |
| if step_t % 100 == 0: |
| logger.debug(f" Step {step_t}: guidance failed: {e}") |
| |
|
|
| |
| delta = (x_noisy - x_denoised) / t_hat_tensor[..., None, None] |
| dt = c_tau - t_hat_tensor |
| if isinstance(step_scale_eta, float): |
| eta = step_scale_eta |
| elif step_scale_eta["type"] == "const": |
| assert step_scale_eta["min"] == step_scale_eta["max"] |
| eta = step_scale_eta["min"] |
| else: |
| eta_min, eta_max = step_scale_eta["min"], step_scale_eta["max"] |
| if step_scale_eta["type"] == "linear": |
| eta = eta_min + (eta_max - eta_min) * (step_t / T) |
| elif step_scale_eta["type"] == "poly": |
| eta = eta_min + (eta_max - eta_min) * (step_t / T) ** 2 |
| elif step_scale_eta["type"] == "cos": |
| eta = eta_min + 0.5 * (eta_max - eta_min) * ( |
| 1 - np.cos(np.pi * step_t / T)) |
| elif step_scale_eta["type"] == "piecewise": |
| eta = eta_min if step_t / T < 0.5 else eta_max |
| elif step_scale_eta["type"] == "piecewise_65": |
| eta = eta_min if step_t / T < 0.65 else eta_max |
| elif step_scale_eta["type"] == "piecewise_70": |
| eta = eta_min if step_t / T < 0.70 else eta_max |
| else: |
| raise ValueError("Unsupported eta schedule!") |
| x_l = x_noisy + eta * dt[..., None, None] * delta |
|
|
| return x_l |
|
|
| |
| if diffusion_chunk_size is None: |
| x_l = _chunk_sample_diffusion_guided(N_sample, inplace_safe=inplace_safe) |
| else: |
| x_l = [] |
| no_chunks = N_sample // diffusion_chunk_size + ( |
| N_sample % diffusion_chunk_size != 0) |
| for i in range(no_chunks): |
| chunk_n_sample = ( |
| diffusion_chunk_size |
| if i < no_chunks - 1 |
| else N_sample - i * diffusion_chunk_size |
| ) |
| chunk_x_l = _chunk_sample_diffusion_guided( |
| chunk_n_sample, inplace_safe=inplace_safe) |
| x_l.append(chunk_x_l) |
| x_l = torch.cat(x_l, -3) |
|
|
| return x_l |
|
|
|
|
| def run_guided_pxdesign(args): |
| """Run PXDesign with Q_theta classifier guidance.""" |
| if 'CUDA_VISIBLE_DEVICES' not in os.environ: |
| os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) |
|
|
| |
| from pxdesign.runner.inference import InferenceRunner, main as pxdesign_main |
| from pxdesign.utils.infer import ( |
| get_configs, convert_to_bioassembly_dict, download_inference_cache, derive_seed |
| ) |
| from pxdesign.utils.inputs import process_input_file |
| from protenix.config import save_config |
| from protenix.utils.seed import seed_everything |
| from protenix.utils.torch_utils import autocasting_disable_decorator |
|
|
| from qtheta_pxdesign import QThetaPXDesignGuidance |
|
|
| |
| outdir = args.outdir if os.path.isabs(args.outdir) else os.path.join(_ALLO_ROOT, args.outdir) |
| os.makedirs(outdir, exist_ok=True) |
|
|
| |
| pxdesign_argv = [ |
| '--dump_dir', outdir, |
| '--input', args.input, |
| '--dtype', 'bf16', |
| '--N_sample', str(args.N_sample), |
| '--N_step', str(args.N_step), |
| ] |
|
|
| configs = get_configs(pxdesign_argv) |
| configs.input_json_path = process_input_file( |
| configs.input_json_path, out_dir=outdir) |
| download_inference_cache(configs) |
|
|
| |
| save_config(configs, os.path.join(outdir, "config.yaml")) |
| with open(configs.input_json_path, "r") as f: |
| orig_inputs = json.load(f) |
| for x in orig_inputs: |
| convert_to_bioassembly_dict(x, outdir) |
| configs.input_json_path = os.path.join(outdir, "input_tasks.json") |
| with open(configs.input_json_path, "w") as f: |
| json.dump(orig_inputs, f, indent=4) |
|
|
| |
| runner = InferenceRunner(configs) |
|
|
| |
| guidance = QThetaPXDesignGuidance( |
| checkpoint=args.qtheta_checkpoint if os.path.isabs(args.qtheta_checkpoint) else os.path.join(_ALLO_ROOT, args.qtheta_checkpoint), |
| ref_holo=args.ref_holo if os.path.isabs(args.ref_holo) else os.path.join(_ALLO_ROOT, args.ref_holo), |
| ref_apo=args.ref_apo if os.path.isabs(args.ref_apo) else os.path.join(_ALLO_ROOT, args.ref_apo), |
| ref_chain=args.ref_chain, |
| device='cuda:0', |
| esm_target=args.esm_target, |
| ) |
|
|
| |
| from pxdesign.model import generator as pxdesign_generator |
| import pxdesign.model.pxdesign as pxdesign_model |
|
|
| |
| guided_fn = partial( |
| guided_sample_diffusion, |
| guidance_module=guidance, |
| guidance_scale=args.guidance_scale, |
| guidance_start=args.guidance_start, |
| guidance_end=args.guidance_end, |
| ) |
|
|
| |
| pxdesign_generator.sample_diffusion = guided_fn |
|
|
| |
| |
| |
| |
| pxdesign_model.sample_diffusion = guided_fn |
|
|
| logger.info("PXDesign diffusion loop patched with Q_theta guidance") |
|
|
| |
| seeds = [derive_seed(time.time_ns())] if not configs.seeds else configs.seeds |
| for seed in seeds: |
| logger.info(f"Running guided inference with seed {seed}") |
| seed_everything(seed=seed, deterministic=False) |
| runner._inference(seed) |
|
|
| |
| logger.info("Scoring generated designs...") |
| from glob import glob |
|
|
| pdb_dir = outdir |
| pdbs = [] |
| for ext in ('*.pdb', '*.cif'): |
| pdbs.extend(glob(os.path.join(pdb_dir, '**/' + ext), recursive=True)) |
| pdbs = sorted([p for p in pdbs if 'sample' in os.path.basename(p).lower()]) |
|
|
| results = [] |
| for i, pdb_path in enumerate(pdbs): |
| design_id = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '') |
| result = guidance.score_design(pdb_path) |
| if result is not None: |
| result['design_id'] = design_id |
| result['pdb_path'] = pdb_path |
| results.append(result) |
| logger.info( |
| f"[{i+1}/{len(pdbs)}] {design_id}: " |
| f"Q+={result['q_holo']:.3f} Q-={result['q_apo']:.3f} " |
| f"S={result['margin']:+.3f}") |
|
|
| |
| if results: |
| results.sort(key=lambda x: x['margin'], reverse=True) |
| margins = np.array([r['margin'] for r in results]) |
|
|
| summary = { |
| 'method': 'PXDesign + Classifier Guidance', |
| 'n_designs': len(results), |
| 'guidance_scale': args.guidance_scale, |
| 'guidance_window': [args.guidance_end, args.guidance_start], |
| 'margin_mean': float(margins.mean()), |
| 'margin_std': float(margins.std()), |
| 'frac_positive': float((margins > 0).mean()), |
| 'q_holo_mean': float(np.mean([r['q_holo'] for r in results])), |
| 'q_apo_mean': float(np.mean([r['q_apo'] for r in results])), |
| } |
|
|
| with open(os.path.join(outdir, 'guided_scores.json'), 'w') as f: |
| json.dump(results, f, indent=2) |
| with open(os.path.join(outdir, 'guided_summary.json'), 'w') as f: |
| json.dump(summary, f, indent=2) |
|
|
| logger.info(f"\n{'='*60}") |
| logger.info(f"PXDesign + Classifier Guidance Results ({len(results)} designs)") |
| logger.info(f" Margin: {margins.mean():.3f} ± {margins.std():.3f}") |
| logger.info(f" Fraction S > 0: {(margins > 0).mean():.1%}") |
| logger.info(f" Q(holo) mean: {summary['q_holo_mean']:.3f}") |
| logger.info(f"{'='*60}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='PXDesign + Q_theta Classifier Guidance') |
| parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json', |
| help='PXDesign input JSON') |
| parser.add_argument('--qtheta_checkpoint', |
| default='results/checkpoints_cam_v3/best_phase2.pt') |
| parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb') |
| parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb') |
| parser.add_argument('--ref_chain', default='A') |
| parser.add_argument('--guidance_scale', type=float, default=1.0, |
| help='Guidance gradient scale') |
| parser.add_argument('--guidance_start', type=float, default=0.8, |
| help='Start guidance at this noise fraction (high noise)') |
| parser.add_argument('--guidance_end', type=float, default=0.1, |
| help='Stop guidance at this noise fraction (low noise)') |
| parser.add_argument('--N_sample', type=int, default=50) |
| parser.add_argument('--N_step', type=int, default=400) |
| parser.add_argument('--gpu', type=int, default=0) |
| parser.add_argument('--outdir', default='results/pxdesign_guided') |
| parser.add_argument('--esm_target', default='cam', |
| help='Subdir under data/esm2_embeddings (e.g., adk, cam)') |
| args = parser.parse_args() |
|
|
| run_guided_pxdesign(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|