| """ |
| PXDesign + SMC Reranking. |
| |
| Post-hoc Sequential Monte Carlo: generate multiple batches of vanilla PXDesign |
| binders, score with Q_theta, and rank by selectivity margin. No modification |
| to the PXDesign diffusion process — pure generate-score-rank pipeline. |
| |
| This is the simplest Q_theta integration strategy: generate a large pool of |
| candidates and select the best ones by selectivity score. |
| |
| Usage: |
| python code/scripts/pxdesign_guidance/smc_pxdesign.py \ |
| --input experiments/pxdesign_cam/output/cam_binder.json \ |
| --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \ |
| --ref_holo data/pdbs/cam_holo/3CLN.pdb \ |
| --ref_apo data/pdbs/cam_apo/1CFD.pdb \ |
| --n_particles 16 --n_rounds 4 \ |
| --gpu 0 |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import json |
| import logging |
| import shutil |
| import subprocess |
| from glob import glob |
|
|
| import numpy as np |
| import torch |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..')) |
| _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..')) |
|
|
| if _ALLO_CODE_DIR not in sys.path: |
| sys.path.insert(0, _ALLO_CODE_DIR) |
|
|
|
|
| def run_pxdesign_batch(input_json, outdir, n_sample, n_step, gpu): |
| """Run vanilla PXDesign via CLI subprocess.""" |
| pxdesign_python = 'python' |
|
|
| |
| cmd = [ |
| pxdesign_python, '-m', 'pxdesign.runner.cli', 'infer', |
| '-o', outdir, |
| '-i', input_json, |
| '--dtype', 'bf16', |
| '--N_sample', str(n_sample), |
| '--N_step', str(n_step), |
| ] |
|
|
| env = os.environ.copy() |
| |
|
|
| logger.info(f"Running PXDesign: {n_sample} samples -> {outdir}") |
| result = subprocess.run(cmd, capture_output=True, text=True, env=env, |
| timeout=7200) |
|
|
| if result.returncode != 0: |
| |
| cmd_alt = [ |
| pxdesign_python, '-m', 'pxdesign.runner.inference', |
| '--dump_dir', outdir, |
| '--input', input_json, |
| '--dtype', 'bf16', |
| '--N_sample', str(n_sample), |
| '--N_step', str(n_step), |
| ] |
| result = subprocess.run(cmd_alt, capture_output=True, text=True, env=env, |
| timeout=7200) |
| if result.returncode != 0: |
| logger.error(f"PXDesign failed:\nstdout: {result.stdout[-1000:]}\nstderr: {result.stderr[-1000:]}") |
| return False |
| return True |
|
|
|
|
| def collect_pdbs(outdir): |
| """Collect PDB/CIF files from PXDesign output.""" |
| pdbs = [] |
| for ext in ('*.pdb', '*.cif'): |
| pdbs.extend(glob(os.path.join(outdir, '**/' + ext), recursive=True)) |
| pdbs = sorted(pdbs) |
| filtered = [p for p in pdbs if 'sample' in os.path.basename(p).lower() |
| or 'design' in os.path.basename(p).lower() |
| or 'rank' in os.path.basename(p).lower()] |
| return filtered if filtered else pdbs |
|
|
|
|
| def smc_particle_filter(args): |
| """Run SMC reranking with PXDesign.""" |
| os.chdir(_ALLO_ROOT) |
|
|
| from qtheta_pxdesign import QThetaPXDesignGuidance |
|
|
| outdir = args.outdir |
| os.makedirs(outdir, exist_ok=True) |
|
|
| |
| guidance = QThetaPXDesignGuidance( |
| checkpoint=args.qtheta_checkpoint, |
| ref_holo=args.ref_holo, |
| ref_apo=args.ref_apo, |
| ref_chain=args.ref_chain, |
| device=f'cuda:{args.gpu}', |
| ) |
| guidance._lazy_init() |
|
|
| all_designs = [] |
| round_summaries = [] |
|
|
| for round_idx in range(args.n_rounds): |
| round_dir = os.path.join(outdir, f'round_{round_idx}') |
| os.makedirs(round_dir, exist_ok=True) |
|
|
| logger.info(f"\n{'='*60}") |
| logger.info(f"SMC Round {round_idx + 1}/{args.n_rounds}") |
| logger.info(f"{'='*60}") |
|
|
| |
| gen_dir = os.path.join(round_dir, 'generated') |
| success = run_pxdesign_batch( |
| input_json=args.input, |
| outdir=gen_dir, |
| n_sample=args.n_particles, |
| n_step=args.N_step, |
| gpu=args.gpu, |
| ) |
|
|
| if not success: |
| |
| logger.warning(f"Round {round_idx} generation failed. " |
| f"Checking for existing outputs...") |
| pdbs = collect_pdbs(args.designs_dir) if hasattr(args, 'designs_dir') else [] |
| if not pdbs: |
| continue |
| else: |
| pdbs = collect_pdbs(gen_dir) |
|
|
| if not pdbs: |
| logger.warning(f"No PDBs found in round {round_idx}") |
| continue |
|
|
| |
| logger.info(f"Scoring {len(pdbs)} particles...") |
| round_results = [] |
| for pdb_path in pdbs: |
| result = guidance.score_design(pdb_path) |
| if result is not None: |
| result['pdb_path'] = pdb_path |
| result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '') |
| result['round'] = round_idx |
| round_results.append(result) |
|
|
| if not round_results: |
| continue |
|
|
| margins = np.array([r['margin'] for r in round_results]) |
|
|
| round_summary = { |
| 'round': round_idx, |
| 'n_particles': len(round_results), |
| 'margin_mean': float(margins.mean()), |
| 'margin_std': float(margins.std()), |
| 'margin_max': float(margins.max()), |
| 'frac_positive': float((margins > 0).mean()), |
| } |
| round_summaries.append(round_summary) |
|
|
| logger.info(f"Round {round_idx}: margin={margins.mean():.3f}±{margins.std():.3f}, " |
| f"max={margins.max():.3f}, S>0={round_summary['frac_positive']:.1%}") |
|
|
| all_designs.extend(round_results) |
|
|
| |
| if all_designs: |
| all_designs.sort(key=lambda x: x['margin'], reverse=True) |
| all_margins = np.array([d['margin'] for d in all_designs]) |
| holo_scores = np.array([d['q_holo'] for d in all_designs]) |
|
|
| |
| bok = {} |
| for K in [1, 2, 5, 10]: |
| n_trials = 2000 |
| n_avail = len(all_margins) |
| successes = sum( |
| 1 for _ in range(n_trials) |
| if all_margins[np.random.choice(n_avail, min(K, n_avail), replace=False)].max() > 0 |
| ) |
| bok[K] = successes / n_trials |
|
|
| summary = { |
| 'method': 'PXDesign + SMC', |
| 'n_rounds': args.n_rounds, |
| 'n_particles_per_round': args.n_particles, |
| 'total_designs': len(all_designs), |
| 'margin_mean': float(all_margins.mean()), |
| 'margin_std': float(all_margins.std()), |
| 'margin_max': float(all_margins.max()), |
| 'frac_positive': float((all_margins > 0).mean()), |
| 'q_holo_mean': float(holo_scores.mean()), |
| 'q_apo_mean': float(np.mean([d['q_apo'] for d in all_designs])), |
| 'best_of_k': {str(k): v for k, v in bok.items()}, |
| 'round_summaries': round_summaries, |
| 'top5': all_designs[:5], |
| } |
|
|
| with open(os.path.join(outdir, 'smc_scores.json'), 'w') as f: |
| json.dump(all_designs, f, indent=2) |
| with open(os.path.join(outdir, 'smc_summary.json'), 'w') as f: |
| json.dump(summary, f, indent=2) |
|
|
| |
| best_dir = os.path.join(outdir, 'best_designs') |
| os.makedirs(best_dir, exist_ok=True) |
| for i, d in enumerate(all_designs[:20]): |
| if os.path.exists(d['pdb_path']): |
| dest = os.path.join(best_dir, f'rank_{i:02d}_{d["design_id"]}.pdb') |
| shutil.copy2(d['pdb_path'], dest) |
|
|
| logger.info(f"\n{'='*60}") |
| logger.info(f"PXDesign + SMC Results ({len(all_designs)} total designs)") |
| logger.info(f" Margin: {all_margins.mean():.3f} ± {all_margins.std():.3f}") |
| logger.info(f" Max margin: {all_margins.max():.3f}") |
| logger.info(f" Fraction S > 0: {(all_margins > 0).mean():.1%}") |
| logger.info(f" Q(holo) mean: {holo_scores.mean():.3f}") |
| logger.info(f" Best-of-K:") |
| for k, v in sorted(bok.items()): |
| logger.info(f" K={k:3d}: {v:.3f}") |
| logger.info(f"{'='*60}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='PXDesign + SMC Reranking') |
| parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json', |
| help='PXDesign input JSON') |
| parser.add_argument('--designs_dir', default='experiments/pxdesign_cam/output/', |
| help='Existing PXDesign outputs (fallback if generation fails)') |
| parser.add_argument('--qtheta_checkpoint', |
| default='results/checkpoints_cam_v3/best_phase2.pt') |
| parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb') |
| parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb') |
| parser.add_argument('--ref_chain', default='A') |
| parser.add_argument('--n_particles', type=int, default=16, |
| help='Particles per round') |
| parser.add_argument('--n_rounds', type=int, default=4, |
| help='Number of SMC rounds') |
| parser.add_argument('--N_step', type=int, default=400) |
| parser.add_argument('--gpu', type=int, default=0) |
| parser.add_argument('--outdir', default='results/pxdesign_smc') |
| args = parser.parse_args() |
|
|
| smc_particle_filter(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|