""" PXDesign + Langevin Refinement. Post-hoc gradient ascent on existing PXDesign binder backbones using Q_theta selectivity gradient: x_{t+1} = x_t + η · ∇_x[Q(holo,Y) - Q(apo,Y)] + √(2η) · ε Takes PXDesign outputs (which have full sidechains), extracts backbone coords, refines them via Langevin dynamics, and outputs refined backbone-only PDBs. Usage: python code/scripts/pxdesign_guidance/langevin_pxdesign.py \ --designs_dir experiments/pxdesign_cam/output/ \ --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \ --ref_holo data/pdbs/cam_holo/3CLN.pdb \ --ref_apo data/pdbs/cam_apo/1CFD.pdb \ --n_steps 100 --step_size 0.01 \ --gpu 0 """ import os import sys import argparse import json import logging import numpy as np import torch from glob import glob logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger(__name__) _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..')) _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..')) if _ALLO_CODE_DIR not in sys.path: sys.path.insert(0, _ALLO_CODE_DIR) from utils.pdb_utils import ( load_structure, get_residues, get_backbone_coords, get_aa_indices, align_structures ) def write_backbone_pdb(coords, mask, out_path, chain='B'): """Write backbone PDB (N, CA, C, O) from [N, 4, 3] numpy coords.""" atom_names = [' N ', ' CA ', ' C ', ' O '] elements = ['N', 'C', 'C', 'O'] with open(out_path, 'w') as f: atom_idx = 1 for i in range(len(coords)): if not mask[i]: continue for j, (aname, elem) in enumerate(zip(atom_names, elements)): x, y, z = coords[i, j, :] f.write( f"ATOM {atom_idx:5d} {aname:4s} ALA {chain}{i+1:4d} " f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 {elem}\n" ) atom_idx += 1 f.write("END\n") def find_pxdesign_pdbs(designs_dir): """Find all PXDesign output PDB files.""" pdbs = sorted(glob(os.path.join(designs_dir, '**/*.pdb'), recursive=True)) pdbs = [p for p in pdbs if 'sample' in os.path.basename(p).lower() or 'design' in os.path.basename(p).lower() or 'rank' in os.path.basename(p).lower()] if not pdbs: pdbs = sorted(glob(os.path.join(designs_dir, '**/*.pdb'), recursive=True)) return pdbs def langevin_refine(dq, binder_coords_init, binder_mask, binder_aa_idx, rec_coords, rec_mask, ref_holo_ca, ref_apo_ca, n_steps=100, step_size=0.01, noise_scale=0.0, device='cuda:0'): """ Langevin refinement of binder backbone coordinates. Args: dq: DifferentiableQTheta scorer binder_coords_init: [N_binder, 4, 3] numpy — initial binder backbone binder_mask: [N_binder] numpy bool binder_aa_idx: [N_binder] numpy int rec_coords: [N_rec, 4, 3] numpy — receptor backbone rec_mask: [N_rec] numpy bool ref_holo_ca: [N_ref, 3] torch — holo reference CA ref_apo_ca: [N_ref, 3] torch — apo reference CA n_steps: int step_size: float (η) noise_scale: float (for stochastic Langevin, 0 = gradient ascent) device: str Returns: best_coords: [N_binder, 4, 3] numpy — refined coords trajectory: list of dicts with step info """ device = torch.device(device) # Convert to tensors x = torch.from_numpy(binder_coords_init.copy()).float().to(device) mask_t = torch.from_numpy(binder_mask).bool().to(device) aa_t = torch.from_numpy(binder_aa_idx).long().to(device) rec_ca = torch.from_numpy(rec_coords[:, 1, :]).float().to(device) best_margin = -float('inf') best_coords = binder_coords_init.copy() best_q_holo = 0.0 best_q_apo = 0.0 trajectory = [] for step in range(n_steps): x_grad = x.clone().requires_grad_(True) try: with torch.enable_grad(): # Align to holo reference n_align_h = min(len(rec_ca), len(ref_holo_ca)) if n_align_h < 5: break from qtheta_pxdesign import differentiable_kabsch R_h, t_h = differentiable_kabsch(rec_ca[:n_align_h].detach(), ref_holo_ca[:n_align_h].detach()) R_h, t_h = R_h.detach(), t_h.detach() aligned_holo = x_grad.reshape(-1, 3) @ R_h.T + t_h aligned_holo = aligned_holo.reshape(-1, 4, 3) q_holo = dq.score(aligned_holo, mask_t, binder_aa_idx=aa_t, receptor_label='holo') # Align to apo reference n_align_a = min(len(rec_ca), len(ref_apo_ca)) R_a, t_a = differentiable_kabsch(rec_ca[:n_align_a].detach(), ref_apo_ca[:n_align_a].detach()) R_a, t_a = R_a.detach(), t_a.detach() aligned_apo = x_grad.reshape(-1, 3) @ R_a.T + t_a aligned_apo = aligned_apo.reshape(-1, 4, 3) q_apo = dq.score(aligned_apo, mask_t, binder_aa_idx=aa_t, receptor_label='apo') margin = q_holo - q_apo margin.backward() grad = x_grad.grad if grad is None or torch.isnan(grad).any(): continue # Gradient ascent step x = x + step_size * grad # Optional noise for stochastic Langevin if noise_scale > 0: x = x + noise_scale * np.sqrt(2 * step_size) * torch.randn_like(x) current_margin = margin.item() step_info = { 'step': step, 'q_holo': q_holo.item(), 'q_apo': q_apo.item(), 'margin': current_margin, 'grad_norm': grad.norm().item(), } trajectory.append(step_info) if current_margin > best_margin: best_margin = current_margin best_coords = x.detach().cpu().numpy() best_q_holo = q_holo.item() best_q_apo = q_apo.item() if step % 20 == 0: logger.info( f" Step {step:3d}: Q+={q_holo.item():.3f} Q-={q_apo.item():.3f} " f"S={current_margin:+.3f} |∇|={grad.norm().item():.4f}") except Exception as e: logger.debug(f" Step {step}: {e}") continue return best_coords, trajectory, best_margin, best_q_holo, best_q_apo def main(): parser = argparse.ArgumentParser(description='PXDesign + Langevin Refinement') parser.add_argument('--designs_dir', default='experiments/pxdesign_cam/output/') parser.add_argument('--qtheta_checkpoint', default='results/checkpoints_cam_v3/best_phase2.pt') parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb') parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb') parser.add_argument('--ref_chain', default='A') parser.add_argument('--n_steps', type=int, default=100) parser.add_argument('--step_size', type=float, default=0.01) parser.add_argument('--noise_scale', type=float, default=0.0, help='Noise scale for stochastic Langevin (0=gradient ascent)') parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--outdir', default='results/pxdesign_langevin') args = parser.parse_args() os.chdir(_ALLO_ROOT) device = f'cuda:{args.gpu}' from models.differentiable_features import DifferentiableQTheta # Load scorer dq = DifferentiableQTheta(args.qtheta_checkpoint, device=device) dq.load_receptor(args.ref_holo, chain=args.ref_chain, label='holo') dq.load_receptor(args.ref_apo, chain=args.ref_chain, label='apo') # Load reference CA coords holo_model = load_structure(args.ref_holo) holo_res = get_residues(holo_model[args.ref_chain]) holo_coords, _ = get_backbone_coords(holo_res) ref_holo_ca = torch.from_numpy(holo_coords[:, 1, :]).float().to(device) apo_model = load_structure(args.ref_apo) apo_res = get_residues(apo_model[args.ref_chain]) apo_coords, _ = get_backbone_coords(apo_res) ref_apo_ca = torch.from_numpy(apo_coords[:, 1, :]).float().to(device) # Find designs pdbs = find_pxdesign_pdbs(args.designs_dir) logger.info(f"Found {len(pdbs)} PXDesign outputs to refine") outdir = args.outdir os.makedirs(outdir, exist_ok=True) all_results = [] for i, pdb_path in enumerate(pdbs): design_id = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '') logger.info(f"\n[{i+1}/{len(pdbs)}] Refining {design_id}...") try: model = load_structure(pdb_path) chains = {c.get_id(): c for c in model.get_chains()} chain_ids = sorted(chains.keys()) # Identify chains ref_len = len(holo_res) rec_chain_id, binder_chain_id = None, None for cid in chain_ids: cres = get_residues(chains[cid]) if abs(len(cres) - ref_len) < ref_len * 0.3: rec_chain_id = cid else: binder_chain_id = cid if rec_chain_id is None or binder_chain_id is None: if len(chain_ids) >= 2: rec_chain_id, binder_chain_id = chain_ids[0], chain_ids[1] else: logger.warning(f"Skipping {design_id}: cannot identify chains") continue rec_res = get_residues(chains[rec_chain_id]) binder_res = get_residues(chains[binder_chain_id]) rec_coords_np, rec_mask = get_backbone_coords(rec_res) binder_coords_np, binder_mask = get_backbone_coords(binder_res) aa_idx = get_aa_indices(binder_res) # Score before refinement rec_ca = rec_coords_np[:, 1, :] n_align = min(len(rec_ca), len(holo_coords[:, 1, :])) _, R_h = align_structures(rec_ca[:n_align], holo_coords[:n_align, 1, :]) center_h = rec_ca[:n_align].mean(0) ref_center_h = holo_coords[:n_align, 1, :].mean(0) aligned_init = (binder_coords_np.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h aligned_init = aligned_init.reshape(-1, 4, 3) with torch.no_grad(): q_h_init = dq.score( torch.from_numpy(aligned_init).float().to(device), torch.from_numpy(binder_mask).bool().to(device), binder_aa_idx=torch.from_numpy(aa_idx).long().to(device), receptor_label='holo').item() n_align_a = min(len(rec_ca), len(apo_coords[:, 1, :])) _, R_a = align_structures(rec_ca[:n_align_a], apo_coords[:n_align_a, 1, :]) center_a = rec_ca[:n_align_a].mean(0) ref_center_a = apo_coords[:n_align_a, 1, :].mean(0) aligned_init_a = (binder_coords_np.reshape(-1, 3) - center_a) @ R_a.T + ref_center_a aligned_init_a = aligned_init_a.reshape(-1, 4, 3) with torch.no_grad(): q_a_init = dq.score( torch.from_numpy(aligned_init_a).float().to(device), torch.from_numpy(binder_mask).bool().to(device), binder_aa_idx=torch.from_numpy(aa_idx).long().to(device), receptor_label='apo').item() margin_init = q_h_init - q_a_init # Run Langevin refinement refined_coords, trajectory, best_margin, best_qh, best_qa = langevin_refine( dq, binder_coords_np, binder_mask, aa_idx, rec_coords_np, rec_mask, ref_holo_ca, ref_apo_ca, n_steps=args.n_steps, step_size=args.step_size, noise_scale=args.noise_scale, device=device, ) # Use best-margin values (matching the saved best_coords PDB) margin_final = best_margin if trajectory else margin_init # Save refined PDB out_pdb = os.path.join(outdir, f'{design_id}_refined.pdb') write_backbone_pdb(refined_coords, binder_mask, out_pdb) result = { 'design_id': design_id, 'pdb_path': pdb_path, 'refined_pdb': out_pdb, 'q_holo_init': q_h_init, 'q_apo_init': q_a_init, 'margin_init': margin_init, 'q_holo_final': best_qh if trajectory else q_h_init, 'q_apo_final': best_qa if trajectory else q_a_init, 'margin_final': margin_final, 'margin_delta': margin_final - margin_init, 'n_steps_converged': len(trajectory), 'n_res': len(binder_res), } all_results.append(result) logger.info( f" {design_id}: S_init={margin_init:+.3f} -> S_final={margin_final:+.3f} " f"(Δ={margin_final - margin_init:+.3f})") except Exception as e: logger.warning(f"Failed to refine {design_id}: {e}") continue # Summary if all_results: all_results.sort(key=lambda x: x['margin_final'], reverse=True) margins_init = np.array([r['margin_init'] for r in all_results]) margins_final = np.array([r['margin_final'] for r in all_results]) deltas = margins_final - margins_init summary = { 'method': 'PXDesign + Langevin', 'n_designs': len(all_results), 'n_steps': args.n_steps, 'step_size': args.step_size, 'margin_init_mean': float(margins_init.mean()), 'margin_final_mean': float(margins_final.mean()), 'margin_delta_mean': float(deltas.mean()), 'frac_improved': float((deltas > 0).mean()), 'frac_positive_init': float((margins_init > 0).mean()), 'frac_positive_final': float((margins_final > 0).mean()), 'q_holo_final_mean': float(np.mean([r['q_holo_final'] for r in all_results])), } with open(os.path.join(outdir, 'langevin_scores.json'), 'w') as f: json.dump(all_results, f, indent=2) with open(os.path.join(outdir, 'langevin_summary.json'), 'w') as f: json.dump(summary, f, indent=2) logger.info(f"\n{'='*60}") logger.info(f"PXDesign + Langevin Results ({len(all_results)} designs)") logger.info(f" Margin init: {margins_init.mean():.3f} ± {margins_init.std():.3f}") logger.info(f" Margin final: {margins_final.mean():.3f} ± {margins_final.std():.3f}") logger.info(f" Δ margin: {deltas.mean():+.3f} ± {deltas.std():.3f}") logger.info(f" % improved: {(deltas > 0).mean():.1%}") logger.info(f" S>0 init/final: {(margins_init > 0).mean():.1%} / " f"{(margins_final > 0).mean():.1%}") logger.info(f"{'='*60}") if __name__ == '__main__': main()