| """ |
| PXDesign + Twisted Diffusion Sampling (TDS). |
| |
| Multi-round particle filtering with guided PXDesign: |
| Round r: |
| 1. Generate N particles via PXDesign with Q_theta classifier guidance |
| 2. Score each particle with Q_theta selectivity margin |
| 3. Compute importance weights w_i ~ exp(margin_i / temperature) |
| 4. Resample particles (keep best, discard worst) |
| 5. Add perturbation noise for diversity |
| |
| This combines in-process guidance (the "twisted proposal") with post-hoc |
| importance-weighted resampling for highest-quality designs. |
| |
| Usage: |
| python code/scripts/pxdesign_guidance/tds_pxdesign.py \ |
| --input experiments/pxdesign_cam/output/cam_binder.json \ |
| --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \ |
| --ref_holo data/pdbs/cam_holo/3CLN.pdb \ |
| --ref_apo data/pdbs/cam_apo/1CFD.pdb \ |
| --n_particles 16 --n_rounds 4 \ |
| --guidance_scale 0.5 \ |
| --gpu 0 |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import json |
| import logging |
| import shutil |
| import subprocess |
| from glob import glob |
|
|
| import numpy as np |
| import torch |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..')) |
| _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..')) |
|
|
| if _ALLO_CODE_DIR not in sys.path: |
| sys.path.insert(0, _ALLO_CODE_DIR) |
|
|
|
|
| def compute_ess(log_weights): |
| """Compute effective sample size from log-weights.""" |
| log_weights = log_weights - log_weights.max() |
| weights = np.exp(log_weights) |
| weights = weights / weights.sum() |
| return 1.0 / (weights ** 2).sum() |
|
|
|
|
| def run_guided_pxdesign_batch(input_json, outdir, n_sample, n_step, |
| gpu, guidance_args): |
| """Run guided PXDesign as a subprocess.""" |
| pxdesign_python = 'python' |
| cmd = [ |
| pxdesign_python, |
| os.path.join(_SCRIPT_DIR, 'guided_pxdesign.py'), |
| '--input', input_json, |
| '--qtheta_checkpoint', guidance_args['checkpoint'], |
| '--ref_holo', guidance_args['ref_holo'], |
| '--ref_apo', guidance_args['ref_apo'], |
| '--ref_chain', guidance_args['ref_chain'], |
| '--guidance_scale', str(guidance_args['guidance_scale']), |
| '--guidance_start', str(guidance_args.get('guidance_start', 0.8)), |
| '--guidance_end', str(guidance_args.get('guidance_end', 0.1)), |
| '--N_sample', str(n_sample), |
| '--N_step', str(n_step), |
| '--gpu', str(gpu), |
| '--outdir', outdir, |
| ] |
|
|
| env = os.environ.copy() |
| |
|
|
| logger.info(f"Running guided PXDesign: {n_sample} samples -> {outdir}") |
| result = subprocess.run(cmd, capture_output=True, text=True, env=env, |
| timeout=7200) |
|
|
| if result.returncode != 0: |
| logger.error(f"PXDesign failed:\n{result.stderr[-2000:]}") |
| return False |
| return True |
|
|
|
|
| def run_vanilla_pxdesign_batch(input_json, outdir, n_sample, n_step, gpu): |
| """Run vanilla PXDesign (no guidance) as a subprocess.""" |
| pxdesign_env = 'python' |
|
|
| cmd = [ |
| pxdesign_env, '-m', 'pxdesign.runner.inference', |
| '--dump_dir', outdir, |
| '--input', input_json, |
| '--dtype', 'bf16', |
| '--N_sample', str(n_sample), |
| '--N_step', str(n_step), |
| ] |
|
|
| env = os.environ.copy() |
| |
|
|
| logger.info(f"Running vanilla PXDesign: {n_sample} samples -> {outdir}") |
| result = subprocess.run(cmd, capture_output=True, text=True, env=env, |
| timeout=7200) |
|
|
| if result.returncode != 0: |
| logger.error(f"PXDesign failed:\n{result.stderr[-2000:]}") |
| return False |
| return True |
|
|
|
|
| def collect_pdbs(outdir): |
| """Collect PDB/CIF paths from PXDesign output directory.""" |
| pdbs = [] |
| for ext in ('*.pdb', '*.cif'): |
| pdbs.extend(glob(os.path.join(outdir, '**/' + ext), recursive=True)) |
| pdbs = sorted(pdbs) |
| filtered = [p for p in pdbs if 'sample' in os.path.basename(p).lower() |
| or 'design' in os.path.basename(p).lower() |
| or 'rank' in os.path.basename(p).lower()] |
| return filtered if filtered else pdbs |
|
|
|
|
| def tds_particle_filter(args): |
| """Run TDS particle filtering with PXDesign.""" |
| from qtheta_pxdesign import QThetaPXDesignGuidance |
|
|
| outdir = os.path.join(_ALLO_ROOT, args.outdir) |
| os.makedirs(outdir, exist_ok=True) |
|
|
| |
| guidance = QThetaPXDesignGuidance( |
| checkpoint=os.path.join(_ALLO_ROOT, args.qtheta_checkpoint), |
| ref_holo=os.path.join(_ALLO_ROOT, args.ref_holo), |
| ref_apo=os.path.join(_ALLO_ROOT, args.ref_apo), |
| ref_chain=args.ref_chain, |
| device=f'cuda:{args.gpu}', |
| ) |
| guidance._lazy_init() |
|
|
| guidance_args = { |
| 'checkpoint': args.qtheta_checkpoint, |
| 'ref_holo': args.ref_holo, |
| 'ref_apo': args.ref_apo, |
| 'ref_chain': args.ref_chain, |
| 'guidance_scale': args.guidance_scale, |
| 'guidance_start': args.guidance_start, |
| 'guidance_end': args.guidance_end, |
| } |
|
|
| all_designs = [] |
| round_summaries = [] |
|
|
| for round_idx in range(args.n_rounds): |
| round_dir = os.path.join(outdir, f'round_{round_idx}') |
| os.makedirs(round_dir, exist_ok=True) |
|
|
| logger.info(f"\n{'='*60}") |
| logger.info(f"TDS Round {round_idx + 1}/{args.n_rounds}") |
| logger.info(f"{'='*60}") |
|
|
| |
| gen_dir = os.path.join(round_dir, 'generated') |
| success = run_guided_pxdesign_batch( |
| input_json=os.path.join(_ALLO_ROOT, args.input), |
| outdir=gen_dir, |
| n_sample=args.n_particles, |
| n_step=args.N_step, |
| gpu=args.gpu, |
| guidance_args=guidance_args, |
| ) |
|
|
| if not success: |
| logger.warning(f"Round {round_idx} generation failed, skipping") |
| continue |
|
|
| |
| pdbs = collect_pdbs(gen_dir) |
| if not pdbs: |
| logger.warning(f"No PDBs found in round {round_idx}") |
| continue |
|
|
| logger.info(f"Scoring {len(pdbs)} particles...") |
| round_results = [] |
| for pdb_path in pdbs: |
| result = guidance.score_design(pdb_path) |
| if result is not None: |
| result['pdb_path'] = pdb_path |
| result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '') |
| result['round'] = round_idx |
| round_results.append(result) |
|
|
| if not round_results: |
| logger.warning(f"No scorable designs in round {round_idx}") |
| continue |
|
|
| margins = np.array([r['margin'] for r in round_results]) |
|
|
| |
| log_weights = margins / args.temperature |
| ess = compute_ess(log_weights) |
|
|
| round_summary = { |
| 'round': round_idx, |
| 'n_particles': len(round_results), |
| 'margin_mean': float(margins.mean()), |
| 'margin_std': float(margins.std()), |
| 'margin_max': float(margins.max()), |
| 'frac_positive': float((margins > 0).mean()), |
| 'ess': float(ess), |
| } |
| round_summaries.append(round_summary) |
|
|
| logger.info(f"Round {round_idx}: margin={margins.mean():.3f}±{margins.std():.3f}, " |
| f"max={margins.max():.3f}, S>0={round_summary['frac_positive']:.1%}, " |
| f"ESS={ess:.1f}/{len(round_results)}") |
|
|
| |
| all_designs.extend(round_results) |
|
|
| |
| |
| if round_idx < args.n_rounds - 1: |
| |
| |
| |
| |
| guidance_args['guidance_scale'] = args.guidance_scale * (1.0 + 0.2 * (round_idx + 1)) |
| logger.info(f"Increasing guidance scale to {guidance_args['guidance_scale']:.2f} " |
| f"for next round") |
|
|
| |
| if all_designs: |
| all_designs.sort(key=lambda x: x['margin'], reverse=True) |
| all_margins = np.array([d['margin'] for d in all_designs]) |
| holo_scores = np.array([d['q_holo'] for d in all_designs]) |
|
|
| |
| bok = {} |
| for K in [1, 2, 5, 10]: |
| n_trials = 2000 |
| n_avail = len(all_margins) |
| successes = sum( |
| 1 for _ in range(n_trials) |
| if all_margins[np.random.choice(n_avail, min(K, n_avail), replace=False)].max() > 0 |
| ) |
| bok[K] = successes / n_trials |
|
|
| summary = { |
| 'method': 'PXDesign + TDS', |
| 'n_rounds': args.n_rounds, |
| 'n_particles_per_round': args.n_particles, |
| 'total_designs': len(all_designs), |
| 'guidance_scale': args.guidance_scale, |
| 'temperature': args.temperature, |
| 'margin_mean': float(all_margins.mean()), |
| 'margin_std': float(all_margins.std()), |
| 'margin_max': float(all_margins.max()), |
| 'frac_positive': float((all_margins > 0).mean()), |
| 'q_holo_mean': float(holo_scores.mean()), |
| 'best_of_k': {str(k): v for k, v in bok.items()}, |
| 'round_summaries': round_summaries, |
| 'top5': all_designs[:5], |
| } |
|
|
| with open(os.path.join(outdir, 'tds_scores.json'), 'w') as f: |
| json.dump(all_designs, f, indent=2) |
| with open(os.path.join(outdir, 'tds_summary.json'), 'w') as f: |
| json.dump(summary, f, indent=2) |
|
|
| |
| best_dir = os.path.join(outdir, 'best_designs') |
| os.makedirs(best_dir, exist_ok=True) |
| for i, d in enumerate(all_designs[:20]): |
| if os.path.exists(d['pdb_path']): |
| dest = os.path.join(best_dir, f'rank_{i:02d}_{d["design_id"]}.pdb') |
| shutil.copy2(d['pdb_path'], dest) |
|
|
| logger.info(f"\n{'='*60}") |
| logger.info(f"PXDesign + TDS Results ({len(all_designs)} total designs)") |
| logger.info(f" Margin: {all_margins.mean():.3f} ± {all_margins.std():.3f}") |
| logger.info(f" Max margin: {all_margins.max():.3f}") |
| logger.info(f" Fraction S > 0: {(all_margins > 0).mean():.1%}") |
| logger.info(f" Q(holo) mean: {holo_scores.mean():.3f}") |
| logger.info(f" Best-of-K:") |
| for k, v in sorted(bok.items()): |
| logger.info(f" K={k:3d}: {v:.3f}") |
| logger.info(f"{'='*60}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='PXDesign + TDS') |
| parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json') |
| parser.add_argument('--qtheta_checkpoint', |
| default='results/checkpoints_cam_v3/best_phase2.pt') |
| parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb') |
| parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb') |
| parser.add_argument('--ref_chain', default='A') |
| parser.add_argument('--n_particles', type=int, default=16, |
| help='Particles per round') |
| parser.add_argument('--n_rounds', type=int, default=4, |
| help='Number of TDS rounds') |
| parser.add_argument('--guidance_scale', type=float, default=0.5, |
| help='Initial guidance scale') |
| parser.add_argument('--guidance_start', type=float, default=0.8) |
| parser.add_argument('--guidance_end', type=float, default=0.1) |
| parser.add_argument('--temperature', type=float, default=0.5, |
| help='Temperature for importance weights') |
| parser.add_argument('--N_step', type=int, default=400) |
| parser.add_argument('--gpu', type=int, default=0) |
| parser.add_argument('--outdir', default='results/pxdesign_tds') |
| args = parser.parse_args() |
|
|
| tds_particle_filter(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|