AlloGen / code /scripts /pxdesign_guidance /langevin_pxdesign.py
chq1155's picture
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
ad9572d
"""
PXDesign + Langevin Refinement.
Post-hoc gradient ascent on existing PXDesign binder backbones using Q_theta
selectivity gradient:
x_{t+1} = x_t + η · ∇_x[Q(holo,Y) - Q(apo,Y)] + √(2η) · ε
Takes PXDesign outputs (which have full sidechains), extracts backbone coords,
refines them via Langevin dynamics, and outputs refined backbone-only PDBs.
Usage:
python code/scripts/pxdesign_guidance/langevin_pxdesign.py \
--designs_dir experiments/pxdesign_cam/output/ \
--qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
--ref_holo data/pdbs/cam_holo/3CLN.pdb \
--ref_apo data/pdbs/cam_apo/1CFD.pdb \
--n_steps 100 --step_size 0.01 \
--gpu 0
"""
import os
import sys
import argparse
import json
import logging
import numpy as np
import torch
from glob import glob
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
_ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
if _ALLO_CODE_DIR not in sys.path:
sys.path.insert(0, _ALLO_CODE_DIR)
from utils.pdb_utils import (
load_structure, get_residues, get_backbone_coords,
get_aa_indices, align_structures
)
def write_backbone_pdb(coords, mask, out_path, chain='B'):
"""Write backbone PDB (N, CA, C, O) from [N, 4, 3] numpy coords."""
atom_names = [' N ', ' CA ', ' C ', ' O ']
elements = ['N', 'C', 'C', 'O']
with open(out_path, 'w') as f:
atom_idx = 1
for i in range(len(coords)):
if not mask[i]:
continue
for j, (aname, elem) in enumerate(zip(atom_names, elements)):
x, y, z = coords[i, j, :]
f.write(
f"ATOM {atom_idx:5d} {aname:4s} ALA {chain}{i+1:4d} "
f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 {elem}\n"
)
atom_idx += 1
f.write("END\n")
def find_pxdesign_pdbs(designs_dir):
"""Find all PXDesign output PDB files."""
pdbs = sorted(glob(os.path.join(designs_dir, '**/*.pdb'), recursive=True))
pdbs = [p for p in pdbs if 'sample' in os.path.basename(p).lower()
or 'design' in os.path.basename(p).lower()
or 'rank' in os.path.basename(p).lower()]
if not pdbs:
pdbs = sorted(glob(os.path.join(designs_dir, '**/*.pdb'), recursive=True))
return pdbs
def langevin_refine(dq, binder_coords_init, binder_mask, binder_aa_idx,
rec_coords, rec_mask, ref_holo_ca, ref_apo_ca,
n_steps=100, step_size=0.01, noise_scale=0.0,
device='cuda:0'):
"""
Langevin refinement of binder backbone coordinates.
Args:
dq: DifferentiableQTheta scorer
binder_coords_init: [N_binder, 4, 3] numpy — initial binder backbone
binder_mask: [N_binder] numpy bool
binder_aa_idx: [N_binder] numpy int
rec_coords: [N_rec, 4, 3] numpy — receptor backbone
rec_mask: [N_rec] numpy bool
ref_holo_ca: [N_ref, 3] torch — holo reference CA
ref_apo_ca: [N_ref, 3] torch — apo reference CA
n_steps: int
step_size: float (η)
noise_scale: float (for stochastic Langevin, 0 = gradient ascent)
device: str
Returns:
best_coords: [N_binder, 4, 3] numpy — refined coords
trajectory: list of dicts with step info
"""
device = torch.device(device)
# Convert to tensors
x = torch.from_numpy(binder_coords_init.copy()).float().to(device)
mask_t = torch.from_numpy(binder_mask).bool().to(device)
aa_t = torch.from_numpy(binder_aa_idx).long().to(device)
rec_ca = torch.from_numpy(rec_coords[:, 1, :]).float().to(device)
best_margin = -float('inf')
best_coords = binder_coords_init.copy()
best_q_holo = 0.0
best_q_apo = 0.0
trajectory = []
for step in range(n_steps):
x_grad = x.clone().requires_grad_(True)
try:
with torch.enable_grad():
# Align to holo reference
n_align_h = min(len(rec_ca), len(ref_holo_ca))
if n_align_h < 5:
break
from qtheta_pxdesign import differentiable_kabsch
R_h, t_h = differentiable_kabsch(rec_ca[:n_align_h].detach(),
ref_holo_ca[:n_align_h].detach())
R_h, t_h = R_h.detach(), t_h.detach()
aligned_holo = x_grad.reshape(-1, 3) @ R_h.T + t_h
aligned_holo = aligned_holo.reshape(-1, 4, 3)
q_holo = dq.score(aligned_holo, mask_t, binder_aa_idx=aa_t,
receptor_label='holo')
# Align to apo reference
n_align_a = min(len(rec_ca), len(ref_apo_ca))
R_a, t_a = differentiable_kabsch(rec_ca[:n_align_a].detach(),
ref_apo_ca[:n_align_a].detach())
R_a, t_a = R_a.detach(), t_a.detach()
aligned_apo = x_grad.reshape(-1, 3) @ R_a.T + t_a
aligned_apo = aligned_apo.reshape(-1, 4, 3)
q_apo = dq.score(aligned_apo, mask_t, binder_aa_idx=aa_t,
receptor_label='apo')
margin = q_holo - q_apo
margin.backward()
grad = x_grad.grad
if grad is None or torch.isnan(grad).any():
continue
# Gradient ascent step
x = x + step_size * grad
# Optional noise for stochastic Langevin
if noise_scale > 0:
x = x + noise_scale * np.sqrt(2 * step_size) * torch.randn_like(x)
current_margin = margin.item()
step_info = {
'step': step,
'q_holo': q_holo.item(),
'q_apo': q_apo.item(),
'margin': current_margin,
'grad_norm': grad.norm().item(),
}
trajectory.append(step_info)
if current_margin > best_margin:
best_margin = current_margin
best_coords = x.detach().cpu().numpy()
best_q_holo = q_holo.item()
best_q_apo = q_apo.item()
if step % 20 == 0:
logger.info(
f" Step {step:3d}: Q+={q_holo.item():.3f} Q-={q_apo.item():.3f} "
f"S={current_margin:+.3f} |∇|={grad.norm().item():.4f}")
except Exception as e:
logger.debug(f" Step {step}: {e}")
continue
return best_coords, trajectory, best_margin, best_q_holo, best_q_apo
def main():
parser = argparse.ArgumentParser(description='PXDesign + Langevin Refinement')
parser.add_argument('--designs_dir', default='experiments/pxdesign_cam/output/')
parser.add_argument('--qtheta_checkpoint',
default='results/checkpoints_cam_v3/best_phase2.pt')
parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
parser.add_argument('--ref_chain', default='A')
parser.add_argument('--n_steps', type=int, default=100)
parser.add_argument('--step_size', type=float, default=0.01)
parser.add_argument('--noise_scale', type=float, default=0.0,
help='Noise scale for stochastic Langevin (0=gradient ascent)')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--outdir', default='results/pxdesign_langevin')
args = parser.parse_args()
os.chdir(_ALLO_ROOT)
device = f'cuda:{args.gpu}'
from models.differentiable_features import DifferentiableQTheta
# Load scorer
dq = DifferentiableQTheta(args.qtheta_checkpoint, device=device)
dq.load_receptor(args.ref_holo, chain=args.ref_chain, label='holo')
dq.load_receptor(args.ref_apo, chain=args.ref_chain, label='apo')
# Load reference CA coords
holo_model = load_structure(args.ref_holo)
holo_res = get_residues(holo_model[args.ref_chain])
holo_coords, _ = get_backbone_coords(holo_res)
ref_holo_ca = torch.from_numpy(holo_coords[:, 1, :]).float().to(device)
apo_model = load_structure(args.ref_apo)
apo_res = get_residues(apo_model[args.ref_chain])
apo_coords, _ = get_backbone_coords(apo_res)
ref_apo_ca = torch.from_numpy(apo_coords[:, 1, :]).float().to(device)
# Find designs
pdbs = find_pxdesign_pdbs(args.designs_dir)
logger.info(f"Found {len(pdbs)} PXDesign outputs to refine")
outdir = args.outdir
os.makedirs(outdir, exist_ok=True)
all_results = []
for i, pdb_path in enumerate(pdbs):
design_id = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
logger.info(f"\n[{i+1}/{len(pdbs)}] Refining {design_id}...")
try:
model = load_structure(pdb_path)
chains = {c.get_id(): c for c in model.get_chains()}
chain_ids = sorted(chains.keys())
# Identify chains
ref_len = len(holo_res)
rec_chain_id, binder_chain_id = None, None
for cid in chain_ids:
cres = get_residues(chains[cid])
if abs(len(cres) - ref_len) < ref_len * 0.3:
rec_chain_id = cid
else:
binder_chain_id = cid
if rec_chain_id is None or binder_chain_id is None:
if len(chain_ids) >= 2:
rec_chain_id, binder_chain_id = chain_ids[0], chain_ids[1]
else:
logger.warning(f"Skipping {design_id}: cannot identify chains")
continue
rec_res = get_residues(chains[rec_chain_id])
binder_res = get_residues(chains[binder_chain_id])
rec_coords_np, rec_mask = get_backbone_coords(rec_res)
binder_coords_np, binder_mask = get_backbone_coords(binder_res)
aa_idx = get_aa_indices(binder_res)
# Score before refinement
rec_ca = rec_coords_np[:, 1, :]
n_align = min(len(rec_ca), len(holo_coords[:, 1, :]))
_, R_h = align_structures(rec_ca[:n_align], holo_coords[:n_align, 1, :])
center_h = rec_ca[:n_align].mean(0)
ref_center_h = holo_coords[:n_align, 1, :].mean(0)
aligned_init = (binder_coords_np.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h
aligned_init = aligned_init.reshape(-1, 4, 3)
with torch.no_grad():
q_h_init = dq.score(
torch.from_numpy(aligned_init).float().to(device),
torch.from_numpy(binder_mask).bool().to(device),
binder_aa_idx=torch.from_numpy(aa_idx).long().to(device),
receptor_label='holo').item()
n_align_a = min(len(rec_ca), len(apo_coords[:, 1, :]))
_, R_a = align_structures(rec_ca[:n_align_a], apo_coords[:n_align_a, 1, :])
center_a = rec_ca[:n_align_a].mean(0)
ref_center_a = apo_coords[:n_align_a, 1, :].mean(0)
aligned_init_a = (binder_coords_np.reshape(-1, 3) - center_a) @ R_a.T + ref_center_a
aligned_init_a = aligned_init_a.reshape(-1, 4, 3)
with torch.no_grad():
q_a_init = dq.score(
torch.from_numpy(aligned_init_a).float().to(device),
torch.from_numpy(binder_mask).bool().to(device),
binder_aa_idx=torch.from_numpy(aa_idx).long().to(device),
receptor_label='apo').item()
margin_init = q_h_init - q_a_init
# Run Langevin refinement
refined_coords, trajectory, best_margin, best_qh, best_qa = langevin_refine(
dq, binder_coords_np, binder_mask, aa_idx,
rec_coords_np, rec_mask, ref_holo_ca, ref_apo_ca,
n_steps=args.n_steps, step_size=args.step_size,
noise_scale=args.noise_scale, device=device,
)
# Use best-margin values (matching the saved best_coords PDB)
margin_final = best_margin if trajectory else margin_init
# Save refined PDB
out_pdb = os.path.join(outdir, f'{design_id}_refined.pdb')
write_backbone_pdb(refined_coords, binder_mask, out_pdb)
result = {
'design_id': design_id,
'pdb_path': pdb_path,
'refined_pdb': out_pdb,
'q_holo_init': q_h_init,
'q_apo_init': q_a_init,
'margin_init': margin_init,
'q_holo_final': best_qh if trajectory else q_h_init,
'q_apo_final': best_qa if trajectory else q_a_init,
'margin_final': margin_final,
'margin_delta': margin_final - margin_init,
'n_steps_converged': len(trajectory),
'n_res': len(binder_res),
}
all_results.append(result)
logger.info(
f" {design_id}: S_init={margin_init:+.3f} -> S_final={margin_final:+.3f} "
f"(Δ={margin_final - margin_init:+.3f})")
except Exception as e:
logger.warning(f"Failed to refine {design_id}: {e}")
continue
# Summary
if all_results:
all_results.sort(key=lambda x: x['margin_final'], reverse=True)
margins_init = np.array([r['margin_init'] for r in all_results])
margins_final = np.array([r['margin_final'] for r in all_results])
deltas = margins_final - margins_init
summary = {
'method': 'PXDesign + Langevin',
'n_designs': len(all_results),
'n_steps': args.n_steps,
'step_size': args.step_size,
'margin_init_mean': float(margins_init.mean()),
'margin_final_mean': float(margins_final.mean()),
'margin_delta_mean': float(deltas.mean()),
'frac_improved': float((deltas > 0).mean()),
'frac_positive_init': float((margins_init > 0).mean()),
'frac_positive_final': float((margins_final > 0).mean()),
'q_holo_final_mean': float(np.mean([r['q_holo_final'] for r in all_results])),
}
with open(os.path.join(outdir, 'langevin_scores.json'), 'w') as f:
json.dump(all_results, f, indent=2)
with open(os.path.join(outdir, 'langevin_summary.json'), 'w') as f:
json.dump(summary, f, indent=2)
logger.info(f"\n{'='*60}")
logger.info(f"PXDesign + Langevin Results ({len(all_results)} designs)")
logger.info(f" Margin init: {margins_init.mean():.3f} ± {margins_init.std():.3f}")
logger.info(f" Margin final: {margins_final.mean():.3f} ± {margins_final.std():.3f}")
logger.info(f" Δ margin: {deltas.mean():+.3f} ± {deltas.std():.3f}")
logger.info(f" % improved: {(deltas > 0).mean():.1%}")
logger.info(f" S>0 init/final: {(margins_init > 0).mean():.1%} / "
f"{(margins_final > 0).mean():.1%}")
logger.info(f"{'='*60}")
if __name__ == '__main__':
main()