| """ |
| PXDesign + Langevin Refinement. |
| |
| Post-hoc gradient ascent on existing PXDesign binder backbones using Q_theta |
| selectivity gradient: |
| x_{t+1} = x_t + η · ∇_x[Q(holo,Y) - Q(apo,Y)] + √(2η) · ε |
| |
| Takes PXDesign outputs (which have full sidechains), extracts backbone coords, |
| refines them via Langevin dynamics, and outputs refined backbone-only PDBs. |
| |
| Usage: |
| python code/scripts/pxdesign_guidance/langevin_pxdesign.py \ |
| --designs_dir experiments/pxdesign_cam/output/ \ |
| --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \ |
| --ref_holo data/pdbs/cam_holo/3CLN.pdb \ |
| --ref_apo data/pdbs/cam_apo/1CFD.pdb \ |
| --n_steps 100 --step_size 0.01 \ |
| --gpu 0 |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import json |
| import logging |
| import numpy as np |
| import torch |
| from glob import glob |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..')) |
| _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..')) |
|
|
| if _ALLO_CODE_DIR not in sys.path: |
| sys.path.insert(0, _ALLO_CODE_DIR) |
|
|
| from utils.pdb_utils import ( |
| load_structure, get_residues, get_backbone_coords, |
| get_aa_indices, align_structures |
| ) |
|
|
|
|
| def write_backbone_pdb(coords, mask, out_path, chain='B'): |
| """Write backbone PDB (N, CA, C, O) from [N, 4, 3] numpy coords.""" |
| atom_names = [' N ', ' CA ', ' C ', ' O '] |
| elements = ['N', 'C', 'C', 'O'] |
| with open(out_path, 'w') as f: |
| atom_idx = 1 |
| for i in range(len(coords)): |
| if not mask[i]: |
| continue |
| for j, (aname, elem) in enumerate(zip(atom_names, elements)): |
| x, y, z = coords[i, j, :] |
| f.write( |
| f"ATOM {atom_idx:5d} {aname:4s} ALA {chain}{i+1:4d} " |
| f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 {elem}\n" |
| ) |
| atom_idx += 1 |
| f.write("END\n") |
|
|
|
|
| def find_pxdesign_pdbs(designs_dir): |
| """Find all PXDesign output PDB files.""" |
| pdbs = sorted(glob(os.path.join(designs_dir, '**/*.pdb'), recursive=True)) |
| pdbs = [p for p in pdbs if 'sample' in os.path.basename(p).lower() |
| or 'design' in os.path.basename(p).lower() |
| or 'rank' in os.path.basename(p).lower()] |
| if not pdbs: |
| pdbs = sorted(glob(os.path.join(designs_dir, '**/*.pdb'), recursive=True)) |
| return pdbs |
|
|
|
|
| def langevin_refine(dq, binder_coords_init, binder_mask, binder_aa_idx, |
| rec_coords, rec_mask, ref_holo_ca, ref_apo_ca, |
| n_steps=100, step_size=0.01, noise_scale=0.0, |
| device='cuda:0'): |
| """ |
| Langevin refinement of binder backbone coordinates. |
| |
| Args: |
| dq: DifferentiableQTheta scorer |
| binder_coords_init: [N_binder, 4, 3] numpy — initial binder backbone |
| binder_mask: [N_binder] numpy bool |
| binder_aa_idx: [N_binder] numpy int |
| rec_coords: [N_rec, 4, 3] numpy — receptor backbone |
| rec_mask: [N_rec] numpy bool |
| ref_holo_ca: [N_ref, 3] torch — holo reference CA |
| ref_apo_ca: [N_ref, 3] torch — apo reference CA |
| n_steps: int |
| step_size: float (η) |
| noise_scale: float (for stochastic Langevin, 0 = gradient ascent) |
| device: str |
| |
| Returns: |
| best_coords: [N_binder, 4, 3] numpy — refined coords |
| trajectory: list of dicts with step info |
| """ |
| device = torch.device(device) |
|
|
| |
| x = torch.from_numpy(binder_coords_init.copy()).float().to(device) |
| mask_t = torch.from_numpy(binder_mask).bool().to(device) |
| aa_t = torch.from_numpy(binder_aa_idx).long().to(device) |
| rec_ca = torch.from_numpy(rec_coords[:, 1, :]).float().to(device) |
|
|
| best_margin = -float('inf') |
| best_coords = binder_coords_init.copy() |
| best_q_holo = 0.0 |
| best_q_apo = 0.0 |
| trajectory = [] |
|
|
| for step in range(n_steps): |
| x_grad = x.clone().requires_grad_(True) |
|
|
| try: |
| with torch.enable_grad(): |
| |
| n_align_h = min(len(rec_ca), len(ref_holo_ca)) |
| if n_align_h < 5: |
| break |
| from qtheta_pxdesign import differentiable_kabsch |
| R_h, t_h = differentiable_kabsch(rec_ca[:n_align_h].detach(), |
| ref_holo_ca[:n_align_h].detach()) |
| R_h, t_h = R_h.detach(), t_h.detach() |
| aligned_holo = x_grad.reshape(-1, 3) @ R_h.T + t_h |
| aligned_holo = aligned_holo.reshape(-1, 4, 3) |
|
|
| q_holo = dq.score(aligned_holo, mask_t, binder_aa_idx=aa_t, |
| receptor_label='holo') |
|
|
| |
| n_align_a = min(len(rec_ca), len(ref_apo_ca)) |
| R_a, t_a = differentiable_kabsch(rec_ca[:n_align_a].detach(), |
| ref_apo_ca[:n_align_a].detach()) |
| R_a, t_a = R_a.detach(), t_a.detach() |
| aligned_apo = x_grad.reshape(-1, 3) @ R_a.T + t_a |
| aligned_apo = aligned_apo.reshape(-1, 4, 3) |
|
|
| q_apo = dq.score(aligned_apo, mask_t, binder_aa_idx=aa_t, |
| receptor_label='apo') |
|
|
| margin = q_holo - q_apo |
| margin.backward() |
|
|
| grad = x_grad.grad |
| if grad is None or torch.isnan(grad).any(): |
| continue |
|
|
| |
| x = x + step_size * grad |
|
|
| |
| if noise_scale > 0: |
| x = x + noise_scale * np.sqrt(2 * step_size) * torch.randn_like(x) |
|
|
| current_margin = margin.item() |
| step_info = { |
| 'step': step, |
| 'q_holo': q_holo.item(), |
| 'q_apo': q_apo.item(), |
| 'margin': current_margin, |
| 'grad_norm': grad.norm().item(), |
| } |
| trajectory.append(step_info) |
|
|
| if current_margin > best_margin: |
| best_margin = current_margin |
| best_coords = x.detach().cpu().numpy() |
| best_q_holo = q_holo.item() |
| best_q_apo = q_apo.item() |
|
|
| if step % 20 == 0: |
| logger.info( |
| f" Step {step:3d}: Q+={q_holo.item():.3f} Q-={q_apo.item():.3f} " |
| f"S={current_margin:+.3f} |∇|={grad.norm().item():.4f}") |
|
|
| except Exception as e: |
| logger.debug(f" Step {step}: {e}") |
| continue |
|
|
| return best_coords, trajectory, best_margin, best_q_holo, best_q_apo |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='PXDesign + Langevin Refinement') |
| parser.add_argument('--designs_dir', default='experiments/pxdesign_cam/output/') |
| parser.add_argument('--qtheta_checkpoint', |
| default='results/checkpoints_cam_v3/best_phase2.pt') |
| parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb') |
| parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb') |
| parser.add_argument('--ref_chain', default='A') |
| parser.add_argument('--n_steps', type=int, default=100) |
| parser.add_argument('--step_size', type=float, default=0.01) |
| parser.add_argument('--noise_scale', type=float, default=0.0, |
| help='Noise scale for stochastic Langevin (0=gradient ascent)') |
| parser.add_argument('--gpu', type=int, default=0) |
| parser.add_argument('--outdir', default='results/pxdesign_langevin') |
| args = parser.parse_args() |
|
|
| os.chdir(_ALLO_ROOT) |
|
|
| device = f'cuda:{args.gpu}' |
|
|
| from models.differentiable_features import DifferentiableQTheta |
|
|
| |
| dq = DifferentiableQTheta(args.qtheta_checkpoint, device=device) |
| dq.load_receptor(args.ref_holo, chain=args.ref_chain, label='holo') |
| dq.load_receptor(args.ref_apo, chain=args.ref_chain, label='apo') |
|
|
| |
| holo_model = load_structure(args.ref_holo) |
| holo_res = get_residues(holo_model[args.ref_chain]) |
| holo_coords, _ = get_backbone_coords(holo_res) |
| ref_holo_ca = torch.from_numpy(holo_coords[:, 1, :]).float().to(device) |
|
|
| apo_model = load_structure(args.ref_apo) |
| apo_res = get_residues(apo_model[args.ref_chain]) |
| apo_coords, _ = get_backbone_coords(apo_res) |
| ref_apo_ca = torch.from_numpy(apo_coords[:, 1, :]).float().to(device) |
|
|
| |
| pdbs = find_pxdesign_pdbs(args.designs_dir) |
| logger.info(f"Found {len(pdbs)} PXDesign outputs to refine") |
|
|
| outdir = args.outdir |
| os.makedirs(outdir, exist_ok=True) |
|
|
| all_results = [] |
| for i, pdb_path in enumerate(pdbs): |
| design_id = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '') |
| logger.info(f"\n[{i+1}/{len(pdbs)}] Refining {design_id}...") |
|
|
| try: |
| model = load_structure(pdb_path) |
| chains = {c.get_id(): c for c in model.get_chains()} |
| chain_ids = sorted(chains.keys()) |
|
|
| |
| ref_len = len(holo_res) |
| rec_chain_id, binder_chain_id = None, None |
| for cid in chain_ids: |
| cres = get_residues(chains[cid]) |
| if abs(len(cres) - ref_len) < ref_len * 0.3: |
| rec_chain_id = cid |
| else: |
| binder_chain_id = cid |
|
|
| if rec_chain_id is None or binder_chain_id is None: |
| if len(chain_ids) >= 2: |
| rec_chain_id, binder_chain_id = chain_ids[0], chain_ids[1] |
| else: |
| logger.warning(f"Skipping {design_id}: cannot identify chains") |
| continue |
|
|
| rec_res = get_residues(chains[rec_chain_id]) |
| binder_res = get_residues(chains[binder_chain_id]) |
|
|
| rec_coords_np, rec_mask = get_backbone_coords(rec_res) |
| binder_coords_np, binder_mask = get_backbone_coords(binder_res) |
| aa_idx = get_aa_indices(binder_res) |
|
|
| |
| rec_ca = rec_coords_np[:, 1, :] |
| n_align = min(len(rec_ca), len(holo_coords[:, 1, :])) |
| _, R_h = align_structures(rec_ca[:n_align], holo_coords[:n_align, 1, :]) |
| center_h = rec_ca[:n_align].mean(0) |
| ref_center_h = holo_coords[:n_align, 1, :].mean(0) |
|
|
| aligned_init = (binder_coords_np.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h |
| aligned_init = aligned_init.reshape(-1, 4, 3) |
| with torch.no_grad(): |
| q_h_init = dq.score( |
| torch.from_numpy(aligned_init).float().to(device), |
| torch.from_numpy(binder_mask).bool().to(device), |
| binder_aa_idx=torch.from_numpy(aa_idx).long().to(device), |
| receptor_label='holo').item() |
|
|
| n_align_a = min(len(rec_ca), len(apo_coords[:, 1, :])) |
| _, R_a = align_structures(rec_ca[:n_align_a], apo_coords[:n_align_a, 1, :]) |
| center_a = rec_ca[:n_align_a].mean(0) |
| ref_center_a = apo_coords[:n_align_a, 1, :].mean(0) |
| aligned_init_a = (binder_coords_np.reshape(-1, 3) - center_a) @ R_a.T + ref_center_a |
| aligned_init_a = aligned_init_a.reshape(-1, 4, 3) |
| with torch.no_grad(): |
| q_a_init = dq.score( |
| torch.from_numpy(aligned_init_a).float().to(device), |
| torch.from_numpy(binder_mask).bool().to(device), |
| binder_aa_idx=torch.from_numpy(aa_idx).long().to(device), |
| receptor_label='apo').item() |
|
|
| margin_init = q_h_init - q_a_init |
|
|
| |
| refined_coords, trajectory, best_margin, best_qh, best_qa = langevin_refine( |
| dq, binder_coords_np, binder_mask, aa_idx, |
| rec_coords_np, rec_mask, ref_holo_ca, ref_apo_ca, |
| n_steps=args.n_steps, step_size=args.step_size, |
| noise_scale=args.noise_scale, device=device, |
| ) |
|
|
| |
| margin_final = best_margin if trajectory else margin_init |
|
|
| |
| out_pdb = os.path.join(outdir, f'{design_id}_refined.pdb') |
| write_backbone_pdb(refined_coords, binder_mask, out_pdb) |
|
|
| result = { |
| 'design_id': design_id, |
| 'pdb_path': pdb_path, |
| 'refined_pdb': out_pdb, |
| 'q_holo_init': q_h_init, |
| 'q_apo_init': q_a_init, |
| 'margin_init': margin_init, |
| 'q_holo_final': best_qh if trajectory else q_h_init, |
| 'q_apo_final': best_qa if trajectory else q_a_init, |
| 'margin_final': margin_final, |
| 'margin_delta': margin_final - margin_init, |
| 'n_steps_converged': len(trajectory), |
| 'n_res': len(binder_res), |
| } |
| all_results.append(result) |
|
|
| logger.info( |
| f" {design_id}: S_init={margin_init:+.3f} -> S_final={margin_final:+.3f} " |
| f"(Δ={margin_final - margin_init:+.3f})") |
|
|
| except Exception as e: |
| logger.warning(f"Failed to refine {design_id}: {e}") |
| continue |
|
|
| |
| if all_results: |
| all_results.sort(key=lambda x: x['margin_final'], reverse=True) |
| margins_init = np.array([r['margin_init'] for r in all_results]) |
| margins_final = np.array([r['margin_final'] for r in all_results]) |
| deltas = margins_final - margins_init |
|
|
| summary = { |
| 'method': 'PXDesign + Langevin', |
| 'n_designs': len(all_results), |
| 'n_steps': args.n_steps, |
| 'step_size': args.step_size, |
| 'margin_init_mean': float(margins_init.mean()), |
| 'margin_final_mean': float(margins_final.mean()), |
| 'margin_delta_mean': float(deltas.mean()), |
| 'frac_improved': float((deltas > 0).mean()), |
| 'frac_positive_init': float((margins_init > 0).mean()), |
| 'frac_positive_final': float((margins_final > 0).mean()), |
| 'q_holo_final_mean': float(np.mean([r['q_holo_final'] for r in all_results])), |
| } |
|
|
| with open(os.path.join(outdir, 'langevin_scores.json'), 'w') as f: |
| json.dump(all_results, f, indent=2) |
| with open(os.path.join(outdir, 'langevin_summary.json'), 'w') as f: |
| json.dump(summary, f, indent=2) |
|
|
| logger.info(f"\n{'='*60}") |
| logger.info(f"PXDesign + Langevin Results ({len(all_results)} designs)") |
| logger.info(f" Margin init: {margins_init.mean():.3f} ± {margins_init.std():.3f}") |
| logger.info(f" Margin final: {margins_final.mean():.3f} ± {margins_final.std():.3f}") |
| logger.info(f" Δ margin: {deltas.mean():+.3f} ± {deltas.std():.3f}") |
| logger.info(f" % improved: {(deltas > 0).mean():.1%}") |
| logger.info(f" S>0 init/final: {(margins_init > 0).mean():.1%} / " |
| f"{(margins_final > 0).mean():.1%}") |
| logger.info(f"{'='*60}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|