File size: 19,978 Bytes
ad9572d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
"""
Core Q_theta guidance module for PXDesign integration.

Provides differentiable Q_theta scoring for PXDesign's atom coordinate format.
Key responsibilities:
  - Extract binder backbone (N, CA, C, O) from PXDesign's flat atom array
  - Align binder to reference receptor frames via differentiable Kabsch
  - Compute selectivity gradient βˆ‡[Q(holo,Y) - Q(apo,Y)] w.r.t. atom coords
  - Works in pxdesign env (PyTorch 2.3.1) using pure-PyTorch scorer (no e3nn)

Usage:
    guidance = QThetaPXDesignGuidance(
        checkpoint='results/checkpoints_cam_v3/best_phase2.pt',
        ref_holo='data/pdbs/cam_holo/3CLN.pdb',
        ref_apo='data/pdbs/cam_apo/1CFD.pdb',
        ref_chain='A',
        device='cuda:0',
    )
    # Inside PXDesign diffusion loop:
    grad = guidance.compute_guidance_gradient(x_denoised, input_feature_dict, t_hat)
    x_denoised = x_denoised + scale * grad
"""

import os
import sys
import logging
import numpy as np
import torch

logger = logging.getLogger(__name__)

# Add Allo-Designer code directory to path
_ALLO_CODE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
if _ALLO_CODE_DIR not in sys.path:
    sys.path.insert(0, _ALLO_CODE_DIR)


def differentiable_kabsch(mobile, target):
    """
    Differentiable Kabsch alignment using SVD.

    Args:
        mobile: [N, 3] tensor (points to align FROM)
        target: [N, 3] tensor (points to align TO)

    Returns:
        R: [3, 3] rotation matrix
        t: [3] translation vector
        Such that aligned = (mobile - mobile_center) @ R.T + target_center
    """
    mobile_center = mobile.mean(dim=0)
    target_center = target.mean(dim=0)

    mobile_centered = mobile - mobile_center
    target_centered = target - target_center

    H = mobile_centered.T @ target_centered  # [3, 3]
    U, S, Vh = torch.linalg.svd(H)
    V = Vh.T

    # Ensure proper rotation (det > 0)
    d = torch.det(V @ U.T)
    sign_matrix = torch.diag(torch.tensor([1.0, 1.0, torch.sign(d)],
                                           device=mobile.device, dtype=mobile.dtype))
    R = V @ sign_matrix @ U.T  # [3, 3]
    t = target_center - mobile_center @ R.T  # [3]

    return R, t


class QThetaPXDesignGuidance:
    """
    Q_theta guidance for PXDesign diffusion process.

    Lazily initializes the scorer and reference structures on first use.
    Handles extraction of binder backbone from PXDesign's flat atom array
    and alignment to reference receptor frames.
    """

    def __init__(self, checkpoint, ref_holo, ref_apo, ref_chain='A',
                 device='cuda:0', cutoff=8.0, esm_target='cam'):
        self.checkpoint = checkpoint
        self.ref_holo = ref_holo
        self.ref_apo = ref_apo
        self.ref_chain = ref_chain
        self.device = torch.device(device)
        self.cutoff = cutoff
        self.esm_target = esm_target

        self._initialized = False
        self.dq = None
        self.ref_holo_ca = None
        self.ref_apo_ca = None

    def _lazy_init(self):
        """Initialize Q_theta scorer and load reference structures."""
        if self._initialized:
            return

        from models.differentiable_features import DifferentiableQTheta
        from utils.pdb_utils import load_structure, get_residues, get_backbone_coords

        logger.info(f"Loading Q_theta checkpoint: {self.checkpoint}")
        self.dq = DifferentiableQTheta(self.checkpoint, device=str(self.device))
        self.dq.load_receptor(self.ref_holo, chain=self.ref_chain, label='holo',
                             esm_target=self.esm_target)
        self.dq.load_receptor(self.ref_apo, chain=self.ref_chain, label='apo',
                             esm_target=self.esm_target)

        # Cache reference CA coords for alignment
        holo_model = load_structure(self.ref_holo)
        holo_res = get_residues(holo_model[self.ref_chain])
        holo_coords, _ = get_backbone_coords(holo_res)
        self.ref_holo_ca = torch.from_numpy(holo_coords[:, 1, :]).float().to(self.device)

        apo_model = load_structure(self.ref_apo)
        apo_res = get_residues(apo_model[self.ref_chain])
        apo_coords, _ = get_backbone_coords(apo_res)
        self.ref_apo_ca = torch.from_numpy(apo_coords[:, 1, :]).float().to(self.device)

        self._initialized = True
        logger.info(f"Q_theta guidance initialized: holo={len(holo_res)} res, apo={len(apo_res)} res")

    def extract_binder_backbone(self, x_coords, input_feature_dict):
        """
        Extract binder backbone atoms (N, CA, C, O) from PXDesign's flat atom array.

        PXDesign stores all atoms in a flat [N_atom, 3] array. Entity annotations
        identify which atoms belong to the designed binder (entity_id=2 typically,
        or the last entity). We extract backbone atoms for each binder residue.

        Args:
            x_coords: [N_sample, N_atom, 3] β€” current coordinates from diffusion
            input_feature_dict: dict with atom_to_token_idx, entity_id, etc.

        Returns:
            binder_bb: [N_sample, N_binder_res, 4, 3] β€” backbone coords (N, CA, C, O)
            binder_mask: [N_binder_res] β€” validity mask
            rec_bb: [N_rec_res, 4, 3] β€” receptor backbone coords (from condition)
            rec_mask: [N_rec_res] β€” receptor validity mask
            binder_atom_indices: [N_binder_bb_atoms] β€” indices into flat atom array
        """
        atom_to_token = input_feature_dict['atom_to_token_idx']  # [N_atom]
        if atom_to_token.dim() > 1:
            atom_to_token = atom_to_token.squeeze(0)

        # Identify binder vs receptor tokens
        # In PXDesign: design_token_mask=True for binder tokens
        design_token_mask = input_feature_dict.get('design_token_mask', None)
        if design_token_mask is not None:
            if design_token_mask.dim() > 1:
                design_token_mask = design_token_mask.squeeze(0)
            binder_tokens = torch.where(design_token_mask)[0]
            rec_tokens = torch.where(~design_token_mask)[0]
        else:
            # Fallback: use entity_id (binder is typically entity_id=2, the last entity)
            entity_id = input_feature_dict['entity_id']
            if entity_id.dim() > 1:
                entity_id = entity_id.squeeze(0)
            max_entity = entity_id.max()
            binder_tokens = torch.where(entity_id == max_entity)[0]
            rec_tokens = torch.where(entity_id != max_entity)[0]

        # Map tokens to atoms
        # For standard amino acids, atom order within each token is:
        # N(0), CA(1), C(2), O(3), CB(4), ...
        # We need atoms 0-3 (N, CA, C, O) per token

        # Get atom indices for each binder token
        n_binder_res = len(binder_tokens)
        if n_binder_res == 0:
            return None

        # Find atoms belonging to each binder residue
        binder_bb_list = []
        binder_atom_idx_list = []
        for tok_idx in binder_tokens:
            atom_indices = torch.where(atom_to_token == tok_idx.item())[0]
            if len(atom_indices) >= 4:
                # First 4 atoms are N, CA, C, O for standard amino acids
                bb_atoms = atom_indices[:4]
                binder_bb_list.append(bb_atoms)
                binder_atom_idx_list.append(bb_atoms)

        if not binder_bb_list:
            return None

        n_binder_res = len(binder_bb_list)
        binder_bb_indices = torch.stack(binder_bb_list)  # [N_binder, 4]
        all_binder_atom_indices = torch.cat(binder_atom_idx_list)  # [N_binder * 4]

        # Extract binder backbone coords for all samples
        # x_coords: [N_sample, N_atom, 3]
        binder_bb = x_coords[:, binder_bb_indices, :]  # [N_sample, N_binder, 4, 3]
        binder_mask = torch.ones(n_binder_res, dtype=torch.bool, device=x_coords.device)

        # Extract receptor backbone from x_coords or condition_coordinate.
        # PXDesign stores condition_coordinate in label_dict (not input_feature_dict),
        # so we extract receptor backbone from x_coords directly. In the diffusion
        # process, receptor atoms are conditioned at their reference positions.
        # Try condition_coordinate first (if available), then fall back to x_coords.
        cond_coords = input_feature_dict.get('condition_coordinate', None)
        if cond_coords is None:
            # Also try label_dict nesting
            label_dict = input_feature_dict.get('label_dict', None)
            if label_dict is not None:
                cond_coords = label_dict.get('condition_coordinate', None)

        rec_bb = None
        rec_mask = None

        # Get receptor backbone atoms
        rec_bb_list = []
        for tok_idx in rec_tokens:
            atom_indices = torch.where(atom_to_token == tok_idx.item())[0]
            if len(atom_indices) >= 4:
                rec_bb_list.append(atom_indices[:4])

        if rec_bb_list:
            rec_bb_indices = torch.stack(rec_bb_list)  # [N_rec, 4]

            if cond_coords is not None:
                if cond_coords.dim() > 2:
                    cond_coords = cond_coords.squeeze(0)
                rec_bb = cond_coords[rec_bb_indices, :]  # [N_rec, 4, 3]
            else:
                # Fallback: extract receptor coords from x_coords (sample 0)
                # Receptor atoms are conditioned and constant across samples
                rec_bb = x_coords[0, rec_bb_indices, :].detach()  # [N_rec, 4, 3]

            rec_mask = torch.ones(len(rec_bb_list), dtype=torch.bool,
                                  device=x_coords.device)

        return {
            'binder_bb': binder_bb,           # [N_sample, N_binder, 4, 3]
            'binder_mask': binder_mask,        # [N_binder]
            'rec_bb': rec_bb,                  # [N_rec, 4, 3] or None
            'rec_mask': rec_mask,              # [N_rec] or None
            'binder_atom_indices': binder_bb_indices,  # [N_binder, 4]
            'all_binder_atom_indices': all_binder_atom_indices,  # [N_binder * 4]
        }

    def align_and_score(self, binder_bb, rec_bb, rec_mask, receptor_label):
        """
        Align binder to a reference receptor frame and score with Q_theta.

        Uses the receptor chain from the design to compute Kabsch alignment
        to the reference receptor, then transforms the binder accordingly.

        Args:
            binder_bb: [N_binder, 4, 3] β€” binder backbone coords (requires_grad)
            rec_bb: [N_rec, 4, 3] β€” receptor backbone coords
            rec_mask: [N_rec] bool
            receptor_label: 'holo' or 'apo'

        Returns:
            score: scalar tensor, differentiable w.r.t. binder_bb
        """
        if receptor_label == 'holo':
            ref_ca = self.ref_holo_ca
        else:
            ref_ca = self.ref_apo_ca

        # Get CA atoms from receptor
        rec_ca = rec_bb[:, 1, :]  # [N_rec, 3]

        # Use overlapping residues for alignment (take min length)
        n_align = min(len(rec_ca), len(ref_ca))
        if n_align < 5:
            return torch.zeros(1, device=binder_bb.device, dtype=binder_bb.dtype,
                               requires_grad=True).squeeze()

        mobile_ca = rec_ca[:n_align].detach()
        target_ca = ref_ca[:n_align].detach()

        # Compute Kabsch alignment (detached β€” no gradient through rotation)
        R, t = differentiable_kabsch(mobile_ca, target_ca)
        R = R.detach()
        t = t.detach()

        # Apply transform to binder (gradient flows through binder_bb)
        binder_flat = binder_bb.reshape(-1, 3)  # [N_binder*4, 3]
        aligned = binder_flat @ R.T + t  # [N_binder*4, 3]
        aligned_bb = aligned.reshape(-1, 4, 3)  # [N_binder, 4, 3]

        # Score with Q_theta
        binder_mask = torch.ones(aligned_bb.shape[0], dtype=torch.bool,
                                  device=binder_bb.device)
        score = self.dq.score(aligned_bb, binder_mask, receptor_label=receptor_label,
                              cutoff=self.cutoff)
        return score

    def compute_guidance_gradient(self, x_denoised, input_feature_dict, t_hat=None,
                                   sample_idx=0):
        """
        Compute Q_theta selectivity gradient for guidance.

        Args:
            x_denoised: [N_sample, N_atom, 3] β€” denoised coordinates from diffusion net
            input_feature_dict: PXDesign input features dict
            t_hat: current noise level (for logging/scaling)
            sample_idx: which sample to compute gradient for (or -1 for all)

        Returns:
            gradient: [N_sample, N_atom, 3] β€” gradient to add to x_denoised
                      (non-zero only at binder backbone atom positions)
            margin: float β€” current selectivity margin
        """
        self._lazy_init()

        extraction = self.extract_binder_backbone(x_denoised.detach(), input_feature_dict)
        if extraction is None:
            return torch.zeros_like(x_denoised), 0.0

        binder_bb = extraction['binder_bb']      # [N_sample, N_binder, 4, 3]
        binder_mask = extraction['binder_mask']   # [N_binder]
        rec_bb = extraction['rec_bb']             # [N_rec, 4, 3]
        rec_mask = extraction['rec_mask']          # [N_rec]
        binder_atom_indices = extraction['binder_atom_indices']  # [N_binder, 4]

        if rec_bb is None:
            return torch.zeros_like(x_denoised), 0.0

        N_sample = x_denoised.shape[0]
        gradient = torch.zeros_like(x_denoised)
        margins = []

        # Ensure receptor is float32 for Q_theta scoring
        if rec_bb is not None:
            rec_bb = rec_bb.float()

        # Process each sample
        indices = range(N_sample) if sample_idx == -1 else [sample_idx]
        for si in indices:
            # Make binder coords differentiable, cast to float32 for Q_theta
            binder_si = binder_bb[si].clone().float().requires_grad_(True)  # [N_binder, 4, 3]

            try:
                with torch.enable_grad():
                    q_holo = self.align_and_score(binder_si, rec_bb, rec_mask, 'holo')
                    q_apo = self.align_and_score(binder_si, rec_bb, rec_mask, 'apo')
                    margin = q_holo - q_apo
                    margin.backward()

                if binder_si.grad is not None and not torch.isnan(binder_si.grad).any():
                    # Map gradient back to full atom array
                    grad_bb = binder_si.grad  # [N_binder, 4, 3]
                    for ri in range(len(binder_atom_indices)):
                        for ai in range(4):
                            atom_idx = binder_atom_indices[ri, ai]
                            gradient[si, atom_idx] = grad_bb[ri, ai]
                    margins.append(margin.item())
                else:
                    margins.append(0.0)
            except Exception as e:
                logger.debug(f"Gradient computation failed for sample {si}: {e}")
                margins.append(0.0)

        avg_margin = np.mean(margins) if margins else 0.0
        return gradient, avg_margin

    def score_design(self, pdb_path, rec_chain='A', binder_chain='B'):
        """
        Score a single PXDesign output PDB/CIF (post-hoc, no gradient).

        Handles PXDesign CIF files which use chain IDs like 'A0'/'B0' and
        non-standard residue name 'xpb' for designed binder residues.

        Returns:
            dict with q_holo, q_apo, margin, or None on failure
        """
        self._lazy_init()

        from utils.pdb_utils import (
            load_structure, get_residues, get_backbone_coords,
            get_aa_indices, align_structures
        )

        try:
            model = load_structure(pdb_path)
            chains = {c.get_id(): c for c in model.get_chains()}

            if len(chains) < 2:
                return None

            chain_ids = sorted(chains.keys())

            # Identify receptor and binder
            # PXDesign CIF uses chain IDs like 'A0', 'B0' instead of 'A', 'B'
            rc, bc = None, None
            if rec_chain in chains and binder_chain in chains:
                rc, bc = rec_chain, binder_chain
            else:
                # Match by residue count: receptor matches reference length,
                # binder is the other chain
                ref_model = load_structure(self.ref_holo)
                ref_res = get_residues(ref_model[self.ref_chain])
                ref_len = len(ref_res)
                for cid in chain_ids:
                    # Try standard residues first, then all residues
                    cres = get_residues(chains[cid])
                    if not cres:
                        cres = get_residues(chains[cid], only_standard=False)
                    n_res = len(cres)
                    if n_res > 0 and abs(n_res - ref_len) < ref_len * 0.3:
                        rc = cid
                    elif n_res > 0:
                        bc = cid
                if rc is None or bc is None:
                    rc, bc = chain_ids[0], chain_ids[1]

            rec_res = get_residues(chains[rc])
            if not rec_res:
                rec_res = get_residues(chains[rc], only_standard=False)

            # For binder: PXDesign uses 'xpb' residue names (non-standard)
            binder_res = get_residues(chains[bc])
            if not binder_res:
                binder_res = get_residues(chains[bc], only_standard=False)

            if not rec_res or not binder_res:
                return None

            rec_coords, rec_mask = get_backbone_coords(rec_res)
            binder_coords, binder_mask = get_backbone_coords(binder_res)

            # Handle amino acid indices: use get_aa_indices for standard AAs,
            # default to GLY (7) for non-standard (PXDesign 'xpb')
            try:
                aa_idx = get_aa_indices(binder_res)
            except Exception:
                aa_idx = np.zeros(len(binder_res), dtype=np.int64)  # default to ALA

            device = self.device

            # Align to holo
            rec_ca = rec_coords[:, 1, :]
            ref_holo_ca_np = self.ref_holo_ca.cpu().numpy()
            n_align = min(len(rec_ca), len(ref_holo_ca_np))
            if n_align < 5:
                return None
            _, R_h = align_structures(rec_ca[:n_align], ref_holo_ca_np[:n_align])
            center_h = rec_ca[:n_align].mean(0)
            ref_center_h = ref_holo_ca_np[:n_align].mean(0)
            aligned_holo = (binder_coords.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h
            aligned_holo = aligned_holo.reshape(-1, 4, 3)

            # Align to apo
            ref_apo_ca_np = self.ref_apo_ca.cpu().numpy()
            n_align_a = min(len(rec_ca), len(ref_apo_ca_np))
            _, R_a = align_structures(rec_ca[:n_align_a], ref_apo_ca_np[:n_align_a])
            center_a = rec_ca[:n_align_a].mean(0)
            ref_center_a = ref_apo_ca_np[:n_align_a].mean(0)
            aligned_apo = (binder_coords.reshape(-1, 3) - center_a) @ R_a.T + ref_center_a
            aligned_apo = aligned_apo.reshape(-1, 4, 3)

            with torch.no_grad():
                coords_h = torch.from_numpy(aligned_holo).float().to(device)
                coords_a = torch.from_numpy(aligned_apo).float().to(device)
                mask_t = torch.from_numpy(binder_mask).bool().to(device)
                aa_t = torch.from_numpy(aa_idx).long().to(device)

                q_holo = self.dq.score(coords_h, mask_t, binder_aa_idx=aa_t,
                                        receptor_label='holo').item()
                q_apo = self.dq.score(coords_a, mask_t, binder_aa_idx=aa_t,
                                       receptor_label='apo').item()

            return {
                'q_holo': q_holo,
                'q_apo': q_apo,
                'margin': q_holo - q_apo,
                'n_res': len(binder_res),
            }

        except Exception as e:
            logger.warning(f"Error scoring {pdb_path}: {e}")
            return None