""" PXDesign + SMC Reranking. Post-hoc Sequential Monte Carlo: generate multiple batches of vanilla PXDesign binders, score with Q_theta, and rank by selectivity margin. No modification to the PXDesign diffusion process — pure generate-score-rank pipeline. This is the simplest Q_theta integration strategy: generate a large pool of candidates and select the best ones by selectivity score. Usage: python code/scripts/pxdesign_guidance/smc_pxdesign.py \ --input experiments/pxdesign_cam/output/cam_binder.json \ --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \ --ref_holo data/pdbs/cam_holo/3CLN.pdb \ --ref_apo data/pdbs/cam_apo/1CFD.pdb \ --n_particles 16 --n_rounds 4 \ --gpu 0 """ import os import sys import argparse import json import logging import shutil import subprocess from glob import glob import numpy as np import torch logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger(__name__) _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..')) _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..')) if _ALLO_CODE_DIR not in sys.path: sys.path.insert(0, _ALLO_CODE_DIR) def run_pxdesign_batch(input_json, outdir, n_sample, n_step, gpu): """Run vanilla PXDesign via CLI subprocess.""" pxdesign_python = 'python' # Use pxdesign CLI cmd = [ pxdesign_python, '-m', 'pxdesign.runner.cli', 'infer', '-o', outdir, '-i', input_json, '--dtype', 'bf16', '--N_sample', str(n_sample), '--N_step', str(n_step), ] env = os.environ.copy() # Inherit CUDA_VISIBLE_DEVICES from parent logger.info(f"Running PXDesign: {n_sample} samples -> {outdir}") result = subprocess.run(cmd, capture_output=True, text=True, env=env, timeout=7200) if result.returncode != 0: # Try alternative invocation via module cmd_alt = [ pxdesign_python, '-m', 'pxdesign.runner.inference', '--dump_dir', outdir, '--input', input_json, '--dtype', 'bf16', '--N_sample', str(n_sample), '--N_step', str(n_step), ] result = subprocess.run(cmd_alt, capture_output=True, text=True, env=env, timeout=7200) if result.returncode != 0: logger.error(f"PXDesign failed:\nstdout: {result.stdout[-1000:]}\nstderr: {result.stderr[-1000:]}") return False return True def collect_pdbs(outdir): """Collect PDB/CIF files from PXDesign output.""" pdbs = [] for ext in ('*.pdb', '*.cif'): pdbs.extend(glob(os.path.join(outdir, '**/' + ext), recursive=True)) pdbs = sorted(pdbs) filtered = [p for p in pdbs if 'sample' in os.path.basename(p).lower() or 'design' in os.path.basename(p).lower() or 'rank' in os.path.basename(p).lower()] return filtered if filtered else pdbs def smc_particle_filter(args): """Run SMC reranking with PXDesign.""" os.chdir(_ALLO_ROOT) from qtheta_pxdesign import QThetaPXDesignGuidance outdir = args.outdir os.makedirs(outdir, exist_ok=True) # Initialize scorer guidance = QThetaPXDesignGuidance( checkpoint=args.qtheta_checkpoint, ref_holo=args.ref_holo, ref_apo=args.ref_apo, ref_chain=args.ref_chain, device=f'cuda:{args.gpu}', ) guidance._lazy_init() all_designs = [] round_summaries = [] for round_idx in range(args.n_rounds): round_dir = os.path.join(outdir, f'round_{round_idx}') os.makedirs(round_dir, exist_ok=True) logger.info(f"\n{'='*60}") logger.info(f"SMC Round {round_idx + 1}/{args.n_rounds}") logger.info(f"{'='*60}") # Generate particles via vanilla PXDesign gen_dir = os.path.join(round_dir, 'generated') success = run_pxdesign_batch( input_json=args.input, outdir=gen_dir, n_sample=args.n_particles, n_step=args.N_step, gpu=args.gpu, ) if not success: # If subprocess fails, try using existing PXDesign outputs logger.warning(f"Round {round_idx} generation failed. " f"Checking for existing outputs...") pdbs = collect_pdbs(args.designs_dir) if hasattr(args, 'designs_dir') else [] if not pdbs: continue else: pdbs = collect_pdbs(gen_dir) if not pdbs: logger.warning(f"No PDBs found in round {round_idx}") continue # Score all particles logger.info(f"Scoring {len(pdbs)} particles...") round_results = [] for pdb_path in pdbs: result = guidance.score_design(pdb_path) if result is not None: result['pdb_path'] = pdb_path result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '') result['round'] = round_idx round_results.append(result) if not round_results: continue margins = np.array([r['margin'] for r in round_results]) round_summary = { 'round': round_idx, 'n_particles': len(round_results), 'margin_mean': float(margins.mean()), 'margin_std': float(margins.std()), 'margin_max': float(margins.max()), 'frac_positive': float((margins > 0).mean()), } round_summaries.append(round_summary) logger.info(f"Round {round_idx}: margin={margins.mean():.3f}±{margins.std():.3f}, " f"max={margins.max():.3f}, S>0={round_summary['frac_positive']:.1%}") all_designs.extend(round_results) # Final ranking and summary if all_designs: all_designs.sort(key=lambda x: x['margin'], reverse=True) all_margins = np.array([d['margin'] for d in all_designs]) holo_scores = np.array([d['q_holo'] for d in all_designs]) # Best-of-K bok = {} for K in [1, 2, 5, 10]: n_trials = 2000 n_avail = len(all_margins) successes = sum( 1 for _ in range(n_trials) if all_margins[np.random.choice(n_avail, min(K, n_avail), replace=False)].max() > 0 ) bok[K] = successes / n_trials summary = { 'method': 'PXDesign + SMC', 'n_rounds': args.n_rounds, 'n_particles_per_round': args.n_particles, 'total_designs': len(all_designs), 'margin_mean': float(all_margins.mean()), 'margin_std': float(all_margins.std()), 'margin_max': float(all_margins.max()), 'frac_positive': float((all_margins > 0).mean()), 'q_holo_mean': float(holo_scores.mean()), 'q_apo_mean': float(np.mean([d['q_apo'] for d in all_designs])), 'best_of_k': {str(k): v for k, v in bok.items()}, 'round_summaries': round_summaries, 'top5': all_designs[:5], } with open(os.path.join(outdir, 'smc_scores.json'), 'w') as f: json.dump(all_designs, f, indent=2) with open(os.path.join(outdir, 'smc_summary.json'), 'w') as f: json.dump(summary, f, indent=2) # Copy best designs best_dir = os.path.join(outdir, 'best_designs') os.makedirs(best_dir, exist_ok=True) for i, d in enumerate(all_designs[:20]): if os.path.exists(d['pdb_path']): dest = os.path.join(best_dir, f'rank_{i:02d}_{d["design_id"]}.pdb') shutil.copy2(d['pdb_path'], dest) logger.info(f"\n{'='*60}") logger.info(f"PXDesign + SMC Results ({len(all_designs)} total designs)") logger.info(f" Margin: {all_margins.mean():.3f} ± {all_margins.std():.3f}") logger.info(f" Max margin: {all_margins.max():.3f}") logger.info(f" Fraction S > 0: {(all_margins > 0).mean():.1%}") logger.info(f" Q(holo) mean: {holo_scores.mean():.3f}") logger.info(f" Best-of-K:") for k, v in sorted(bok.items()): logger.info(f" K={k:3d}: {v:.3f}") logger.info(f"{'='*60}") def main(): parser = argparse.ArgumentParser(description='PXDesign + SMC Reranking') parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json', help='PXDesign input JSON') parser.add_argument('--designs_dir', default='experiments/pxdesign_cam/output/', help='Existing PXDesign outputs (fallback if generation fails)') parser.add_argument('--qtheta_checkpoint', default='results/checkpoints_cam_v3/best_phase2.pt') parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb') parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb') parser.add_argument('--ref_chain', default='A') parser.add_argument('--n_particles', type=int, default=16, help='Particles per round') parser.add_argument('--n_rounds', type=int, default=4, help='Number of SMC rounds') parser.add_argument('--N_step', type=int, default=400) parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--outdir', default='results/pxdesign_smc') args = parser.parse_args() smc_particle_filter(args) if __name__ == '__main__': main()