| """ |
| Core Q_theta guidance module for PXDesign integration. |
| |
| Provides differentiable Q_theta scoring for PXDesign's atom coordinate format. |
| Key responsibilities: |
| - Extract binder backbone (N, CA, C, O) from PXDesign's flat atom array |
| - Align binder to reference receptor frames via differentiable Kabsch |
| - Compute selectivity gradient ∇[Q(holo,Y) - Q(apo,Y)] w.r.t. atom coords |
| - Works in pxdesign env (PyTorch 2.3.1) using pure-PyTorch scorer (no e3nn) |
| |
| Usage: |
| guidance = QThetaPXDesignGuidance( |
| checkpoint='results/checkpoints_cam_v3/best_phase2.pt', |
| ref_holo='data/pdbs/cam_holo/3CLN.pdb', |
| ref_apo='data/pdbs/cam_apo/1CFD.pdb', |
| ref_chain='A', |
| device='cuda:0', |
| ) |
| # Inside PXDesign diffusion loop: |
| grad = guidance.compute_guidance_gradient(x_denoised, input_feature_dict, t_hat) |
| x_denoised = x_denoised + scale * grad |
| """ |
|
|
| import os |
| import sys |
| import logging |
| import numpy as np |
| import torch |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| _ALLO_CODE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) |
| if _ALLO_CODE_DIR not in sys.path: |
| sys.path.insert(0, _ALLO_CODE_DIR) |
|
|
|
|
| def differentiable_kabsch(mobile, target): |
| """ |
| Differentiable Kabsch alignment using SVD. |
| |
| Args: |
| mobile: [N, 3] tensor (points to align FROM) |
| target: [N, 3] tensor (points to align TO) |
| |
| Returns: |
| R: [3, 3] rotation matrix |
| t: [3] translation vector |
| Such that aligned = (mobile - mobile_center) @ R.T + target_center |
| """ |
| mobile_center = mobile.mean(dim=0) |
| target_center = target.mean(dim=0) |
|
|
| mobile_centered = mobile - mobile_center |
| target_centered = target - target_center |
|
|
| H = mobile_centered.T @ target_centered |
| U, S, Vh = torch.linalg.svd(H) |
| V = Vh.T |
|
|
| |
| d = torch.det(V @ U.T) |
| sign_matrix = torch.diag(torch.tensor([1.0, 1.0, torch.sign(d)], |
| device=mobile.device, dtype=mobile.dtype)) |
| R = V @ sign_matrix @ U.T |
| t = target_center - mobile_center @ R.T |
|
|
| return R, t |
|
|
|
|
| class QThetaPXDesignGuidance: |
| """ |
| Q_theta guidance for PXDesign diffusion process. |
| |
| Lazily initializes the scorer and reference structures on first use. |
| Handles extraction of binder backbone from PXDesign's flat atom array |
| and alignment to reference receptor frames. |
| """ |
|
|
| def __init__(self, checkpoint, ref_holo, ref_apo, ref_chain='A', |
| device='cuda:0', cutoff=8.0, esm_target='cam'): |
| self.checkpoint = checkpoint |
| self.ref_holo = ref_holo |
| self.ref_apo = ref_apo |
| self.ref_chain = ref_chain |
| self.device = torch.device(device) |
| self.cutoff = cutoff |
| self.esm_target = esm_target |
|
|
| self._initialized = False |
| self.dq = None |
| self.ref_holo_ca = None |
| self.ref_apo_ca = None |
|
|
| def _lazy_init(self): |
| """Initialize Q_theta scorer and load reference structures.""" |
| if self._initialized: |
| return |
|
|
| from models.differentiable_features import DifferentiableQTheta |
| from utils.pdb_utils import load_structure, get_residues, get_backbone_coords |
|
|
| logger.info(f"Loading Q_theta checkpoint: {self.checkpoint}") |
| self.dq = DifferentiableQTheta(self.checkpoint, device=str(self.device)) |
| self.dq.load_receptor(self.ref_holo, chain=self.ref_chain, label='holo', |
| esm_target=self.esm_target) |
| self.dq.load_receptor(self.ref_apo, chain=self.ref_chain, label='apo', |
| esm_target=self.esm_target) |
|
|
| |
| holo_model = load_structure(self.ref_holo) |
| holo_res = get_residues(holo_model[self.ref_chain]) |
| holo_coords, _ = get_backbone_coords(holo_res) |
| self.ref_holo_ca = torch.from_numpy(holo_coords[:, 1, :]).float().to(self.device) |
|
|
| apo_model = load_structure(self.ref_apo) |
| apo_res = get_residues(apo_model[self.ref_chain]) |
| apo_coords, _ = get_backbone_coords(apo_res) |
| self.ref_apo_ca = torch.from_numpy(apo_coords[:, 1, :]).float().to(self.device) |
|
|
| self._initialized = True |
| logger.info(f"Q_theta guidance initialized: holo={len(holo_res)} res, apo={len(apo_res)} res") |
|
|
| def extract_binder_backbone(self, x_coords, input_feature_dict): |
| """ |
| Extract binder backbone atoms (N, CA, C, O) from PXDesign's flat atom array. |
| |
| PXDesign stores all atoms in a flat [N_atom, 3] array. Entity annotations |
| identify which atoms belong to the designed binder (entity_id=2 typically, |
| or the last entity). We extract backbone atoms for each binder residue. |
| |
| Args: |
| x_coords: [N_sample, N_atom, 3] — current coordinates from diffusion |
| input_feature_dict: dict with atom_to_token_idx, entity_id, etc. |
| |
| Returns: |
| binder_bb: [N_sample, N_binder_res, 4, 3] — backbone coords (N, CA, C, O) |
| binder_mask: [N_binder_res] — validity mask |
| rec_bb: [N_rec_res, 4, 3] — receptor backbone coords (from condition) |
| rec_mask: [N_rec_res] — receptor validity mask |
| binder_atom_indices: [N_binder_bb_atoms] — indices into flat atom array |
| """ |
| atom_to_token = input_feature_dict['atom_to_token_idx'] |
| if atom_to_token.dim() > 1: |
| atom_to_token = atom_to_token.squeeze(0) |
|
|
| |
| |
| design_token_mask = input_feature_dict.get('design_token_mask', None) |
| if design_token_mask is not None: |
| if design_token_mask.dim() > 1: |
| design_token_mask = design_token_mask.squeeze(0) |
| binder_tokens = torch.where(design_token_mask)[0] |
| rec_tokens = torch.where(~design_token_mask)[0] |
| else: |
| |
| entity_id = input_feature_dict['entity_id'] |
| if entity_id.dim() > 1: |
| entity_id = entity_id.squeeze(0) |
| max_entity = entity_id.max() |
| binder_tokens = torch.where(entity_id == max_entity)[0] |
| rec_tokens = torch.where(entity_id != max_entity)[0] |
|
|
| |
| |
| |
| |
|
|
| |
| n_binder_res = len(binder_tokens) |
| if n_binder_res == 0: |
| return None |
|
|
| |
| binder_bb_list = [] |
| binder_atom_idx_list = [] |
| for tok_idx in binder_tokens: |
| atom_indices = torch.where(atom_to_token == tok_idx.item())[0] |
| if len(atom_indices) >= 4: |
| |
| bb_atoms = atom_indices[:4] |
| binder_bb_list.append(bb_atoms) |
| binder_atom_idx_list.append(bb_atoms) |
|
|
| if not binder_bb_list: |
| return None |
|
|
| n_binder_res = len(binder_bb_list) |
| binder_bb_indices = torch.stack(binder_bb_list) |
| all_binder_atom_indices = torch.cat(binder_atom_idx_list) |
|
|
| |
| |
| binder_bb = x_coords[:, binder_bb_indices, :] |
| binder_mask = torch.ones(n_binder_res, dtype=torch.bool, device=x_coords.device) |
|
|
| |
| |
| |
| |
| |
| cond_coords = input_feature_dict.get('condition_coordinate', None) |
| if cond_coords is None: |
| |
| label_dict = input_feature_dict.get('label_dict', None) |
| if label_dict is not None: |
| cond_coords = label_dict.get('condition_coordinate', None) |
|
|
| rec_bb = None |
| rec_mask = None |
|
|
| |
| rec_bb_list = [] |
| for tok_idx in rec_tokens: |
| atom_indices = torch.where(atom_to_token == tok_idx.item())[0] |
| if len(atom_indices) >= 4: |
| rec_bb_list.append(atom_indices[:4]) |
|
|
| if rec_bb_list: |
| rec_bb_indices = torch.stack(rec_bb_list) |
|
|
| if cond_coords is not None: |
| if cond_coords.dim() > 2: |
| cond_coords = cond_coords.squeeze(0) |
| rec_bb = cond_coords[rec_bb_indices, :] |
| else: |
| |
| |
| rec_bb = x_coords[0, rec_bb_indices, :].detach() |
|
|
| rec_mask = torch.ones(len(rec_bb_list), dtype=torch.bool, |
| device=x_coords.device) |
|
|
| return { |
| 'binder_bb': binder_bb, |
| 'binder_mask': binder_mask, |
| 'rec_bb': rec_bb, |
| 'rec_mask': rec_mask, |
| 'binder_atom_indices': binder_bb_indices, |
| 'all_binder_atom_indices': all_binder_atom_indices, |
| } |
|
|
| def align_and_score(self, binder_bb, rec_bb, rec_mask, receptor_label): |
| """ |
| Align binder to a reference receptor frame and score with Q_theta. |
| |
| Uses the receptor chain from the design to compute Kabsch alignment |
| to the reference receptor, then transforms the binder accordingly. |
| |
| Args: |
| binder_bb: [N_binder, 4, 3] — binder backbone coords (requires_grad) |
| rec_bb: [N_rec, 4, 3] — receptor backbone coords |
| rec_mask: [N_rec] bool |
| receptor_label: 'holo' or 'apo' |
| |
| Returns: |
| score: scalar tensor, differentiable w.r.t. binder_bb |
| """ |
| if receptor_label == 'holo': |
| ref_ca = self.ref_holo_ca |
| else: |
| ref_ca = self.ref_apo_ca |
|
|
| |
| rec_ca = rec_bb[:, 1, :] |
|
|
| |
| n_align = min(len(rec_ca), len(ref_ca)) |
| if n_align < 5: |
| return torch.zeros(1, device=binder_bb.device, dtype=binder_bb.dtype, |
| requires_grad=True).squeeze() |
|
|
| mobile_ca = rec_ca[:n_align].detach() |
| target_ca = ref_ca[:n_align].detach() |
|
|
| |
| R, t = differentiable_kabsch(mobile_ca, target_ca) |
| R = R.detach() |
| t = t.detach() |
|
|
| |
| binder_flat = binder_bb.reshape(-1, 3) |
| aligned = binder_flat @ R.T + t |
| aligned_bb = aligned.reshape(-1, 4, 3) |
|
|
| |
| binder_mask = torch.ones(aligned_bb.shape[0], dtype=torch.bool, |
| device=binder_bb.device) |
| score = self.dq.score(aligned_bb, binder_mask, receptor_label=receptor_label, |
| cutoff=self.cutoff) |
| return score |
|
|
| def compute_guidance_gradient(self, x_denoised, input_feature_dict, t_hat=None, |
| sample_idx=0): |
| """ |
| Compute Q_theta selectivity gradient for guidance. |
| |
| Args: |
| x_denoised: [N_sample, N_atom, 3] — denoised coordinates from diffusion net |
| input_feature_dict: PXDesign input features dict |
| t_hat: current noise level (for logging/scaling) |
| sample_idx: which sample to compute gradient for (or -1 for all) |
| |
| Returns: |
| gradient: [N_sample, N_atom, 3] — gradient to add to x_denoised |
| (non-zero only at binder backbone atom positions) |
| margin: float — current selectivity margin |
| """ |
| self._lazy_init() |
|
|
| extraction = self.extract_binder_backbone(x_denoised.detach(), input_feature_dict) |
| if extraction is None: |
| return torch.zeros_like(x_denoised), 0.0 |
|
|
| binder_bb = extraction['binder_bb'] |
| binder_mask = extraction['binder_mask'] |
| rec_bb = extraction['rec_bb'] |
| rec_mask = extraction['rec_mask'] |
| binder_atom_indices = extraction['binder_atom_indices'] |
|
|
| if rec_bb is None: |
| return torch.zeros_like(x_denoised), 0.0 |
|
|
| N_sample = x_denoised.shape[0] |
| gradient = torch.zeros_like(x_denoised) |
| margins = [] |
|
|
| |
| if rec_bb is not None: |
| rec_bb = rec_bb.float() |
|
|
| |
| indices = range(N_sample) if sample_idx == -1 else [sample_idx] |
| for si in indices: |
| |
| binder_si = binder_bb[si].clone().float().requires_grad_(True) |
|
|
| try: |
| with torch.enable_grad(): |
| q_holo = self.align_and_score(binder_si, rec_bb, rec_mask, 'holo') |
| q_apo = self.align_and_score(binder_si, rec_bb, rec_mask, 'apo') |
| margin = q_holo - q_apo |
| margin.backward() |
|
|
| if binder_si.grad is not None and not torch.isnan(binder_si.grad).any(): |
| |
| grad_bb = binder_si.grad |
| for ri in range(len(binder_atom_indices)): |
| for ai in range(4): |
| atom_idx = binder_atom_indices[ri, ai] |
| gradient[si, atom_idx] = grad_bb[ri, ai] |
| margins.append(margin.item()) |
| else: |
| margins.append(0.0) |
| except Exception as e: |
| logger.debug(f"Gradient computation failed for sample {si}: {e}") |
| margins.append(0.0) |
|
|
| avg_margin = np.mean(margins) if margins else 0.0 |
| return gradient, avg_margin |
|
|
| def score_design(self, pdb_path, rec_chain='A', binder_chain='B'): |
| """ |
| Score a single PXDesign output PDB/CIF (post-hoc, no gradient). |
| |
| Handles PXDesign CIF files which use chain IDs like 'A0'/'B0' and |
| non-standard residue name 'xpb' for designed binder residues. |
| |
| Returns: |
| dict with q_holo, q_apo, margin, or None on failure |
| """ |
| self._lazy_init() |
|
|
| from utils.pdb_utils import ( |
| load_structure, get_residues, get_backbone_coords, |
| get_aa_indices, align_structures |
| ) |
|
|
| try: |
| model = load_structure(pdb_path) |
| chains = {c.get_id(): c for c in model.get_chains()} |
|
|
| if len(chains) < 2: |
| return None |
|
|
| chain_ids = sorted(chains.keys()) |
|
|
| |
| |
| rc, bc = None, None |
| if rec_chain in chains and binder_chain in chains: |
| rc, bc = rec_chain, binder_chain |
| else: |
| |
| |
| ref_model = load_structure(self.ref_holo) |
| ref_res = get_residues(ref_model[self.ref_chain]) |
| ref_len = len(ref_res) |
| for cid in chain_ids: |
| |
| cres = get_residues(chains[cid]) |
| if not cres: |
| cres = get_residues(chains[cid], only_standard=False) |
| n_res = len(cres) |
| if n_res > 0 and abs(n_res - ref_len) < ref_len * 0.3: |
| rc = cid |
| elif n_res > 0: |
| bc = cid |
| if rc is None or bc is None: |
| rc, bc = chain_ids[0], chain_ids[1] |
|
|
| rec_res = get_residues(chains[rc]) |
| if not rec_res: |
| rec_res = get_residues(chains[rc], only_standard=False) |
|
|
| |
| binder_res = get_residues(chains[bc]) |
| if not binder_res: |
| binder_res = get_residues(chains[bc], only_standard=False) |
|
|
| if not rec_res or not binder_res: |
| return None |
|
|
| rec_coords, rec_mask = get_backbone_coords(rec_res) |
| binder_coords, binder_mask = get_backbone_coords(binder_res) |
|
|
| |
| |
| try: |
| aa_idx = get_aa_indices(binder_res) |
| except Exception: |
| aa_idx = np.zeros(len(binder_res), dtype=np.int64) |
|
|
| device = self.device |
|
|
| |
| rec_ca = rec_coords[:, 1, :] |
| ref_holo_ca_np = self.ref_holo_ca.cpu().numpy() |
| n_align = min(len(rec_ca), len(ref_holo_ca_np)) |
| if n_align < 5: |
| return None |
| _, R_h = align_structures(rec_ca[:n_align], ref_holo_ca_np[:n_align]) |
| center_h = rec_ca[:n_align].mean(0) |
| ref_center_h = ref_holo_ca_np[:n_align].mean(0) |
| aligned_holo = (binder_coords.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h |
| aligned_holo = aligned_holo.reshape(-1, 4, 3) |
|
|
| |
| ref_apo_ca_np = self.ref_apo_ca.cpu().numpy() |
| n_align_a = min(len(rec_ca), len(ref_apo_ca_np)) |
| _, R_a = align_structures(rec_ca[:n_align_a], ref_apo_ca_np[:n_align_a]) |
| center_a = rec_ca[:n_align_a].mean(0) |
| ref_center_a = ref_apo_ca_np[:n_align_a].mean(0) |
| aligned_apo = (binder_coords.reshape(-1, 3) - center_a) @ R_a.T + ref_center_a |
| aligned_apo = aligned_apo.reshape(-1, 4, 3) |
|
|
| with torch.no_grad(): |
| coords_h = torch.from_numpy(aligned_holo).float().to(device) |
| coords_a = torch.from_numpy(aligned_apo).float().to(device) |
| mask_t = torch.from_numpy(binder_mask).bool().to(device) |
| aa_t = torch.from_numpy(aa_idx).long().to(device) |
|
|
| q_holo = self.dq.score(coords_h, mask_t, binder_aa_idx=aa_t, |
| receptor_label='holo').item() |
| q_apo = self.dq.score(coords_a, mask_t, binder_aa_idx=aa_t, |
| receptor_label='apo').item() |
|
|
| return { |
| 'q_holo': q_holo, |
| 'q_apo': q_apo, |
| 'margin': q_holo - q_apo, |
| 'n_res': len(binder_res), |
| } |
|
|
| except Exception as e: |
| logger.warning(f"Error scoring {pdb_path}: {e}") |
| return None |
|
|