""" PXDesign + Twisted Diffusion Sampling (TDS). Multi-round particle filtering with guided PXDesign: Round r: 1. Generate N particles via PXDesign with Q_theta classifier guidance 2. Score each particle with Q_theta selectivity margin 3. Compute importance weights w_i ~ exp(margin_i / temperature) 4. Resample particles (keep best, discard worst) 5. Add perturbation noise for diversity This combines in-process guidance (the "twisted proposal") with post-hoc importance-weighted resampling for highest-quality designs. Usage: python code/scripts/pxdesign_guidance/tds_pxdesign.py \ --input experiments/pxdesign_cam/output/cam_binder.json \ --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \ --ref_holo data/pdbs/cam_holo/3CLN.pdb \ --ref_apo data/pdbs/cam_apo/1CFD.pdb \ --n_particles 16 --n_rounds 4 \ --guidance_scale 0.5 \ --gpu 0 """ import os import sys import argparse import json import logging import shutil import subprocess from glob import glob import numpy as np import torch logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger(__name__) _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..')) _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..')) if _ALLO_CODE_DIR not in sys.path: sys.path.insert(0, _ALLO_CODE_DIR) def compute_ess(log_weights): """Compute effective sample size from log-weights.""" log_weights = log_weights - log_weights.max() weights = np.exp(log_weights) weights = weights / weights.sum() return 1.0 / (weights ** 2).sum() def run_guided_pxdesign_batch(input_json, outdir, n_sample, n_step, gpu, guidance_args): """Run guided PXDesign as a subprocess.""" pxdesign_python = 'python' cmd = [ pxdesign_python, os.path.join(_SCRIPT_DIR, 'guided_pxdesign.py'), '--input', input_json, '--qtheta_checkpoint', guidance_args['checkpoint'], '--ref_holo', guidance_args['ref_holo'], '--ref_apo', guidance_args['ref_apo'], '--ref_chain', guidance_args['ref_chain'], '--guidance_scale', str(guidance_args['guidance_scale']), '--guidance_start', str(guidance_args.get('guidance_start', 0.8)), '--guidance_end', str(guidance_args.get('guidance_end', 0.1)), '--N_sample', str(n_sample), '--N_step', str(n_step), '--gpu', str(gpu), '--outdir', outdir, ] env = os.environ.copy() # Inherit CUDA_VISIBLE_DEVICES from parent logger.info(f"Running guided PXDesign: {n_sample} samples -> {outdir}") result = subprocess.run(cmd, capture_output=True, text=True, env=env, timeout=7200) if result.returncode != 0: logger.error(f"PXDesign failed:\n{result.stderr[-2000:]}") return False return True def run_vanilla_pxdesign_batch(input_json, outdir, n_sample, n_step, gpu): """Run vanilla PXDesign (no guidance) as a subprocess.""" pxdesign_env = 'python' cmd = [ pxdesign_env, '-m', 'pxdesign.runner.inference', '--dump_dir', outdir, '--input', input_json, '--dtype', 'bf16', '--N_sample', str(n_sample), '--N_step', str(n_step), ] env = os.environ.copy() # Inherit CUDA_VISIBLE_DEVICES from parent logger.info(f"Running vanilla PXDesign: {n_sample} samples -> {outdir}") result = subprocess.run(cmd, capture_output=True, text=True, env=env, timeout=7200) if result.returncode != 0: logger.error(f"PXDesign failed:\n{result.stderr[-2000:]}") return False return True def collect_pdbs(outdir): """Collect PDB/CIF paths from PXDesign output directory.""" pdbs = [] for ext in ('*.pdb', '*.cif'): pdbs.extend(glob(os.path.join(outdir, '**/' + ext), recursive=True)) pdbs = sorted(pdbs) filtered = [p for p in pdbs if 'sample' in os.path.basename(p).lower() or 'design' in os.path.basename(p).lower() or 'rank' in os.path.basename(p).lower()] return filtered if filtered else pdbs def tds_particle_filter(args): """Run TDS particle filtering with PXDesign.""" from qtheta_pxdesign import QThetaPXDesignGuidance outdir = os.path.join(_ALLO_ROOT, args.outdir) os.makedirs(outdir, exist_ok=True) # Initialize scorer guidance = QThetaPXDesignGuidance( checkpoint=os.path.join(_ALLO_ROOT, args.qtheta_checkpoint), ref_holo=os.path.join(_ALLO_ROOT, args.ref_holo), ref_apo=os.path.join(_ALLO_ROOT, args.ref_apo), ref_chain=args.ref_chain, device=f'cuda:{args.gpu}', ) guidance._lazy_init() guidance_args = { 'checkpoint': args.qtheta_checkpoint, 'ref_holo': args.ref_holo, 'ref_apo': args.ref_apo, 'ref_chain': args.ref_chain, 'guidance_scale': args.guidance_scale, 'guidance_start': args.guidance_start, 'guidance_end': args.guidance_end, } all_designs = [] round_summaries = [] for round_idx in range(args.n_rounds): round_dir = os.path.join(outdir, f'round_{round_idx}') os.makedirs(round_dir, exist_ok=True) logger.info(f"\n{'='*60}") logger.info(f"TDS Round {round_idx + 1}/{args.n_rounds}") logger.info(f"{'='*60}") # Generate particles via guided PXDesign gen_dir = os.path.join(round_dir, 'generated') success = run_guided_pxdesign_batch( input_json=os.path.join(_ALLO_ROOT, args.input), outdir=gen_dir, n_sample=args.n_particles, n_step=args.N_step, gpu=args.gpu, guidance_args=guidance_args, ) if not success: logger.warning(f"Round {round_idx} generation failed, skipping") continue # Collect and score particles pdbs = collect_pdbs(gen_dir) if not pdbs: logger.warning(f"No PDBs found in round {round_idx}") continue logger.info(f"Scoring {len(pdbs)} particles...") round_results = [] for pdb_path in pdbs: result = guidance.score_design(pdb_path) if result is not None: result['pdb_path'] = pdb_path result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '') result['round'] = round_idx round_results.append(result) if not round_results: logger.warning(f"No scorable designs in round {round_idx}") continue margins = np.array([r['margin'] for r in round_results]) # Compute importance weights log_weights = margins / args.temperature ess = compute_ess(log_weights) round_summary = { 'round': round_idx, 'n_particles': len(round_results), 'margin_mean': float(margins.mean()), 'margin_std': float(margins.std()), 'margin_max': float(margins.max()), 'frac_positive': float((margins > 0).mean()), 'ess': float(ess), } round_summaries.append(round_summary) logger.info(f"Round {round_idx}: margin={margins.mean():.3f}±{margins.std():.3f}, " f"max={margins.max():.3f}, S>0={round_summary['frac_positive']:.1%}, " f"ESS={ess:.1f}/{len(round_results)}") # Add to design pool all_designs.extend(round_results) # Resample for next round (top-K selection for PXDesign since # we can't easily perturb and re-denoise) if round_idx < args.n_rounds - 1: # Copy best designs to inform next round # For PXDesign, each round generates fresh samples with guidance # Resampling influence is through the guidance strength # Increase guidance scale for later rounds guidance_args['guidance_scale'] = args.guidance_scale * (1.0 + 0.2 * (round_idx + 1)) logger.info(f"Increasing guidance scale to {guidance_args['guidance_scale']:.2f} " f"for next round") # Final summary if all_designs: all_designs.sort(key=lambda x: x['margin'], reverse=True) all_margins = np.array([d['margin'] for d in all_designs]) holo_scores = np.array([d['q_holo'] for d in all_designs]) # Best-of-K bok = {} for K in [1, 2, 5, 10]: n_trials = 2000 n_avail = len(all_margins) successes = sum( 1 for _ in range(n_trials) if all_margins[np.random.choice(n_avail, min(K, n_avail), replace=False)].max() > 0 ) bok[K] = successes / n_trials summary = { 'method': 'PXDesign + TDS', 'n_rounds': args.n_rounds, 'n_particles_per_round': args.n_particles, 'total_designs': len(all_designs), 'guidance_scale': args.guidance_scale, 'temperature': args.temperature, 'margin_mean': float(all_margins.mean()), 'margin_std': float(all_margins.std()), 'margin_max': float(all_margins.max()), 'frac_positive': float((all_margins > 0).mean()), 'q_holo_mean': float(holo_scores.mean()), 'best_of_k': {str(k): v for k, v in bok.items()}, 'round_summaries': round_summaries, 'top5': all_designs[:5], } with open(os.path.join(outdir, 'tds_scores.json'), 'w') as f: json.dump(all_designs, f, indent=2) with open(os.path.join(outdir, 'tds_summary.json'), 'w') as f: json.dump(summary, f, indent=2) # Copy best designs to top-level best_dir = os.path.join(outdir, 'best_designs') os.makedirs(best_dir, exist_ok=True) for i, d in enumerate(all_designs[:20]): if os.path.exists(d['pdb_path']): dest = os.path.join(best_dir, f'rank_{i:02d}_{d["design_id"]}.pdb') shutil.copy2(d['pdb_path'], dest) logger.info(f"\n{'='*60}") logger.info(f"PXDesign + TDS Results ({len(all_designs)} total designs)") logger.info(f" Margin: {all_margins.mean():.3f} ± {all_margins.std():.3f}") logger.info(f" Max margin: {all_margins.max():.3f}") logger.info(f" Fraction S > 0: {(all_margins > 0).mean():.1%}") logger.info(f" Q(holo) mean: {holo_scores.mean():.3f}") logger.info(f" Best-of-K:") for k, v in sorted(bok.items()): logger.info(f" K={k:3d}: {v:.3f}") logger.info(f"{'='*60}") def main(): parser = argparse.ArgumentParser(description='PXDesign + TDS') parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json') parser.add_argument('--qtheta_checkpoint', default='results/checkpoints_cam_v3/best_phase2.pt') parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb') parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb') parser.add_argument('--ref_chain', default='A') parser.add_argument('--n_particles', type=int, default=16, help='Particles per round') parser.add_argument('--n_rounds', type=int, default=4, help='Number of TDS rounds') parser.add_argument('--guidance_scale', type=float, default=0.5, help='Initial guidance scale') parser.add_argument('--guidance_start', type=float, default=0.8) parser.add_argument('--guidance_end', type=float, default=0.1) parser.add_argument('--temperature', type=float, default=0.5, help='Temperature for importance weights') parser.add_argument('--N_step', type=int, default=400) parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--outdir', default='results/pxdesign_tds') args = parser.parse_args() tds_particle_filter(args) if __name__ == '__main__': main()