AlloGen / code /scripts /pxdesign_guidance /tds_pxdesign.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
PXDesign + Twisted Diffusion Sampling (TDS).
Multi-round particle filtering with guided PXDesign:
Round r:
1. Generate N particles via PXDesign with Q_theta classifier guidance
2. Score each particle with Q_theta selectivity margin
3. Compute importance weights w_i ~ exp(margin_i / temperature)
4. Resample particles (keep best, discard worst)
5. Add perturbation noise for diversity
This combines in-process guidance (the "twisted proposal") with post-hoc
importance-weighted resampling for highest-quality designs.
Usage:
python code/scripts/pxdesign_guidance/tds_pxdesign.py \
--input experiments/pxdesign_cam/output/cam_binder.json \
--qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
--ref_holo data/pdbs/cam_holo/3CLN.pdb \
--ref_apo data/pdbs/cam_apo/1CFD.pdb \
--n_particles 16 --n_rounds 4 \
--guidance_scale 0.5 \
--gpu 0
"""
import os
import sys
import argparse
import json
import logging
import shutil
import subprocess
from glob import glob
import numpy as np
import torch
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
_ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
if _ALLO_CODE_DIR not in sys.path:
sys.path.insert(0, _ALLO_CODE_DIR)
def compute_ess(log_weights):
"""Compute effective sample size from log-weights."""
log_weights = log_weights - log_weights.max()
weights = np.exp(log_weights)
weights = weights / weights.sum()
return 1.0 / (weights ** 2).sum()
def run_guided_pxdesign_batch(input_json, outdir, n_sample, n_step,
gpu, guidance_args):
"""Run guided PXDesign as a subprocess."""
pxdesign_python = 'python'
cmd = [
pxdesign_python,
os.path.join(_SCRIPT_DIR, 'guided_pxdesign.py'),
'--input', input_json,
'--qtheta_checkpoint', guidance_args['checkpoint'],
'--ref_holo', guidance_args['ref_holo'],
'--ref_apo', guidance_args['ref_apo'],
'--ref_chain', guidance_args['ref_chain'],
'--guidance_scale', str(guidance_args['guidance_scale']),
'--guidance_start', str(guidance_args.get('guidance_start', 0.8)),
'--guidance_end', str(guidance_args.get('guidance_end', 0.1)),
'--N_sample', str(n_sample),
'--N_step', str(n_step),
'--gpu', str(gpu),
'--outdir', outdir,
]
env = os.environ.copy()
# Inherit CUDA_VISIBLE_DEVICES from parent
logger.info(f"Running guided PXDesign: {n_sample} samples -> {outdir}")
result = subprocess.run(cmd, capture_output=True, text=True, env=env,
timeout=7200)
if result.returncode != 0:
logger.error(f"PXDesign failed:\n{result.stderr[-2000:]}")
return False
return True
def run_vanilla_pxdesign_batch(input_json, outdir, n_sample, n_step, gpu):
"""Run vanilla PXDesign (no guidance) as a subprocess."""
pxdesign_env = 'python'
cmd = [
pxdesign_env, '-m', 'pxdesign.runner.inference',
'--dump_dir', outdir,
'--input', input_json,
'--dtype', 'bf16',
'--N_sample', str(n_sample),
'--N_step', str(n_step),
]
env = os.environ.copy()
# Inherit CUDA_VISIBLE_DEVICES from parent
logger.info(f"Running vanilla PXDesign: {n_sample} samples -> {outdir}")
result = subprocess.run(cmd, capture_output=True, text=True, env=env,
timeout=7200)
if result.returncode != 0:
logger.error(f"PXDesign failed:\n{result.stderr[-2000:]}")
return False
return True
def collect_pdbs(outdir):
"""Collect PDB/CIF paths from PXDesign output directory."""
pdbs = []
for ext in ('*.pdb', '*.cif'):
pdbs.extend(glob(os.path.join(outdir, '**/' + ext), recursive=True))
pdbs = sorted(pdbs)
filtered = [p for p in pdbs if 'sample' in os.path.basename(p).lower()
or 'design' in os.path.basename(p).lower()
or 'rank' in os.path.basename(p).lower()]
return filtered if filtered else pdbs
def tds_particle_filter(args):
"""Run TDS particle filtering with PXDesign."""
from qtheta_pxdesign import QThetaPXDesignGuidance
outdir = os.path.join(_ALLO_ROOT, args.outdir)
os.makedirs(outdir, exist_ok=True)
# Initialize scorer
guidance = QThetaPXDesignGuidance(
checkpoint=os.path.join(_ALLO_ROOT, args.qtheta_checkpoint),
ref_holo=os.path.join(_ALLO_ROOT, args.ref_holo),
ref_apo=os.path.join(_ALLO_ROOT, args.ref_apo),
ref_chain=args.ref_chain,
device=f'cuda:{args.gpu}',
)
guidance._lazy_init()
guidance_args = {
'checkpoint': args.qtheta_checkpoint,
'ref_holo': args.ref_holo,
'ref_apo': args.ref_apo,
'ref_chain': args.ref_chain,
'guidance_scale': args.guidance_scale,
'guidance_start': args.guidance_start,
'guidance_end': args.guidance_end,
}
all_designs = []
round_summaries = []
for round_idx in range(args.n_rounds):
round_dir = os.path.join(outdir, f'round_{round_idx}')
os.makedirs(round_dir, exist_ok=True)
logger.info(f"\n{'='*60}")
logger.info(f"TDS Round {round_idx + 1}/{args.n_rounds}")
logger.info(f"{'='*60}")
# Generate particles via guided PXDesign
gen_dir = os.path.join(round_dir, 'generated')
success = run_guided_pxdesign_batch(
input_json=os.path.join(_ALLO_ROOT, args.input),
outdir=gen_dir,
n_sample=args.n_particles,
n_step=args.N_step,
gpu=args.gpu,
guidance_args=guidance_args,
)
if not success:
logger.warning(f"Round {round_idx} generation failed, skipping")
continue
# Collect and score particles
pdbs = collect_pdbs(gen_dir)
if not pdbs:
logger.warning(f"No PDBs found in round {round_idx}")
continue
logger.info(f"Scoring {len(pdbs)} particles...")
round_results = []
for pdb_path in pdbs:
result = guidance.score_design(pdb_path)
if result is not None:
result['pdb_path'] = pdb_path
result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
result['round'] = round_idx
round_results.append(result)
if not round_results:
logger.warning(f"No scorable designs in round {round_idx}")
continue
margins = np.array([r['margin'] for r in round_results])
# Compute importance weights
log_weights = margins / args.temperature
ess = compute_ess(log_weights)
round_summary = {
'round': round_idx,
'n_particles': len(round_results),
'margin_mean': float(margins.mean()),
'margin_std': float(margins.std()),
'margin_max': float(margins.max()),
'frac_positive': float((margins > 0).mean()),
'ess': float(ess),
}
round_summaries.append(round_summary)
logger.info(f"Round {round_idx}: margin={margins.mean():.3f}±{margins.std():.3f}, "
f"max={margins.max():.3f}, S>0={round_summary['frac_positive']:.1%}, "
f"ESS={ess:.1f}/{len(round_results)}")
# Add to design pool
all_designs.extend(round_results)
# Resample for next round (top-K selection for PXDesign since
# we can't easily perturb and re-denoise)
if round_idx < args.n_rounds - 1:
# Copy best designs to inform next round
# For PXDesign, each round generates fresh samples with guidance
# Resampling influence is through the guidance strength
# Increase guidance scale for later rounds
guidance_args['guidance_scale'] = args.guidance_scale * (1.0 + 0.2 * (round_idx + 1))
logger.info(f"Increasing guidance scale to {guidance_args['guidance_scale']:.2f} "
f"for next round")
# Final summary
if all_designs:
all_designs.sort(key=lambda x: x['margin'], reverse=True)
all_margins = np.array([d['margin'] for d in all_designs])
holo_scores = np.array([d['q_holo'] for d in all_designs])
# Best-of-K
bok = {}
for K in [1, 2, 5, 10]:
n_trials = 2000
n_avail = len(all_margins)
successes = sum(
1 for _ in range(n_trials)
if all_margins[np.random.choice(n_avail, min(K, n_avail), replace=False)].max() > 0
)
bok[K] = successes / n_trials
summary = {
'method': 'PXDesign + TDS',
'n_rounds': args.n_rounds,
'n_particles_per_round': args.n_particles,
'total_designs': len(all_designs),
'guidance_scale': args.guidance_scale,
'temperature': args.temperature,
'margin_mean': float(all_margins.mean()),
'margin_std': float(all_margins.std()),
'margin_max': float(all_margins.max()),
'frac_positive': float((all_margins > 0).mean()),
'q_holo_mean': float(holo_scores.mean()),
'best_of_k': {str(k): v for k, v in bok.items()},
'round_summaries': round_summaries,
'top5': all_designs[:5],
}
with open(os.path.join(outdir, 'tds_scores.json'), 'w') as f:
json.dump(all_designs, f, indent=2)
with open(os.path.join(outdir, 'tds_summary.json'), 'w') as f:
json.dump(summary, f, indent=2)
# Copy best designs to top-level
best_dir = os.path.join(outdir, 'best_designs')
os.makedirs(best_dir, exist_ok=True)
for i, d in enumerate(all_designs[:20]):
if os.path.exists(d['pdb_path']):
dest = os.path.join(best_dir, f'rank_{i:02d}_{d["design_id"]}.pdb')
shutil.copy2(d['pdb_path'], dest)
logger.info(f"\n{'='*60}")
logger.info(f"PXDesign + TDS Results ({len(all_designs)} total designs)")
logger.info(f" Margin: {all_margins.mean():.3f} ± {all_margins.std():.3f}")
logger.info(f" Max margin: {all_margins.max():.3f}")
logger.info(f" Fraction S > 0: {(all_margins > 0).mean():.1%}")
logger.info(f" Q(holo) mean: {holo_scores.mean():.3f}")
logger.info(f" Best-of-K:")
for k, v in sorted(bok.items()):
logger.info(f" K={k:3d}: {v:.3f}")
logger.info(f"{'='*60}")
def main():
parser = argparse.ArgumentParser(description='PXDesign + TDS')
parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json')
parser.add_argument('--qtheta_checkpoint',
default='results/checkpoints_cam_v3/best_phase2.pt')
parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
parser.add_argument('--ref_chain', default='A')
parser.add_argument('--n_particles', type=int, default=16,
help='Particles per round')
parser.add_argument('--n_rounds', type=int, default=4,
help='Number of TDS rounds')
parser.add_argument('--guidance_scale', type=float, default=0.5,
help='Initial guidance scale')
parser.add_argument('--guidance_start', type=float, default=0.8)
parser.add_argument('--guidance_end', type=float, default=0.1)
parser.add_argument('--temperature', type=float, default=0.5,
help='Temperature for importance weights')
parser.add_argument('--N_step', type=int, default=400)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--outdir', default='results/pxdesign_tds')
args = parser.parse_args()
tds_particle_filter(args)
if __name__ == '__main__':
main()