AlloGen / code /scripts /pxdesign_guidance /smc_pxdesign.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
PXDesign + SMC Reranking.
Post-hoc Sequential Monte Carlo: generate multiple batches of vanilla PXDesign
binders, score with Q_theta, and rank by selectivity margin. No modification
to the PXDesign diffusion process — pure generate-score-rank pipeline.
This is the simplest Q_theta integration strategy: generate a large pool of
candidates and select the best ones by selectivity score.
Usage:
python code/scripts/pxdesign_guidance/smc_pxdesign.py \
--input experiments/pxdesign_cam/output/cam_binder.json \
--qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
--ref_holo data/pdbs/cam_holo/3CLN.pdb \
--ref_apo data/pdbs/cam_apo/1CFD.pdb \
--n_particles 16 --n_rounds 4 \
--gpu 0
"""
import os
import sys
import argparse
import json
import logging
import shutil
import subprocess
from glob import glob
import numpy as np
import torch
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
_ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
if _ALLO_CODE_DIR not in sys.path:
sys.path.insert(0, _ALLO_CODE_DIR)
def run_pxdesign_batch(input_json, outdir, n_sample, n_step, gpu):
"""Run vanilla PXDesign via CLI subprocess."""
pxdesign_python = 'python'
# Use pxdesign CLI
cmd = [
pxdesign_python, '-m', 'pxdesign.runner.cli', 'infer',
'-o', outdir,
'-i', input_json,
'--dtype', 'bf16',
'--N_sample', str(n_sample),
'--N_step', str(n_step),
]
env = os.environ.copy()
# Inherit CUDA_VISIBLE_DEVICES from parent
logger.info(f"Running PXDesign: {n_sample} samples -> {outdir}")
result = subprocess.run(cmd, capture_output=True, text=True, env=env,
timeout=7200)
if result.returncode != 0:
# Try alternative invocation via module
cmd_alt = [
pxdesign_python, '-m', 'pxdesign.runner.inference',
'--dump_dir', outdir,
'--input', input_json,
'--dtype', 'bf16',
'--N_sample', str(n_sample),
'--N_step', str(n_step),
]
result = subprocess.run(cmd_alt, capture_output=True, text=True, env=env,
timeout=7200)
if result.returncode != 0:
logger.error(f"PXDesign failed:\nstdout: {result.stdout[-1000:]}\nstderr: {result.stderr[-1000:]}")
return False
return True
def collect_pdbs(outdir):
"""Collect PDB/CIF files from PXDesign output."""
pdbs = []
for ext in ('*.pdb', '*.cif'):
pdbs.extend(glob(os.path.join(outdir, '**/' + ext), recursive=True))
pdbs = sorted(pdbs)
filtered = [p for p in pdbs if 'sample' in os.path.basename(p).lower()
or 'design' in os.path.basename(p).lower()
or 'rank' in os.path.basename(p).lower()]
return filtered if filtered else pdbs
def smc_particle_filter(args):
"""Run SMC reranking with PXDesign."""
os.chdir(_ALLO_ROOT)
from qtheta_pxdesign import QThetaPXDesignGuidance
outdir = args.outdir
os.makedirs(outdir, exist_ok=True)
# Initialize scorer
guidance = QThetaPXDesignGuidance(
checkpoint=args.qtheta_checkpoint,
ref_holo=args.ref_holo,
ref_apo=args.ref_apo,
ref_chain=args.ref_chain,
device=f'cuda:{args.gpu}',
)
guidance._lazy_init()
all_designs = []
round_summaries = []
for round_idx in range(args.n_rounds):
round_dir = os.path.join(outdir, f'round_{round_idx}')
os.makedirs(round_dir, exist_ok=True)
logger.info(f"\n{'='*60}")
logger.info(f"SMC Round {round_idx + 1}/{args.n_rounds}")
logger.info(f"{'='*60}")
# Generate particles via vanilla PXDesign
gen_dir = os.path.join(round_dir, 'generated')
success = run_pxdesign_batch(
input_json=args.input,
outdir=gen_dir,
n_sample=args.n_particles,
n_step=args.N_step,
gpu=args.gpu,
)
if not success:
# If subprocess fails, try using existing PXDesign outputs
logger.warning(f"Round {round_idx} generation failed. "
f"Checking for existing outputs...")
pdbs = collect_pdbs(args.designs_dir) if hasattr(args, 'designs_dir') else []
if not pdbs:
continue
else:
pdbs = collect_pdbs(gen_dir)
if not pdbs:
logger.warning(f"No PDBs found in round {round_idx}")
continue
# Score all particles
logger.info(f"Scoring {len(pdbs)} particles...")
round_results = []
for pdb_path in pdbs:
result = guidance.score_design(pdb_path)
if result is not None:
result['pdb_path'] = pdb_path
result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
result['round'] = round_idx
round_results.append(result)
if not round_results:
continue
margins = np.array([r['margin'] for r in round_results])
round_summary = {
'round': round_idx,
'n_particles': len(round_results),
'margin_mean': float(margins.mean()),
'margin_std': float(margins.std()),
'margin_max': float(margins.max()),
'frac_positive': float((margins > 0).mean()),
}
round_summaries.append(round_summary)
logger.info(f"Round {round_idx}: margin={margins.mean():.3f}±{margins.std():.3f}, "
f"max={margins.max():.3f}, S>0={round_summary['frac_positive']:.1%}")
all_designs.extend(round_results)
# Final ranking and summary
if all_designs:
all_designs.sort(key=lambda x: x['margin'], reverse=True)
all_margins = np.array([d['margin'] for d in all_designs])
holo_scores = np.array([d['q_holo'] for d in all_designs])
# Best-of-K
bok = {}
for K in [1, 2, 5, 10]:
n_trials = 2000
n_avail = len(all_margins)
successes = sum(
1 for _ in range(n_trials)
if all_margins[np.random.choice(n_avail, min(K, n_avail), replace=False)].max() > 0
)
bok[K] = successes / n_trials
summary = {
'method': 'PXDesign + SMC',
'n_rounds': args.n_rounds,
'n_particles_per_round': args.n_particles,
'total_designs': len(all_designs),
'margin_mean': float(all_margins.mean()),
'margin_std': float(all_margins.std()),
'margin_max': float(all_margins.max()),
'frac_positive': float((all_margins > 0).mean()),
'q_holo_mean': float(holo_scores.mean()),
'q_apo_mean': float(np.mean([d['q_apo'] for d in all_designs])),
'best_of_k': {str(k): v for k, v in bok.items()},
'round_summaries': round_summaries,
'top5': all_designs[:5],
}
with open(os.path.join(outdir, 'smc_scores.json'), 'w') as f:
json.dump(all_designs, f, indent=2)
with open(os.path.join(outdir, 'smc_summary.json'), 'w') as f:
json.dump(summary, f, indent=2)
# Copy best designs
best_dir = os.path.join(outdir, 'best_designs')
os.makedirs(best_dir, exist_ok=True)
for i, d in enumerate(all_designs[:20]):
if os.path.exists(d['pdb_path']):
dest = os.path.join(best_dir, f'rank_{i:02d}_{d["design_id"]}.pdb')
shutil.copy2(d['pdb_path'], dest)
logger.info(f"\n{'='*60}")
logger.info(f"PXDesign + SMC Results ({len(all_designs)} total designs)")
logger.info(f" Margin: {all_margins.mean():.3f} ± {all_margins.std():.3f}")
logger.info(f" Max margin: {all_margins.max():.3f}")
logger.info(f" Fraction S > 0: {(all_margins > 0).mean():.1%}")
logger.info(f" Q(holo) mean: {holo_scores.mean():.3f}")
logger.info(f" Best-of-K:")
for k, v in sorted(bok.items()):
logger.info(f" K={k:3d}: {v:.3f}")
logger.info(f"{'='*60}")
def main():
parser = argparse.ArgumentParser(description='PXDesign + SMC Reranking')
parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json',
help='PXDesign input JSON')
parser.add_argument('--designs_dir', default='experiments/pxdesign_cam/output/',
help='Existing PXDesign outputs (fallback if generation fails)')
parser.add_argument('--qtheta_checkpoint',
default='results/checkpoints_cam_v3/best_phase2.pt')
parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
parser.add_argument('--ref_chain', default='A')
parser.add_argument('--n_particles', type=int, default=16,
help='Particles per round')
parser.add_argument('--n_rounds', type=int, default=4,
help='Number of SMC rounds')
parser.add_argument('--N_step', type=int, default=400)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--outdir', default='results/pxdesign_smc')
args = parser.parse_args()
smc_particle_filter(args)
if __name__ == '__main__':
main()