chq1155 Claude Opus 4.7 (1M context) commited on
Commit
ad9572d
·
0 Parent(s):

AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo

Browse files

Single commit, no prior history.

Contents:
- Q_theta scorer (graph transformer, SE(3)-invariant + ESM-2 conditioning,
~898K params, v4-S2 target-swap checkpoint via Git LFS).
- PXDesign guidance scripts (Langevin / SMC / TDS / classifier) under
code/scripts/pxdesign_guidance/.
- CaM inference sample (96 binder-CaM graphs + matching ESM-2 features).
- Colab demo at notebooks/AlloGen_CaM_demo.ipynb (one-click for biology users:
load scorer, score 96 designs, view ROC/best-of-K, guidance recipe).
- README with method figure, inference quickstart, full Python scoring API.

MIT license.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +14 -0
  2. .gitignore +41 -0
  3. LICENSE +21 -0
  4. README.md +133 -0
  5. checkpoints/Q_theta_phase1.pt +3 -0
  6. checkpoints/Q_theta_phase2.pt +3 -0
  7. checkpoints/Q_theta_train_curve.csv +16 -0
  8. code/__init__.py +0 -0
  9. code/data/__init__.py +0 -0
  10. code/data/dataset.py +832 -0
  11. code/models/__init__.py +0 -0
  12. code/models/differentiable_features.py +622 -0
  13. code/models/features.py +250 -0
  14. code/models/scorer.py +585 -0
  15. code/requirements.txt +22 -0
  16. code/scripts/README.md +55 -0
  17. code/scripts/evaluate.py +332 -0
  18. code/scripts/pxdesign_guidance/__init__.py +1 -0
  19. code/scripts/pxdesign_guidance/convert_cif_to_pdb.py +132 -0
  20. code/scripts/pxdesign_guidance/guided_pxdesign.py +408 -0
  21. code/scripts/pxdesign_guidance/iterative_refinement.py +338 -0
  22. code/scripts/pxdesign_guidance/langevin_pxdesign.py +374 -0
  23. code/scripts/pxdesign_guidance/qtheta_pxdesign.py +477 -0
  24. code/scripts/pxdesign_guidance/smc_pxdesign.py +262 -0
  25. code/scripts/pxdesign_guidance/tds_pxdesign.py +323 -0
  26. code/scripts/rescore.py +178 -0
  27. code/trainers/__init__.py +0 -0
  28. code/trainers/trainer.py +674 -0
  29. code/utils/__init__.py +0 -0
  30. code/utils/anm.py +208 -0
  31. code/utils/path_utils.py +448 -0
  32. code/utils/pdb_utils.py +472 -0
  33. code/utils/sam.py +54 -0
  34. data/sample/README.md +49 -0
  35. data/sample/cam/test.pkl +3 -0
  36. data/sample/esm2_embeddings/cam/1IWQ_A.pt +3 -0
  37. data/sample/esm2_embeddings/cam/1IWQ_B.pt +3 -0
  38. data/sample/esm2_embeddings/cam/1K93_A.pt +3 -0
  39. data/sample/esm2_embeddings/cam/1K93_B.pt +3 -0
  40. data/sample/esm2_embeddings/cam/1NWD_A.pt +3 -0
  41. data/sample/esm2_embeddings/cam/1NWD_B.pt +3 -0
  42. data/sample/esm2_embeddings/cam/1SY9_A.pt +3 -0
  43. data/sample/esm2_embeddings/cam/1SY9_B.pt +3 -0
  44. data/sample/esm2_embeddings/cam/2BBM_A.pt +3 -0
  45. data/sample/esm2_embeddings/cam/2BBM_B.pt +3 -0
  46. data/sample/esm2_embeddings/cam/2HQW_A.pt +3 -0
  47. data/sample/esm2_embeddings/cam/2HQW_B.pt +3 -0
  48. data/sample/esm2_embeddings/cam/2O5G_A.pt +3 -0
  49. data/sample/esm2_embeddings/cam/2O5G_B.pt +3 -0
  50. data/sample/esm2_embeddings/cam/3D33_A.pt +3 -0
.gitattributes ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.pth filter=lfs diff=lfs merge=lfs -text
3
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
4
+ *.pkl filter=lfs diff=lfs merge=lfs -text
5
+ *.pickle filter=lfs diff=lfs merge=lfs -text
6
+ *.npy filter=lfs diff=lfs merge=lfs -text
7
+ *.npz filter=lfs diff=lfs merge=lfs -text
8
+ *.bin filter=lfs diff=lfs merge=lfs -text
9
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
10
+ *.png filter=lfs diff=lfs merge=lfs -text
11
+ *.jpg filter=lfs diff=lfs merge=lfs -text
12
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
13
+ *.pdf filter=lfs diff=lfs merge=lfs -text
14
+ *.svg -text
.gitignore ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ .pytest_cache/
8
+ .coverage
9
+ *.egg-info/
10
+
11
+ # Env
12
+ .env
13
+ .venv/
14
+ venv/
15
+ env/
16
+
17
+ # OS
18
+ .DS_Store
19
+ Thumbs.db
20
+
21
+ # IDE
22
+ .idea/
23
+ .vscode/
24
+ *.swp
25
+
26
+ # Logs / runs / caches
27
+ *.log
28
+ logs/
29
+ outputs/
30
+ wandb/
31
+ .ipynb_checkpoints/
32
+
33
+ # Local scoring outputs
34
+ results/
35
+ /tmp_*
36
+
37
+ # Misc
38
+ *.tmp
39
+ *.bak
40
+ .allogen_test
41
+ .agent*_done.txt
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Hanqun Cao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - protein-design
5
+ - allosteric
6
+ - state-selectivity
7
+ - guided-generation
8
+ - rfdiffusion
9
+ - pxdesign
10
+ - proteina
11
+ library_name: pytorch
12
+ ---
13
+
14
+ # AlloGen
15
+
16
+ <p align="center">
17
+ <img src="figures/allogen_main.png" alt="AlloGen method overview" width="100%"/>
18
+ </p>
19
+
20
+ State-selectivity scoring + guided generation for allosteric binder design.
21
+
22
+ 🧪 **One-click demo for biology users:**
23
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/#fileId=https%3A//huggingface.co/ChatterjeeLab/AlloGen/raw/main/notebooks/AlloGen_CaM_demo.ipynb) — score CaM binders and run Q_θ-guided PXDesign sampling in 5 minutes. Notebook lives at [`notebooks/AlloGen_CaM_demo.ipynb`](notebooks/AlloGen_CaM_demo.ipynb).
24
+
25
+ AlloGen trains a scorer Q_θ(X, Y) ∈ (0,1) that ranks how well a binder Y discriminates a target's **holo** (active) state X¹ from its **apo** (inactive) state X⁰. The selectivity score is:
26
+
27
+ S(Y) = Q_θ(X¹, Y) − Q_θ(X⁰, Y)
28
+
29
+ Q_θ serves as both a re-ranker (best-of-K) and a gradient signal for guided generation on top of frozen priors (RFdiffusion, PXDesign, Proteina-ComplexA) via Langevin, SMC, TDS, or classifier guidance.
30
+
31
+ This repository accompanies the paper *AlloGen: State-Selective Scoring for Allosteric Binder Design* (NeurIPS 2026).
32
+
33
+ ## Installation
34
+
35
+ ```bash
36
+ conda env create -f environment.yml
37
+ conda activate allogen
38
+ ```
39
+
40
+ Or pip-only:
41
+
42
+ ```bash
43
+ python -m venv .venv && source .venv/bin/activate
44
+ pip install -r requirements.txt
45
+ ```
46
+
47
+ Python 3.10 + PyTorch 2.x are required. A CUDA GPU is recommended for guidance, but CPU works for scoring single designs.
48
+
49
+ ## Inference quickstart
50
+
51
+ ```bash
52
+ # Score the bundled CaM inference sample against the v4-S2 (target-swap) checkpoint
53
+ python code/scripts/evaluate.py \
54
+ --target cam \
55
+ --checkpoint checkpoints/Q_theta_phase2.pt \
56
+ --data_dir data/sample/ \
57
+ --outdir /tmp/cam_inference \
58
+ --no_wandb
59
+ ```
60
+
61
+ See [`inference.md`](inference.md) for the scoring API + guidance command lines.
62
+
63
+ ## Repo layout
64
+
65
+ ```
66
+ code/
67
+ data/ dataset / graph construction, PDB I/O, target YAMLs
68
+ models/ Q_θ scorer (graph transformer) + differentiable wrapper
69
+ trainers/ two-phase training loop (DockQ regression + selectivity)
70
+ utils/ PDB I/O, backbone frames, SAM optimizer
71
+ scripts/ evaluate, rescore, PXDesign guidance (see scripts/README.md)
72
+ checkpoints/ Q_θ paper weights (v4-S2 target-swap split, via Git LFS)
73
+ data/sample/ tiny CaM inference sample (test split only)
74
+ ```
75
+
76
+ ## Checkpoints
77
+
78
+ Paper weights for the **v4-S2 target-swap** split are bundled via **Git LFS**:
79
+
80
+ ```bash
81
+ git lfs install
82
+ git lfs pull
83
+ ```
84
+
85
+ | File | Use |
86
+ |---|---|
87
+ | `checkpoints/Q_theta_phase1.pt` | Phase 1 (DockQ regression) intermediate checkpoint |
88
+ | `checkpoints/Q_theta_phase2.pt` | Phase 2 (selectivity) — main paper result |
89
+ | `checkpoints/Q_theta_train_curve.csv` | Training curve metadata |
90
+
91
+ ## Scoring a single design
92
+
93
+ ```python
94
+ import sys; sys.path.insert(0, 'code')
95
+ from models.differentiable_features import DifferentiableQTheta
96
+
97
+ scorer = DifferentiableQTheta(
98
+ checkpoint='checkpoints/Q_theta_phase2.pt',
99
+ device='cuda:0',
100
+ )
101
+ scorer.load_receptor(
102
+ holo_path='your_holo.pdb', rec_chain='A',
103
+ apo_path='your_apo.pdb', apo_chain='A',
104
+ )
105
+ q_holo = scorer.score('design.pdb', binder_chain='B', state='holo')
106
+ q_apo = scorer.score('design.pdb', binder_chain='B', state='apo')
107
+ print(f'S = {q_holo - q_apo:.3f}')
108
+ ```
109
+
110
+ ## Guidance methods
111
+
112
+ The shipped guidance code wraps **PXDesign** as the prior and uses Q_θ as the gradient / classifier signal. All four method variants (Langevin, SMC, TDS, classifier guidance) live in `code/scripts/pxdesign_guidance/`.
113
+
114
+ See [`inference.md`](inference.md) §3 for command lines.
115
+
116
+ To deploy Q_θ with **RFdiffusion**, **Proteina-ComplexA**, or any other backbone prior, see [`code/scripts/README.md`](code/scripts/README.md) — Q_θ exposes `DifferentiableQTheta` for `∇_x S(x)`, and the PXDesign code is a worked template to mirror.
117
+
118
+ ## Citation
119
+
120
+ ```bibtex
121
+ @inproceedings{cao2026allogen,
122
+ title = {AlloGen: State-Selective Scoring for Allosteric Binder Design},
123
+ author = {Cao, Hanqun and others},
124
+ booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
125
+ year = {2026}
126
+ }
127
+ ```
128
+
129
+ (BibTeX key will be finalized at camera-ready.)
130
+
131
+ ## License
132
+
133
+ MIT — see [`LICENSE`](LICENSE).
checkpoints/Q_theta_phase1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1684955f481c1406b12cc0e1ec3509a2a2e2def8b0a9071ec0c96be00d330e7c
3
+ size 3617774
checkpoints/Q_theta_phase2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:716e4716c0014a46cfd4a2e26c238486c8ad1e3a03809320247cfb734538f8c6
3
+ size 3618158
checkpoints/Q_theta_train_curve.csv ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch,loss,test_rho,cam_rho,rho_cam,rho_bcl2,rho_era,rho_mdm2,rho_ran,rho_a2a,rho_pai1,rho_integrin
2
+ 1,5.527279376983643,0.44785173068985473,0.48303455690607044,0.48303455690607044,0.5733270876635986,0.6666325520384175,0.655984079559599,0.4957646288057514,0.28072319745279745,0.32040567098387046,0.10694207210873281
3
+ 2,5.452569961547852,0.45406694660326774,0.4863130719981931,0.4863130719981931,0.5755749254666933,0.6683595793753047,0.6593234065151792,0.49658672598782,0.29627094377326,0.33010796670396186,0.11999895300572927
4
+ 3,5.449374318122864,0.44742845709394635,0.4966343232141348,0.4966343232141348,0.5758960451528496,0.6612355916106455,0.6531535587702088,0.4753625218665822,0.23926254059823043,0.33456826378855364,0.14331481175036578
5
+ 4,5.44292688369751,0.4461415166784438,0.49274867569754505,0.49274867569754505,0.5728775201029797,0.6614514700277563,0.6451799702468982,0.46946643826626333,0.242717595336111,0.34013213953325055,0.1445583242167464
6
+ 5,5.441892623901367,0.43839115930587513,0.49080585193925014,0.49080585193925014,0.5653633194469204,0.6586450506053149,0.6361543905961949,0.4544470120828291,0.21334963006412605,0.33976427997988223,0.1485997397324834
7
+ 6,5.439353764057159,0.4284122153367078,0.46931336411311275,0.46931336411311275,0.5655559912586142,0.6597244426908693,0.634838828707037,0.4423570277236172,0.19089177426790224,0.3362236317787114,0.1283926621537984
8
+ 7,5.439265787601471,0.42722475631179213,0.45923496586695794,0.45923496586695794,0.5661982306309269,0.6614514700277563,0.635386597507575,0.4425217082123247,0.20816704795730515,0.3344762989002116,0.1103617313912795
9
+ 8,5.439032256603241,0.423813755390132,0.45146367083377825,0.45146367083377825,0.5662624545681583,0.6633943757817542,0.6325423252001796,0.4388258604749022,0.2029844658504843,0.33617764933454036,0.09885924107725882
10
+ 9,5.435689151287079,0.4221364421714867,0.45122081786399143,0.45122081786399143,0.5656202151958456,0.660156199525091,0.6294276063720167,0.4378717142280125,0.19952941111260372,0.33502808823026414,0.09823748484406851
11
+ 10,5.43562650680542,0.419534099250895,0.4498851265301637,0.4498851265301637,0.5641430646395261,0.662099105279089,0.6279539020257997,0.4329763310874754,0.1822541374232008,0.33530398289529045,0.10165714412661521
12
+ 11,5.435124695301056,0.41942099211035866,0.451949376773352,0.451949376773352,0.5633723773927508,0.6612355916106455,0.6249514872613453,0.43204285159336153,0.1822541374232008,0.3360397020020272,0.10352241282618613
13
+ 12,5.440709412097931,0.41971570222256677,0.4518279502884586,0.4518279502884586,0.5642072885767574,0.6612355916106455,0.6261524531671271,0.43317581854869713,0.1822541374232008,0.33534996533946143,0.10352241282618613
14
+ 13,5.430392742156982,0.4198042517465396,0.4523136562280323,0.4523136562280323,0.5629228098321319,0.6627467405304217,0.6259874349510655,0.4334288217301594,0.1822541374232008,0.33525800045111936,0.10352241282618613
15
+ 14,5.436960756778717,0.4199312463250485,0.4515850973186718,0.4515850973186718,0.5639503928278323,0.6627467405304217,0.6264251916075623,0.4329262960644863,0.1822541374232008,0.3360397020020272,0.10352241282618613
16
+ 15,5.433559775352478,0.419866834088955,0.4515850973186718,0.4515850973186718,0.5635650492044447,0.6627467405304217,0.6264251916075623,0.432842324243296,0.1822541374232008,0.3359937195578562,0.10352241282618613
code/__init__.py ADDED
File without changes
code/data/__init__.py ADDED
File without changes
code/data/dataset.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Dataset for two-state complex scoring.
3
+
4
+ Loads preprocessed graph data and provides batched tensors
5
+ with padding for variable-sized interface graphs.
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import pickle
11
+ import numpy as np
12
+ import torch
13
+ from torch.utils.data import Dataset, DataLoader
14
+
15
+ # Global ESM embedding cache: {file_path: tensor}
16
+ _ESM_CACHE = {}
17
+
18
+
19
+ def preload_esm_cache(esm_dir, targets):
20
+ """Preload all ESM .pt files into global cache before DataLoader workers fork.
21
+
22
+ This ensures forked workers inherit the populated cache via copy-on-write,
23
+ avoiding redundant I/O across workers.
24
+ """
25
+ import glob as glob_mod
26
+ n = 0
27
+ for target in targets:
28
+ target_dir = os.path.join(esm_dir, target)
29
+ if not os.path.isdir(target_dir):
30
+ continue
31
+ for pt_file in glob_mod.glob(os.path.join(target_dir, '*.pt')):
32
+ if pt_file not in _ESM_CACHE:
33
+ _ESM_CACHE[pt_file] = torch.load(pt_file, map_location='cpu', weights_only=True)
34
+ n += 1
35
+ return n
36
+
37
+
38
+ def load_esm_for_sample(sample, esm_dir, target_name, max_nodes=128):
39
+ """Load and index ESM-2 embeddings for a sample's interface residues.
40
+
41
+ Returns: esm_feats [max_nodes, 1280] or None if unavailable.
42
+ """
43
+ graph = sample['graph']
44
+ rec_idx = graph.get('rec_iface_idx')
45
+ binder_idx = graph.get('binder_iface_idx')
46
+ if rec_idx is None or binder_idx is None:
47
+ return None
48
+
49
+ # Get PDB ID (strip chain suffix like "2G1T_AE" -> "2G1T")
50
+ pdb_id = sample.get('pdb', '')
51
+ base_pdb = pdb_id.split('_')[0] if '_' in pdb_id else pdb_id
52
+ rec_chain = sample.get('rec_chain_id', 'A')
53
+ binder_chain = sample.get('binder_chain_id', 'B')
54
+
55
+ # Load ESM embeddings (cached)
56
+ rec_path = os.path.join(esm_dir, target_name, f'{base_pdb}_{rec_chain}.pt')
57
+ binder_path = os.path.join(esm_dir, target_name, f'{base_pdb}_{binder_chain}.pt')
58
+
59
+ def _load_cached(path):
60
+ if path not in _ESM_CACHE:
61
+ if not os.path.exists(path):
62
+ return None
63
+ _ESM_CACHE[path] = torch.load(path, map_location='cpu', weights_only=True)
64
+ return _ESM_CACHE[path]
65
+
66
+ rec_esm = _load_cached(rec_path)
67
+ binder_esm = _load_cached(binder_path)
68
+ if rec_esm is None or binder_esm is None:
69
+ return None
70
+
71
+ esm_dim = rec_esm.shape[-1] # 1280
72
+ n_rec = len(rec_idx)
73
+ n_binder = len(binder_idx)
74
+
75
+ # Index ESM embeddings by interface residue indices (clamp to valid range)
76
+ rec_idx_safe = np.clip(rec_idx, 0, len(rec_esm) - 1)
77
+ binder_idx_safe = np.clip(binder_idx, 0, len(binder_esm) - 1)
78
+
79
+ esm_feats = np.zeros((max_nodes, esm_dim), dtype=np.float32)
80
+ esm_feats[:n_rec] = rec_esm[rec_idx_safe].numpy()
81
+ esm_feats[n_rec:n_rec + n_binder] = binder_esm[binder_idx_safe].numpy()
82
+
83
+ return esm_feats
84
+
85
+
86
+ def load_rosetta_labels(rosetta_dir, target):
87
+ """Load Rosetta dG labels for a target and normalize to [0,1]."""
88
+ path = os.path.join(rosetta_dir, f'{target}_rosetta.json')
89
+ if not os.path.exists(path):
90
+ return None
91
+ with open(path) as f:
92
+ raw = json.load(f)
93
+ if not raw:
94
+ return None
95
+ # Filter outliers: dG values outside [-500, 500] are failed Rosetta runs
96
+ dG_MIN, dG_MAX = -500.0, 500.0
97
+ # Normalize: sigmoid(-dG / tau) maps dG to [0,1]
98
+ # More negative dG = better binding = higher score
99
+ tau = 15.0 # temperature; dG=-30 -> 0.88, dG=-15 -> 0.73, dG=0 -> 0.5
100
+ labels = {}
101
+ for pdb_id, metrics in raw.items():
102
+ dG = metrics.get('dG_separated', 0.0)
103
+ if not np.isfinite(dG) or dG < dG_MIN or dG > dG_MAX:
104
+ continue # skip failed Rosetta runs
105
+ labels[pdb_id] = 1.0 / (1.0 + np.exp(dG / tau))
106
+ labels[pdb_id.upper()] = labels[pdb_id]
107
+ labels[pdb_id.lower()] = labels[pdb_id]
108
+ return labels
109
+
110
+
111
+ def apply_rosetta_labels(samples, rosetta_labels, label_source='rosetta', alpha=0.5):
112
+ """Replace or combine sample labels with Rosetta-derived labels."""
113
+ if rosetta_labels is None:
114
+ return
115
+ n_replaced = 0
116
+ for s in samples:
117
+ pdb_id = s.get('pdb', '')
118
+ # Strip chain suffixes: "2G1T_AE" -> "2G1T"
119
+ base_pdb = pdb_id.split('_')[0] if '_' in pdb_id else pdb_id
120
+ rosetta_val = rosetta_labels.get(base_pdb) or rosetta_labels.get(base_pdb.upper())
121
+ if rosetta_val is None:
122
+ continue
123
+ if s['type'] == 'positive':
124
+ new_label = rosetta_val
125
+ elif s['type'].startswith('negative'):
126
+ new_label = 0.0 # apo mismatch stays 0
127
+ continue
128
+ elif s['type'].startswith('decoy'):
129
+ # Scale Rosetta label by DockQ-proxy quality
130
+ new_label = s['label'] * rosetta_val
131
+ else:
132
+ continue
133
+ if label_source == 'rosetta':
134
+ s['label'] = float(new_label)
135
+ elif label_source == 'combined':
136
+ s['label'] = float(alpha * s['label'] + (1 - alpha) * new_label)
137
+ n_replaced += 1
138
+ return n_replaced
139
+
140
+
141
+ class TwoStateComplexDataset(Dataset):
142
+ """
143
+ Dataset of protein complex interface graphs with two-state labels.
144
+
145
+ Each sample contains:
146
+ node_feats: [N, node_dim] interface residue features
147
+ edge_feats: [N, N, edge_dim] pairwise SE(3)-invariant features
148
+ node_mask: [N] bool
149
+ label: scalar float in [0, 1] (DockQ proxy / selectivity label)
150
+ type: str (positive / negative_apo / decoy_*)
151
+ pdb: str
152
+ """
153
+
154
+ def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False,
155
+ rosetta_labels: dict = None, label_source: str = 'dockq',
156
+ esm_dir: str = None, target_name: str = None,
157
+ binder_dropout: float = 0.0):
158
+ with open(data_path, 'rb') as f:
159
+ self.samples = pickle.load(f)
160
+ self.max_nodes = max_nodes
161
+ self.augment = augment
162
+ self.esm_dir = esm_dir
163
+ self.target_name = target_name
164
+ self.binder_dropout = binder_dropout
165
+ if label_source != 'dockq' and rosetta_labels:
166
+ apply_rosetta_labels(self.samples, rosetta_labels, label_source)
167
+
168
+ def __len__(self):
169
+ return len(self.samples)
170
+
171
+ def __getitem__(self, idx):
172
+ sample = self.samples[idx]
173
+ graph = sample['graph']
174
+
175
+ node_feats = graph['node_feats'] # [N, node_dim]
176
+ edge_feats = graph['edge_feats'] # [N, N, edge_dim]
177
+ node_mask = graph['node_mask'] # [N]
178
+
179
+ N = len(node_feats)
180
+ assert N <= self.max_nodes, f"Too many nodes: {N} > {self.max_nodes}"
181
+
182
+ # Pad to max_nodes
183
+ node_dim = node_feats.shape[-1]
184
+ edge_dim = edge_feats.shape[-1]
185
+
186
+ node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
187
+ edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
188
+ node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
189
+
190
+ node_feats_pad[:N] = node_feats
191
+ edge_feats_pad[:N, :N] = edge_feats
192
+ node_mask_pad[:N] = node_mask
193
+
194
+ # Optional: random coordinate noise augmentation
195
+ if self.augment:
196
+ noise = np.random.randn(*node_feats_pad.shape) * 0.01
197
+ node_feats_pad = node_feats_pad + noise.astype(np.float32)
198
+
199
+ # Binder-dropout: simulate backbone-only designs by masking binder
200
+ # sequence features (AA one-hot → UNK, chi angles → 0)
201
+ apply_binder_drop = (self.binder_dropout > 0
202
+ and np.random.rand() < self.binder_dropout)
203
+ if apply_binder_drop:
204
+ n_rec = graph.get('n_rec', N // 2)
205
+ # Zero out binder AA one-hot (dims 0-20), set UNK (dim 20 = 1)
206
+ node_feats_pad[n_rec:N, :21] = 0.0
207
+ node_feats_pad[n_rec:N, 20] = 1.0 # UNK
208
+ # Zero out binder chi angles (dims 27-30)
209
+ node_feats_pad[n_rec:N, 27:31] = 0.0
210
+ # Keep backbone torsions (dims 21-26) and chain indicator (dim 31)
211
+
212
+ result = {
213
+ 'node_feats': torch.from_numpy(node_feats_pad), # [max_nodes, node_dim]
214
+ 'edge_feats': torch.from_numpy(edge_feats_pad), # [max_nodes, max_nodes, edge_dim]
215
+ 'node_mask': torch.from_numpy(node_mask_pad), # [max_nodes]
216
+ 'label': torch.tensor(sample['label'], dtype=torch.float32),
217
+ 'type': sample['type'],
218
+ 'pdb': sample['pdb'],
219
+ }
220
+
221
+ # ESM-2 features (lazy load; zero-fill if unavailable)
222
+ if self.esm_dir:
223
+ esm = load_esm_for_sample(sample, self.esm_dir,
224
+ self.target_name or '', self.max_nodes)
225
+ if esm is not None:
226
+ esm_feats = esm
227
+ else:
228
+ esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32)
229
+ # Zero binder ESM if binder-dropout active
230
+ if apply_binder_drop:
231
+ n_rec = graph.get('n_rec', N // 2)
232
+ n_binder = graph.get('n_binder', N - n_rec)
233
+ esm_feats[n_rec:n_rec + n_binder] = 0.0
234
+ result['esm_feats'] = torch.from_numpy(esm_feats)
235
+
236
+ return result
237
+
238
+
239
+ def collate_fn(batch):
240
+ """Collate a list of samples into batched tensors."""
241
+ node_feats = torch.stack([s['node_feats'] for s in batch])
242
+ edge_feats = torch.stack([s['edge_feats'] for s in batch])
243
+ node_mask = torch.stack([s['node_mask'] for s in batch])
244
+ labels = torch.stack([s['label'] for s in batch])
245
+ types = [s['type'] for s in batch]
246
+ pdbs = [s['pdb'] for s in batch]
247
+
248
+ result = {
249
+ 'node_feats': node_feats, # [B, N, node_dim]
250
+ 'edge_feats': edge_feats, # [B, N, N, edge_dim]
251
+ 'node_mask': node_mask, # [B, N]
252
+ 'label': labels, # [B]
253
+ 'type': types,
254
+ 'pdb': pdbs,
255
+ }
256
+
257
+ # Stack ESM features if present (handle mixed availability with zero-fill)
258
+ has_esm = any('esm_feats' in s for s in batch)
259
+ if has_esm:
260
+ esm_list = []
261
+ for s in batch:
262
+ if 'esm_feats' in s:
263
+ esm_list.append(s['esm_feats'])
264
+ else:
265
+ # Get shape from a sample that has ESM
266
+ ref = next(x['esm_feats'] for x in batch if 'esm_feats' in x)
267
+ esm_list.append(torch.zeros_like(ref))
268
+ result['esm_feats'] = torch.stack(esm_list)
269
+
270
+ return result
271
+
272
+
273
+ class TwoStateDatasetPaired(Dataset):
274
+ """
275
+ Paired dataset: returns (positive, negative) pairs for selectivity training.
276
+ Groups samples by PDB ID and pairs positive (holo) with negative (apo) examples.
277
+ """
278
+
279
+ def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False,
280
+ esm_dir: str = None, target_name: str = None,
281
+ binder_dropout: float = 0.0):
282
+ with open(data_path, 'rb') as f:
283
+ samples = pickle.load(f)
284
+ self.max_nodes = max_nodes
285
+ self.augment = augment
286
+ self.esm_dir = esm_dir
287
+ self.target_name = target_name
288
+ self.binder_dropout = binder_dropout
289
+
290
+ # Group by PDB
291
+ from collections import defaultdict
292
+ by_pdb = defaultdict(lambda: {'positive': [], 'negative': [], 'decoy': []})
293
+ for s in samples:
294
+ pdb = s['pdb']
295
+ t = s['type']
296
+ if t == 'positive':
297
+ by_pdb[pdb]['positive'].append(s)
298
+ elif t.startswith('negative'):
299
+ by_pdb[pdb]['negative'].append(s)
300
+ elif t.startswith('decoy'):
301
+ by_pdb[pdb]['decoy'].append(s)
302
+
303
+ # Build pairs: (positive, negative) per PDB
304
+ self.pairs = []
305
+ for pdb, groups in by_pdb.items():
306
+ if len(groups['positive']) > 0 and len(groups['negative']) > 0:
307
+ for pos in groups['positive']:
308
+ for neg in groups['negative']:
309
+ self.pairs.append((pos, neg))
310
+ # Also add (positive, decoy_large_rmsd) pairs
311
+ if len(groups['positive']) > 0 and len(groups['decoy']) > 0:
312
+ large_decoys = [s for s in groups['decoy'] if 'rmsd' in s['type'] and
313
+ float(s['type'].replace('decoy_rmsd', '')) > 4.0]
314
+ for pos in groups['positive']:
315
+ for neg in large_decoys[:3]: # limit to 3 hard decoys per positive
316
+ self.pairs.append((pos, neg))
317
+
318
+ def __len__(self):
319
+ return len(self.pairs)
320
+
321
+ def _prepare(self, sample, apply_binder_drop=False):
322
+ graph = sample['graph']
323
+ node_feats = graph['node_feats']
324
+ edge_feats = graph['edge_feats']
325
+ node_mask = graph['node_mask']
326
+ N = len(node_feats)
327
+ node_dim = node_feats.shape[-1]
328
+ edge_dim = edge_feats.shape[-1]
329
+
330
+ node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
331
+ edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
332
+ node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
333
+
334
+ n = min(N, self.max_nodes)
335
+ node_feats_pad[:n] = node_feats[:n]
336
+ edge_feats_pad[:n, :n] = edge_feats[:n, :n]
337
+ node_mask_pad[:n] = node_mask[:n]
338
+
339
+ # Binder-dropout: simulate backbone-only designs
340
+ if apply_binder_drop:
341
+ n_rec = graph.get('n_rec', n // 2)
342
+ node_feats_pad[n_rec:n, :21] = 0.0
343
+ node_feats_pad[n_rec:n, 20] = 1.0 # UNK
344
+ node_feats_pad[n_rec:n, 27:31] = 0.0
345
+
346
+ result = {
347
+ 'node_feats': torch.from_numpy(node_feats_pad),
348
+ 'edge_feats': torch.from_numpy(edge_feats_pad),
349
+ 'node_mask': torch.from_numpy(node_mask_pad),
350
+ 'label': torch.tensor(sample['label'], dtype=torch.float32),
351
+ 'contact_energy': torch.tensor(
352
+ sample.get('contact_energy', 0.5), dtype=torch.float32
353
+ ),
354
+ }
355
+
356
+ # ESM-2 features (zero-fill if unavailable)
357
+ if self.esm_dir:
358
+ esm = load_esm_for_sample(sample, self.esm_dir,
359
+ self.target_name or '', self.max_nodes)
360
+ if esm is not None:
361
+ esm_feats = esm
362
+ else:
363
+ esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32)
364
+ if apply_binder_drop:
365
+ n_rec = graph.get('n_rec', n // 2)
366
+ n_binder = graph.get('n_binder', n - n_rec)
367
+ esm_feats[n_rec:n_rec + n_binder] = 0.0
368
+ result['esm_feats'] = torch.from_numpy(esm_feats)
369
+
370
+ return result
371
+
372
+ def __getitem__(self, idx):
373
+ pos_sample, neg_sample = self.pairs[idx]
374
+ # Same dropout decision for both pos and neg in a pair
375
+ drop = (self.binder_dropout > 0
376
+ and np.random.rand() < self.binder_dropout)
377
+ return {
378
+ 'pos': self._prepare(pos_sample, apply_binder_drop=drop),
379
+ 'neg': self._prepare(neg_sample, apply_binder_drop=drop),
380
+ }
381
+
382
+
383
+ def collate_paired_fn(batch):
384
+ """Collate paired (positive, negative) samples."""
385
+ pos_batch = {
386
+ 'node_feats': torch.stack([s['pos']['node_feats'] for s in batch]),
387
+ 'edge_feats': torch.stack([s['pos']['edge_feats'] for s in batch]),
388
+ 'node_mask': torch.stack([s['pos']['node_mask'] for s in batch]),
389
+ 'label': torch.stack([s['pos']['label'] for s in batch]),
390
+ 'contact_energy': torch.stack([s['pos']['contact_energy'] for s in batch]),
391
+ }
392
+ neg_batch = {
393
+ 'node_feats': torch.stack([s['neg']['node_feats'] for s in batch]),
394
+ 'edge_feats': torch.stack([s['neg']['edge_feats'] for s in batch]),
395
+ 'node_mask': torch.stack([s['neg']['node_mask'] for s in batch]),
396
+ 'label': torch.stack([s['neg']['label'] for s in batch]),
397
+ 'contact_energy': torch.stack([s['neg']['contact_energy'] for s in batch]),
398
+ }
399
+ # ESM features (handle mixed availability)
400
+ has_pos_esm = any('esm_feats' in s['pos'] for s in batch)
401
+ if has_pos_esm:
402
+ def _stack_esm(batch_list, key):
403
+ esm_list = []
404
+ ref = next((x[key]['esm_feats'] for x in batch_list if 'esm_feats' in x[key]), None)
405
+ for s in batch_list:
406
+ if 'esm_feats' in s[key]:
407
+ esm_list.append(s[key]['esm_feats'])
408
+ else:
409
+ esm_list.append(torch.zeros_like(ref))
410
+ return torch.stack(esm_list)
411
+ pos_batch['esm_feats'] = _stack_esm(batch, 'pos')
412
+ neg_batch['esm_feats'] = _stack_esm(batch, 'neg')
413
+ return {'pos': pos_batch, 'neg': neg_batch}
414
+
415
+
416
+ class PathAwareDatasetPaired(Dataset):
417
+ """
418
+ Paired dataset with transition-path frames for path-aware Phase 2 training.
419
+
420
+ Extends TwoStateDatasetPaired: each sample returns (positive, negative, path_frames)
421
+ where path_frames is a list of prepared graph dicts for intermediate conformations
422
+ stored in the positive sample's 'path_graphs' field.
423
+ """
424
+
425
+ def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False):
426
+ with open(data_path, 'rb') as f:
427
+ samples = pickle.load(f)
428
+ self.max_nodes = max_nodes
429
+ self.augment = augment
430
+
431
+ from collections import defaultdict
432
+ by_pdb = defaultdict(lambda: {'positive': [], 'negative': [], 'decoy': []})
433
+ for s in samples:
434
+ pdb = s['pdb']
435
+ t = s['type']
436
+ if t == 'positive':
437
+ by_pdb[pdb]['positive'].append(s)
438
+ elif t.startswith('negative'):
439
+ by_pdb[pdb]['negative'].append(s)
440
+ elif t.startswith('decoy'):
441
+ by_pdb[pdb]['decoy'].append(s)
442
+
443
+ self.pairs = []
444
+ for pdb, groups in by_pdb.items():
445
+ if len(groups['positive']) > 0 and len(groups['negative']) > 0:
446
+ for pos in groups['positive']:
447
+ for neg in groups['negative']:
448
+ self.pairs.append((pos, neg))
449
+ if len(groups['positive']) > 0 and len(groups['decoy']) > 0:
450
+ large_decoys = [s for s in groups['decoy'] if 'rmsd' in s['type'] and
451
+ float(s['type'].replace('decoy_rmsd', '')) > 4.0]
452
+ for pos in groups['positive']:
453
+ for neg in large_decoys[:3]:
454
+ self.pairs.append((pos, neg))
455
+
456
+ def _prepare(self, sample):
457
+ graph = sample['graph']
458
+ node_feats = graph['node_feats']
459
+ edge_feats = graph['edge_feats']
460
+ node_mask = graph['node_mask']
461
+ N = len(node_feats)
462
+ node_dim = node_feats.shape[-1]
463
+ edge_dim = edge_feats.shape[-1]
464
+
465
+ node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
466
+ edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
467
+ node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
468
+
469
+ n = min(N, self.max_nodes)
470
+ node_feats_pad[:n] = node_feats[:n]
471
+ edge_feats_pad[:n, :n] = edge_feats[:n, :n]
472
+ node_mask_pad[:n] = node_mask[:n]
473
+
474
+ return {
475
+ 'node_feats': torch.from_numpy(node_feats_pad),
476
+ 'edge_feats': torch.from_numpy(edge_feats_pad),
477
+ 'node_mask': torch.from_numpy(node_mask_pad),
478
+ 'label': torch.tensor(sample.get('label', 0.0), dtype=torch.float32),
479
+ 'contact_energy': torch.tensor(
480
+ sample.get('contact_energy', 0.5), dtype=torch.float32
481
+ ),
482
+ }
483
+
484
+ def _prepare_graph_only(self, path_entry):
485
+ """Prepare a path frame graph (no label/contact_energy needed)."""
486
+ graph = path_entry['graph']
487
+ node_feats = graph['node_feats']
488
+ edge_feats = graph['edge_feats']
489
+ node_mask = graph['node_mask']
490
+ N = len(node_feats)
491
+ node_dim = node_feats.shape[-1]
492
+ edge_dim = edge_feats.shape[-1]
493
+
494
+ node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
495
+ edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
496
+ node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
497
+
498
+ n = min(N, self.max_nodes)
499
+ node_feats_pad[:n] = node_feats[:n]
500
+ edge_feats_pad[:n, :n] = edge_feats[:n, :n]
501
+ node_mask_pad[:n] = node_mask[:n]
502
+
503
+ return {
504
+ 'node_feats': torch.from_numpy(node_feats_pad),
505
+ 'edge_feats': torch.from_numpy(edge_feats_pad),
506
+ 'node_mask': torch.from_numpy(node_mask_pad),
507
+ }
508
+
509
+ def __len__(self):
510
+ return len(self.pairs)
511
+
512
+ def __getitem__(self, idx):
513
+ pos_sample, neg_sample = self.pairs[idx]
514
+ result = {
515
+ 'pos': self._prepare(pos_sample),
516
+ 'neg': self._prepare(neg_sample),
517
+ }
518
+
519
+ # Prepare path frames if available
520
+ path_graphs = pos_sample.get('path_graphs', [])
521
+ prepared_paths = []
522
+ path_taus = []
523
+ for pg in path_graphs:
524
+ prepared_paths.append(self._prepare_graph_only(pg))
525
+ path_taus.append(pg['tau'])
526
+
527
+ result['path'] = prepared_paths
528
+ result['path_taus'] = path_taus
529
+
530
+ return result
531
+
532
+
533
+ def collate_path_paired_fn(batch):
534
+ """Collate paired samples with variable-length path frames."""
535
+ pos_batch = {
536
+ 'node_feats': torch.stack([s['pos']['node_feats'] for s in batch]),
537
+ 'edge_feats': torch.stack([s['pos']['edge_feats'] for s in batch]),
538
+ 'node_mask': torch.stack([s['pos']['node_mask'] for s in batch]),
539
+ 'label': torch.stack([s['pos']['label'] for s in batch]),
540
+ 'contact_energy': torch.stack([s['pos']['contact_energy'] for s in batch]),
541
+ }
542
+ neg_batch = {
543
+ 'node_feats': torch.stack([s['neg']['node_feats'] for s in batch]),
544
+ 'edge_feats': torch.stack([s['neg']['edge_feats'] for s in batch]),
545
+ 'node_mask': torch.stack([s['neg']['node_mask'] for s in batch]),
546
+ 'label': torch.stack([s['neg']['label'] for s in batch]),
547
+ 'contact_energy': torch.stack([s['neg']['contact_energy'] for s in batch]),
548
+ }
549
+
550
+ # Collate path frames: find max K across batch, pad shorter ones
551
+ max_k = max((len(s['path']) for s in batch), default=0)
552
+ path_batches = []
553
+ path_taus = []
554
+
555
+ if max_k > 0:
556
+ # Build a zero-filled placeholder for padding (graph-only keys)
557
+ ref = batch[0]['path'][0] if batch[0]['path'] else batch[0]['pos']
558
+ zero_placeholder = {
559
+ 'node_feats': torch.zeros_like(ref['node_feats']),
560
+ 'edge_feats': torch.zeros_like(ref['edge_feats']),
561
+ 'node_mask': torch.zeros_like(ref['node_mask']),
562
+ }
563
+
564
+ for k_idx in range(max_k):
565
+ frames_at_k = []
566
+ taus_at_k = []
567
+ for s in batch:
568
+ if k_idx < len(s['path']):
569
+ frames_at_k.append(s['path'][k_idx])
570
+ taus_at_k.append(s['path_taus'][k_idx])
571
+ else:
572
+ frames_at_k.append(zero_placeholder)
573
+ taus_at_k.append(1.0)
574
+
575
+ path_batches.append({
576
+ 'node_feats': torch.stack([f['node_feats'] for f in frames_at_k]),
577
+ 'edge_feats': torch.stack([f['edge_feats'] for f in frames_at_k]),
578
+ 'node_mask': torch.stack([f['node_mask'] for f in frames_at_k]),
579
+ })
580
+ path_taus.append(taus_at_k[0])
581
+
582
+ result = {'pos': pos_batch, 'neg': neg_batch}
583
+ if path_batches:
584
+ result['path'] = path_batches
585
+ result['path_taus'] = path_taus
586
+ return result
587
+
588
+
589
+ class MultiTargetDataset(Dataset):
590
+ """
591
+ Pooled dataset combining samples from multiple targets.
592
+ Supports balanced sampling across targets.
593
+ """
594
+
595
+ def __init__(self, data_paths: list, max_nodes: int = 128, augment: bool = False,
596
+ balance: bool = True, rosetta_dir: str = None, label_source: str = 'dockq',
597
+ esm_dir: str = None, binder_dropout: float = 0.0):
598
+ """
599
+ Args:
600
+ data_paths: list of (target_name, pkl_path) tuples
601
+ max_nodes: max interface graph size
602
+ augment: apply noise augmentation
603
+ balance: if True, oversample smaller targets to balance
604
+ rosetta_dir: directory containing Rosetta label JSONs
605
+ label_source: 'dockq', 'rosetta', or 'combined'
606
+ """
607
+ self.max_nodes = max_nodes
608
+ self.augment = augment
609
+ self.esm_dir = esm_dir
610
+ self.binder_dropout = binder_dropout
611
+
612
+ # Load all samples with target labels
613
+ self.samples = []
614
+ self.target_indices = {} # target_name -> list of indices
615
+
616
+ for target_name, path in data_paths:
617
+ if not os.path.exists(path):
618
+ continue
619
+ with open(path, 'rb') as f:
620
+ target_samples = pickle.load(f)
621
+
622
+ # Apply Rosetta labels if requested
623
+ if label_source != 'dockq' and rosetta_dir:
624
+ rl = load_rosetta_labels(rosetta_dir, target_name)
625
+ if rl:
626
+ apply_rosetta_labels(target_samples, rl, label_source)
627
+
628
+ start_idx = len(self.samples)
629
+ for s in target_samples:
630
+ s['_target'] = target_name
631
+ self.samples.append(s)
632
+ end_idx = len(self.samples)
633
+ self.target_indices[target_name] = list(range(start_idx, end_idx))
634
+
635
+ # Build balanced sampling weights
636
+ if balance and len(self.target_indices) > 1:
637
+ non_empty = {k: v for k, v in self.target_indices.items() if len(v) > 0}
638
+ max_count = max(len(idxs) for idxs in non_empty.values()) if non_empty else 1
639
+ self.weights = np.zeros(len(self.samples))
640
+ for target_name, idxs in self.target_indices.items():
641
+ if len(idxs) == 0:
642
+ continue
643
+ weight = max_count / len(idxs)
644
+ for i in idxs:
645
+ self.weights[i] = weight
646
+ self.weights /= self.weights.sum()
647
+ else:
648
+ self.weights = None
649
+
650
+ def __len__(self):
651
+ return len(self.samples)
652
+
653
+ def __getitem__(self, idx):
654
+ sample = self.samples[idx]
655
+ graph = sample['graph']
656
+ node_feats = graph['node_feats']
657
+ edge_feats = graph['edge_feats']
658
+ node_mask = graph['node_mask']
659
+ N = len(node_feats)
660
+ node_dim = node_feats.shape[-1]
661
+ edge_dim = edge_feats.shape[-1]
662
+
663
+ node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
664
+ edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
665
+ node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
666
+
667
+ n = min(N, self.max_nodes)
668
+ node_feats_pad[:n] = node_feats[:n]
669
+ edge_feats_pad[:n, :n] = edge_feats[:n, :n]
670
+ node_mask_pad[:n] = node_mask[:n]
671
+
672
+ if self.augment:
673
+ noise = np.random.randn(*node_feats_pad.shape) * 0.01
674
+ node_feats_pad = node_feats_pad + noise.astype(np.float32)
675
+
676
+ # Binder-dropout: simulate backbone-only designs
677
+ apply_binder_drop = (self.binder_dropout > 0
678
+ and np.random.rand() < self.binder_dropout)
679
+ if apply_binder_drop:
680
+ n_rec = graph.get('n_rec', N // 2)
681
+ node_feats_pad[n_rec:N, :21] = 0.0
682
+ node_feats_pad[n_rec:N, 20] = 1.0 # UNK
683
+ node_feats_pad[n_rec:N, 27:31] = 0.0
684
+
685
+ result = {
686
+ 'node_feats': torch.from_numpy(node_feats_pad),
687
+ 'edge_feats': torch.from_numpy(edge_feats_pad),
688
+ 'node_mask': torch.from_numpy(node_mask_pad),
689
+ 'label': torch.tensor(sample['label'], dtype=torch.float32),
690
+ 'type': sample['type'],
691
+ 'pdb': sample['pdb'],
692
+ 'target': sample.get('_target', 'unknown'),
693
+ }
694
+
695
+ # ESM-2 features (zero-fill if unavailable)
696
+ if self.esm_dir:
697
+ target_name = sample.get('_target', 'unknown')
698
+ esm = load_esm_for_sample(sample, self.esm_dir, target_name, self.max_nodes)
699
+ if esm is not None:
700
+ esm_feats = esm
701
+ else:
702
+ esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32)
703
+ if apply_binder_drop:
704
+ n_rec = graph.get('n_rec', N // 2)
705
+ n_binder = graph.get('n_binder', N - n_rec)
706
+ esm_feats[n_rec:n_rec + n_binder] = 0.0
707
+ result['esm_feats'] = torch.from_numpy(esm_feats)
708
+
709
+ return result
710
+
711
+ @staticmethod
712
+ def get_pooled_dataloaders(data_dir, targets, batch_size=16, max_nodes=128,
713
+ num_workers=4, paired=False,
714
+ rosetta_dir=None, label_source='dockq',
715
+ esm_dir=None, binder_dropout=0.0):
716
+ """Build pooled dataloaders from multiple targets.
717
+
718
+ Args:
719
+ data_dir: root data directory
720
+ targets: list of target names
721
+ batch_size: batch size
722
+ max_nodes: max interface nodes
723
+ num_workers: dataloader workers
724
+ paired: if True, build paired dataloaders for Phase 2
725
+ rosetta_dir: directory with Rosetta label JSONs
726
+ label_source: 'dockq', 'rosetta', or 'combined'
727
+ """
728
+ from torch.utils.data import WeightedRandomSampler
729
+
730
+ # Preload ESM embeddings into global cache before creating datasets/workers
731
+ if esm_dir:
732
+ n_loaded = preload_esm_cache(esm_dir, targets)
733
+
734
+ loaders = {}
735
+ for split in ['train', 'val', 'test']:
736
+ data_paths = []
737
+ for target in targets:
738
+ path = os.path.join(data_dir, target, f"{split}.pkl")
739
+ if os.path.exists(path):
740
+ data_paths.append((target, path))
741
+
742
+ if not data_paths:
743
+ continue
744
+
745
+ augment = (split == 'train')
746
+ bd = binder_dropout if split == 'train' else 0.0
747
+
748
+ if paired:
749
+ # For paired mode, concatenate paired datasets
750
+ all_pairs = []
751
+ for target, path in data_paths:
752
+ ds = TwoStateDatasetPaired(path, max_nodes=max_nodes, augment=augment,
753
+ esm_dir=esm_dir, target_name=target,
754
+ binder_dropout=bd)
755
+ all_pairs.append(ds)
756
+
757
+ if not all_pairs:
758
+ continue
759
+
760
+ # Use ConcatDataset
761
+ from torch.utils.data import ConcatDataset
762
+ concat_ds = ConcatDataset(all_pairs)
763
+ p_batch = min(batch_size, max(1, len(concat_ds) // 2))
764
+ loaders[split] = DataLoader(
765
+ concat_ds, batch_size=p_batch,
766
+ shuffle=(split == 'train'),
767
+ num_workers=num_workers,
768
+ collate_fn=collate_paired_fn,
769
+ pin_memory=True,
770
+ )
771
+ else:
772
+ dataset = MultiTargetDataset(data_paths, max_nodes=max_nodes,
773
+ augment=augment, balance=(split == 'train'),
774
+ rosetta_dir=rosetta_dir, label_source=label_source,
775
+ esm_dir=esm_dir, binder_dropout=bd)
776
+
777
+ sampler = None
778
+ shuffle = (split == 'train')
779
+ if split == 'train' and dataset.weights is not None:
780
+ sampler = WeightedRandomSampler(
781
+ weights=dataset.weights,
782
+ num_samples=len(dataset),
783
+ replacement=True
784
+ )
785
+ shuffle = False
786
+
787
+ loaders[split] = DataLoader(
788
+ dataset, batch_size=batch_size,
789
+ shuffle=shuffle, sampler=sampler,
790
+ num_workers=num_workers,
791
+ collate_fn=collate_fn,
792
+ pin_memory=True,
793
+ drop_last=(split == 'train' and len(dataset) > batch_size),
794
+ )
795
+
796
+ return loaders
797
+
798
+
799
+ def get_dataloaders(data_dir: str, target: str, batch_size: int = 16,
800
+ max_nodes: int = 128, num_workers: int = 4,
801
+ paired: bool = False, esm_dir: str = None,
802
+ binder_dropout: float = 0.0):
803
+ """Build train/val/test dataloaders for a given target."""
804
+ loaders = {}
805
+ for split in ['train', 'val', 'test']:
806
+ path = os.path.join(data_dir, target, f"{split}.pkl")
807
+ if not os.path.exists(path):
808
+ continue
809
+
810
+ augment = (split == 'train')
811
+ bd = binder_dropout if split == 'train' else 0.0
812
+ if paired and split == 'train':
813
+ dataset = TwoStateDatasetPaired(path, max_nodes=max_nodes, augment=augment,
814
+ esm_dir=esm_dir, target_name=target,
815
+ binder_dropout=bd)
816
+ collate = collate_paired_fn
817
+ else:
818
+ dataset = TwoStateComplexDataset(path, max_nodes=max_nodes, augment=augment,
819
+ esm_dir=esm_dir, target_name=target,
820
+ binder_dropout=bd)
821
+ collate = collate_fn
822
+
823
+ loaders[split] = DataLoader(
824
+ dataset,
825
+ batch_size=batch_size,
826
+ shuffle=(split == 'train'),
827
+ num_workers=num_workers,
828
+ collate_fn=collate,
829
+ pin_memory=True,
830
+ drop_last=(split == 'train' and len(dataset) > batch_size),
831
+ )
832
+ return loaders
code/models/__init__.py ADDED
File without changes
code/models/differentiable_features.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Differentiable feature extraction for Q_theta guidance.
3
+
4
+ This module re-implements the key feature extraction functions from features.py
5
+ and pdb_utils.py using PyTorch operations, enabling gradient computation through
6
+ Q_theta with respect to backbone coordinates.
7
+
8
+ The differentiable path:
9
+ coords (N,4,3) → backbone frames → torsions, distances, directions, rotations
10
+ → node_feats, edge_feats → Q_theta → score → backward() → ∇coords
11
+
12
+ Non-differentiable features (AA one-hot, chain_id, seq_sep, same_chain) are
13
+ treated as constants.
14
+ """
15
+
16
+ import os
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import numpy as np
20
+
21
+
22
+ # ── Differentiable backbone frame computation ────────────────────────────────
23
+
24
+ def compute_backbone_frames_torch(coords, mask):
25
+ """
26
+ Compute SE(3)-equivariant backbone frames from N, CA, C atoms.
27
+ Differentiable w.r.t. coords.
28
+
29
+ Args:
30
+ coords: [N, 4, 3] backbone coords (N, CA, C, O) — requires_grad=True for binder
31
+ mask: [N] bool tensor
32
+
33
+ Returns:
34
+ origins: [N, 3] = CA positions
35
+ rotations: [N, 3, 3] = rotation matrices (columns are x, y, z axes)
36
+ """
37
+ N_res = coords.shape[0]
38
+ device = coords.device
39
+
40
+ origins = coords[:, 1, :] # CA positions [N, 3]
41
+ rotations = torch.eye(3, device=device, dtype=coords.dtype).unsqueeze(0).expand(N_res, -1, -1).clone()
42
+
43
+ ca = coords[:, 1, :] # [N, 3]
44
+ n_atom = coords[:, 0, :] # [N, 3]
45
+ c_atom = coords[:, 2, :] # [N, 3]
46
+
47
+ # z-axis: CA -> C
48
+ z = c_atom - ca # [N, 3]
49
+ z_norm = torch.norm(z, dim=-1, keepdim=True).clamp(min=1e-6) # [N, 1]
50
+ z = z / z_norm # [N, 3]
51
+
52
+ # y-axis: CA -> N, orthogonalized against z
53
+ y = n_atom - ca # [N, 3]
54
+ y_proj = (y * z).sum(dim=-1, keepdim=True) # [N, 1]
55
+ y = y - y_proj * z # [N, 3]
56
+ y_norm = torch.norm(y, dim=-1, keepdim=True).clamp(min=1e-6) # [N, 1]
57
+ y = y / y_norm # [N, 3]
58
+
59
+ # x-axis: y cross z
60
+ x = torch.cross(y, z, dim=-1) # [N, 3]
61
+
62
+ # Stack columns: [N, 3, 3] where columns are x, y, z
63
+ rot = torch.stack([x, y, z], dim=-1) # [N, 3, 3]
64
+
65
+ # Apply mask: identity for masked residues
66
+ mask_f = mask.float().unsqueeze(-1).unsqueeze(-1) # [N, 1, 1]
67
+ eye = torch.eye(3, device=device, dtype=coords.dtype).unsqueeze(0) # [1, 3, 3]
68
+ rotations = rot * mask_f + eye * (1 - mask_f)
69
+
70
+ return origins, rotations
71
+
72
+
73
+ # ── Differentiable torsion angle computation ─────────────────────────────────
74
+
75
+ def _dihedral_torch(p0, p1, p2, p3):
76
+ """
77
+ Compute dihedral angle for batches of 4 points. Returns sin, cos.
78
+ Differentiable w.r.t. all inputs.
79
+
80
+ Args:
81
+ p0, p1, p2, p3: [N, 3] tensors
82
+
83
+ Returns:
84
+ sin_angle: [N]
85
+ cos_angle: [N]
86
+ """
87
+ b1 = p1 - p0 # [N, 3]
88
+ b2 = p2 - p1
89
+ b3 = p3 - p2
90
+
91
+ n1 = torch.cross(b1, b2, dim=-1) # [N, 3]
92
+ n2 = torch.cross(b2, b3, dim=-1)
93
+
94
+ n1_norm = torch.norm(n1, dim=-1, keepdim=True).clamp(min=1e-8)
95
+ n2_norm = torch.norm(n2, dim=-1, keepdim=True).clamp(min=1e-8)
96
+ n1 = n1 / n1_norm
97
+ n2 = n2 / n2_norm
98
+
99
+ b2_norm = torch.norm(b2, dim=-1, keepdim=True).clamp(min=1e-8)
100
+ m1 = torch.cross(n1, b2 / b2_norm, dim=-1) # [N, 3]
101
+
102
+ cos_angle = (n1 * n2).sum(dim=-1) # [N]
103
+ sin_angle = (m1 * n2).sum(dim=-1) # [N]
104
+
105
+ return sin_angle, cos_angle
106
+
107
+
108
+ def compute_torsion_angles_torch(coords, mask):
109
+ """
110
+ Compute backbone torsion angles (phi, psi, omega) as sin/cos pairs.
111
+ Differentiable w.r.t. coords.
112
+
113
+ Args:
114
+ coords: [N, 4, 3] backbone coords (N, CA, C, O)
115
+ mask: [N] bool tensor
116
+
117
+ Returns:
118
+ torsions: [N, 6] (sin_phi, cos_phi, sin_psi, cos_psi, sin_omega, cos_omega)
119
+ """
120
+ N = coords.shape[0]
121
+ device = coords.device
122
+ torsions = torch.zeros(N, 6, device=device, dtype=coords.dtype)
123
+
124
+ if N < 2:
125
+ return torsions
126
+
127
+ n_atoms = coords[:, 0, :] # N atoms [N, 3]
128
+ ca_atoms = coords[:, 1, :] # CA atoms
129
+ c_atoms = coords[:, 2, :] # C atoms
130
+
131
+ # Phi: C_{i-1} - N_i - CA_i - C_i (for i >= 1)
132
+ if N > 1:
133
+ phi_mask = mask[1:] & mask[:-1] # [N-1]
134
+ sin_phi, cos_phi = _dihedral_torch(
135
+ c_atoms[:-1], # C_{i-1}
136
+ n_atoms[1:], # N_i
137
+ ca_atoms[1:], # CA_i
138
+ c_atoms[1:] # C_i
139
+ )
140
+ torsions[1:, 0] = sin_phi * phi_mask.float()
141
+ torsions[1:, 1] = cos_phi * phi_mask.float()
142
+
143
+ # Psi: N_i - CA_i - C_i - N_{i+1} (for i < N-1)
144
+ if N > 1:
145
+ psi_mask = mask[:-1] & mask[1:] # [N-1]
146
+ sin_psi, cos_psi = _dihedral_torch(
147
+ n_atoms[:-1], # N_i
148
+ ca_atoms[:-1], # CA_i
149
+ c_atoms[:-1], # C_i
150
+ n_atoms[1:] # N_{i+1}
151
+ )
152
+ torsions[:-1, 2] = sin_psi * psi_mask.float()
153
+ torsions[:-1, 3] = cos_psi * psi_mask.float()
154
+
155
+ # Omega: CA_{i-1} - C_{i-1} - N_i - CA_i (for i >= 1)
156
+ if N > 1:
157
+ omega_mask = mask[1:] & mask[:-1] # [N-1]
158
+ sin_omega, cos_omega = _dihedral_torch(
159
+ ca_atoms[:-1], # CA_{i-1}
160
+ c_atoms[:-1], # C_{i-1}
161
+ n_atoms[1:], # N_i
162
+ ca_atoms[1:] # CA_i
163
+ )
164
+ torsions[1:, 4] = sin_omega * omega_mask.float()
165
+ torsions[1:, 5] = cos_omega * omega_mask.float()
166
+
167
+ return torsions
168
+
169
+
170
+ # ── Differentiable RBF distance encoding ─────────────────────────────────────
171
+
172
+ def rbf_encode_torch(distances, d_min=0.0, d_max=20.0, n_bins=16):
173
+ """
174
+ RBF encoding of distances using Gaussian basis functions.
175
+ Differentiable w.r.t. distances.
176
+
177
+ Args:
178
+ distances: [...] tensor
179
+ Returns:
180
+ encoded: [..., n_bins] tensor
181
+ """
182
+ centers = torch.linspace(d_min, d_max, n_bins, device=distances.device, dtype=distances.dtype)
183
+ sigma = (d_max - d_min) / (n_bins - 1)
184
+ return torch.exp(-((distances.unsqueeze(-1) - centers) ** 2) / (2 * sigma ** 2))
185
+
186
+
187
+ # ── Differentiable edge feature computation ──────────────────────────────────
188
+
189
+ def compute_edge_features_torch(origins, rotations, seq_idx, chain_ids, mask,
190
+ n_bins_rbf=16, n_bins_sep=8, max_sep=32):
191
+ """
192
+ Compute SE(3)-invariant edge features between all residue pairs.
193
+ Differentiable w.r.t. origins and rotations (which derive from coords).
194
+
195
+ Args:
196
+ origins: [N, 3] CA positions
197
+ rotations: [N, 3, 3] backbone frame rotations
198
+ seq_idx: [N] int tensor — sequence indices (non-differentiable)
199
+ chain_ids: [N] int tensor — chain labels (non-differentiable)
200
+ mask: [N] bool tensor
201
+
202
+ Returns:
203
+ edge_feats: [N, N, 37]
204
+ """
205
+ N = origins.shape[0]
206
+ device = origins.device
207
+ dtype = origins.dtype
208
+
209
+ # --- Distance features (differentiable) ---
210
+ diff = origins.unsqueeze(1) - origins.unsqueeze(0) # [N, N, 3]
211
+ dist = torch.norm(diff, dim=-1).clamp(min=1e-8) # [N, N]
212
+ dist_rbf = rbf_encode_torch(dist, d_min=0., d_max=20., n_bins=n_bins_rbf) # [N, N, 16]
213
+
214
+ # --- Direction in local frame (differentiable) ---
215
+ unit_diff = diff / dist.unsqueeze(-1) # [N, N, 3]
216
+ # local_dir[i,j] = R_i^T @ (ca_j - ca_i) / dist
217
+ # rotations: [N, 3, 3], unit_diff: [N, N, 3]
218
+ local_dir = torch.einsum('ikl,ijl->ijk', rotations, unit_diff) # [N, N, 3]
219
+
220
+ # --- Relative rotation (differentiable) ---
221
+ # rel_rot[i,j] = R_i^T @ R_j -> [N, N, 3, 3] -> flatten to [N, N, 9]
222
+ rel_rot = torch.einsum('ikl,jlm->ijkm', rotations, rotations) # [N, N, 3, 3]
223
+ rel_rot_flat = rel_rot.reshape(N, N, 9) # [N, N, 9]
224
+
225
+ # --- Sequence separation (non-differentiable, constant) ---
226
+ sep = seq_idx.unsqueeze(1) - seq_idx.unsqueeze(0) # [N, N]
227
+ bins = torch.linspace(-max_sep, max_sep, n_bins_sep + 1, device=device)
228
+ sep_clipped = sep.float().clamp(-max_sep, max_sep)
229
+ # Bin encoding via soft assignment (but really we just use hard binning)
230
+ sep_enc = torch.zeros(N, N, n_bins_sep, device=device, dtype=dtype)
231
+ bin_idx = torch.bucketize(sep_clipped, bins) - 1
232
+ bin_idx = bin_idx.clamp(0, n_bins_sep - 1)
233
+ # Scatter one-hot
234
+ sep_enc.scatter_(2, bin_idx.unsqueeze(-1).long(), 1.0)
235
+
236
+ # Cross-chain pairs get sep=0
237
+ same_chain = (chain_ids.unsqueeze(1) == chain_ids.unsqueeze(0)) # [N, N]
238
+ cross_chain = ~same_chain
239
+ sep_enc[cross_chain] = 0.0
240
+
241
+ # --- Same chain indicator (non-differentiable, constant) ---
242
+ same_chain_feat = same_chain.float().unsqueeze(-1) # [N, N, 1]
243
+
244
+ # --- Concatenate ---
245
+ edge_feats = torch.cat([
246
+ dist_rbf, # [N, N, 16]
247
+ local_dir, # [N, N, 3]
248
+ rel_rot_flat, # [N, N, 9]
249
+ sep_enc, # [N, N, 8]
250
+ same_chain_feat # [N, N, 1]
251
+ ], dim=-1) # [N, N, 37]
252
+
253
+ # Zero out edges involving masked residues
254
+ mask_2d = mask.unsqueeze(1) & mask.unsqueeze(0) # [N, N]
255
+ edge_feats = edge_feats * mask_2d.unsqueeze(-1).float()
256
+
257
+ return edge_feats
258
+
259
+
260
+ # ── Full differentiable interface graph builder ──────────────────────────────
261
+
262
+ def build_differentiable_interface_graph(
263
+ rec_coords, rec_mask, rec_aa_idx, rec_chi,
264
+ binder_coords, binder_mask, binder_aa_idx, binder_chi,
265
+ cutoff=8.0, max_nodes=128
266
+ ):
267
+ """
268
+ Build interface graph with differentiable features w.r.t. binder_coords.
269
+ Receptor coords are treated as constants (detached).
270
+
271
+ Args:
272
+ rec_coords: [N_rec, 4, 3] — receptor backbone coords (constant, no grad)
273
+ rec_mask: [N_rec] bool
274
+ rec_aa_idx: [N_rec] int — amino acid indices (constant)
275
+ rec_chi: [N_rec, 4] — chi1/chi2 sin/cos (constant)
276
+ binder_coords: [N_binder, 4, 3] — binder backbone coords (requires_grad)
277
+ binder_mask: [N_binder] bool
278
+ binder_aa_idx: [N_binder] int — amino acid indices (constant, UNK for designed)
279
+ binder_chi: [N_binder, 4] — chi1/chi2 sin/cos (zeros for backbone-only)
280
+ cutoff: interface distance cutoff (Å)
281
+ max_nodes: maximum nodes per chain in the graph
282
+
283
+ Returns:
284
+ node_feats: [1, N_total, 32] tensor
285
+ edge_feats: [1, N_total, N_total, 37] tensor
286
+ node_mask: [1, N_total] bool tensor
287
+ n_rec: int
288
+ n_binder: int
289
+ or None if no interface
290
+ """
291
+ device = binder_coords.device
292
+ dtype = binder_coords.dtype
293
+ NUM_AA = 21
294
+
295
+ # ── Find interface residues (differentiable distances but hard threshold) ──
296
+ rec_ca = rec_coords[:, 1, :] # [N_rec, 3]
297
+ binder_ca = binder_coords[:, 1, :] # [N_binder, 3]
298
+
299
+ # Pairwise CA distances
300
+ dist_mat = torch.cdist(rec_ca.unsqueeze(0), binder_ca.unsqueeze(0)).squeeze(0) # [N_rec, N_binder]
301
+ # Mask invalid residues
302
+ dist_mat = dist_mat.clone()
303
+ dist_mat[~rec_mask, :] = float('inf')
304
+ dist_mat[:, ~binder_mask] = float('inf')
305
+
306
+ rec_iface = (dist_mat < cutoff).any(dim=1) # [N_rec]
307
+ binder_iface = (dist_mat < cutoff).any(dim=0) # [N_binder]
308
+
309
+ rec_iface_idx = torch.where(rec_iface)[0]
310
+ binder_iface_idx = torch.where(binder_iface)[0]
311
+
312
+ # Truncate if too many
313
+ if len(rec_iface_idx) > max_nodes // 2:
314
+ rec_iface_idx = rec_iface_idx[:max_nodes // 2]
315
+ if len(binder_iface_idx) > max_nodes // 2:
316
+ binder_iface_idx = binder_iface_idx[:max_nodes // 2]
317
+
318
+ n_rec = len(rec_iface_idx)
319
+ n_binder = len(binder_iface_idx)
320
+ n_total = n_rec + n_binder
321
+
322
+ if n_total == 0:
323
+ return None
324
+
325
+ # ── Extract interface subsets ──
326
+ rec_iface_coords = rec_coords[rec_iface_idx] # [n_rec, 4, 3]
327
+ binder_iface_coords = binder_coords[binder_iface_idx] # [n_binder, 4, 3]
328
+ rec_iface_mask = rec_mask[rec_iface_idx]
329
+ binder_iface_mask = binder_mask[binder_iface_idx]
330
+
331
+ # ── Compute backbone frames (differentiable) ──
332
+ rec_origins, rec_rotations = compute_backbone_frames_torch(rec_iface_coords, rec_iface_mask)
333
+ binder_origins, binder_rotations = compute_backbone_frames_torch(binder_iface_coords, binder_iface_mask)
334
+
335
+ # ── Compute torsion angles (differentiable) ──
336
+ rec_torsion = compute_torsion_angles_torch(rec_iface_coords, rec_iface_mask) # [n_rec, 6]
337
+ binder_torsion = compute_torsion_angles_torch(binder_iface_coords, binder_iface_mask) # [n_binder, 6]
338
+
339
+ # ── Node features ──
340
+ # AA one-hot (non-differentiable constant)
341
+ rec_aa_onehot = F.one_hot(rec_aa_idx[rec_iface_idx].long(), NUM_AA).float() # [n_rec, 21]
342
+ binder_aa_onehot = F.one_hot(binder_aa_idx[binder_iface_idx].long(), NUM_AA).float() # [n_binder, 21]
343
+
344
+ # Chi angles (constant for receptor, zeros for backbone-only binder)
345
+ rec_chi_iface = rec_chi[rec_iface_idx] # [n_rec, 4]
346
+ binder_chi_iface = binder_chi[binder_iface_idx] # [n_binder, 4]
347
+
348
+ # Chain indicator
349
+ rec_chain_feat = torch.zeros(n_rec, 1, device=device, dtype=dtype)
350
+ binder_chain_feat = torch.ones(n_binder, 1, device=device, dtype=dtype)
351
+
352
+ # Concatenate node features: [AA(21) + torsions(6) + chi(4) + chain(1)] = 32
353
+ rec_node = torch.cat([rec_aa_onehot, rec_torsion, rec_chi_iface, rec_chain_feat], dim=-1)
354
+ binder_node = torch.cat([binder_aa_onehot, binder_torsion, binder_chi_iface, binder_chain_feat], dim=-1)
355
+ node_feats = torch.cat([rec_node, binder_node], dim=0) # [N_total, 32]
356
+ node_mask_flat = torch.cat([rec_iface_mask, binder_iface_mask], dim=0) # [N_total]
357
+
358
+ # ── Edge features (differentiable) ──
359
+ all_origins = torch.cat([rec_origins, binder_origins], dim=0) # [N_total, 3]
360
+ all_rotations = torch.cat([rec_rotations, binder_rotations], dim=0) # [N_total, 3, 3]
361
+
362
+ # Sequence indices
363
+ rec_seq_idx = rec_iface_idx
364
+ binder_seq_idx = binder_iface_idx + rec_coords.shape[0]
365
+ all_seq_idx = torch.cat([rec_seq_idx, binder_seq_idx], dim=0)
366
+
367
+ # Chain IDs
368
+ all_chain_ids = torch.cat([
369
+ torch.zeros(n_rec, device=device, dtype=torch.long),
370
+ torch.ones(n_binder, device=device, dtype=torch.long)
371
+ ], dim=0)
372
+
373
+ edge_feats = compute_edge_features_torch(
374
+ all_origins, all_rotations, all_seq_idx, all_chain_ids, node_mask_flat
375
+ ) # [N_total, N_total, 37]
376
+
377
+ # Add batch dimension
378
+ return {
379
+ 'node_feats': node_feats.unsqueeze(0), # [1, N, 32]
380
+ 'edge_feats': edge_feats.unsqueeze(0), # [1, N, N, 37]
381
+ 'node_mask': node_mask_flat.unsqueeze(0), # [1, N]
382
+ 'n_rec': n_rec,
383
+ 'n_binder': n_binder,
384
+ }
385
+
386
+
387
+ # ── Differentiable Q_theta scoring function ──────────────────────────────────
388
+
389
+ class DifferentiableQTheta:
390
+ """
391
+ Wraps the Q_theta scorer for differentiable scoring w.r.t. binder backbone
392
+ coordinates. Receptor structures are pre-loaded and cached.
393
+
394
+ Usage:
395
+ dq = DifferentiableQTheta(checkpoint_path, device)
396
+ dq.load_receptor(holo_pdb, chain='A', label='holo')
397
+ dq.load_receptor(apo_pdb, chain='A', label='apo')
398
+
399
+ binder_coords = torch.tensor(...) # [N_binder, 4, 3], requires_grad=True
400
+ score_holo = dq.score(binder_coords, binder_mask, binder_aa_idx, 'holo')
401
+ score_apo = dq.score(binder_coords, binder_mask, binder_aa_idx, 'apo')
402
+ selectivity = score_holo - score_apo
403
+ selectivity.backward()
404
+ # binder_coords.grad now contains ∂S/∂coords
405
+ """
406
+
407
+ def __init__(self, checkpoint_path, device='cuda:0', esm_dir=None):
408
+ import sys, os
409
+ _code_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
410
+ if _code_dir not in sys.path:
411
+ sys.path.insert(0, _code_dir)
412
+ from models.scorer import build_model
413
+
414
+ self.device = torch.device(device)
415
+ ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
416
+ self.config = ckpt['config']
417
+ self.model = build_model(self.config)
418
+ self.model.load_state_dict(ckpt['model_state'])
419
+ self.model = self.model.to(self.device)
420
+ self.model.eval()
421
+
422
+ # ESM feature support
423
+ self.use_esm = self.config.get('esm_dim', 0) > 0
424
+ self.esm_dim = self.config.get('esm_dim', 0)
425
+ self.esm_dir = esm_dir or os.path.join(os.environ.get('ALLOGEN_ROOT', '.'), 'data/esm2_embeddings')
426
+
427
+ # Cache receptor data
428
+ self.receptors = {} # label -> {coords, mask, aa_idx, chi, esm_emb?}
429
+
430
+ def load_receptor(self, pdb_path, chain='A', label='holo',
431
+ esm_target=None, esm_key=None):
432
+ """Pre-load and cache receptor structure, optionally with ESM embeddings.
433
+
434
+ Args:
435
+ pdb_path: path to receptor PDB
436
+ chain: chain ID
437
+ label: cache key
438
+ esm_target: target name for ESM dir (e.g., 'abl' for data/esm2_embeddings/abl/)
439
+ esm_key: ESM embedding file key (e.g., '6XR7_A'). If None, auto-derived.
440
+ """
441
+ import sys, os
442
+ _code_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
443
+ if _code_dir not in sys.path:
444
+ sys.path.insert(0, _code_dir)
445
+ from utils.pdb_utils import (
446
+ load_structure, get_residues, get_backbone_coords,
447
+ get_aa_indices, compute_chi_angles
448
+ )
449
+
450
+ model = load_structure(pdb_path)
451
+ chain_obj = model[chain]
452
+ residues = get_residues(chain_obj)
453
+ coords, mask = get_backbone_coords(residues)
454
+ aa_idx = get_aa_indices(residues)
455
+ chi = compute_chi_angles(residues, mask)
456
+
457
+ rec_data = {
458
+ 'coords': torch.from_numpy(coords).float().to(self.device),
459
+ 'mask': torch.from_numpy(mask).bool().to(self.device),
460
+ 'aa_idx': torch.from_numpy(aa_idx).long().to(self.device),
461
+ 'chi': torch.from_numpy(chi).float().to(self.device),
462
+ 'residues': residues,
463
+ }
464
+
465
+ # Load ESM embeddings if model uses ESM
466
+ if self.use_esm and esm_target:
467
+ pdb_id = os.path.basename(pdb_path).replace('.pdb', '')
468
+ if esm_key is None:
469
+ esm_key = f'{pdb_id}_{chain}'
470
+ esm_path = os.path.join(self.esm_dir, esm_target, f'{esm_key}.pt')
471
+ if os.path.exists(esm_path):
472
+ esm_emb = torch.load(esm_path, map_location=self.device, weights_only=True)
473
+ # Truncate/pad to match residue count
474
+ n_res = len(residues)
475
+ if esm_emb.shape[0] > n_res:
476
+ esm_emb = esm_emb[:n_res]
477
+ elif esm_emb.shape[0] < n_res:
478
+ pad = torch.zeros(n_res - esm_emb.shape[0], esm_emb.shape[1],
479
+ device=self.device)
480
+ esm_emb = torch.cat([esm_emb, pad], dim=0)
481
+ rec_data['esm_emb'] = esm_emb.float()
482
+ else:
483
+ rec_data['esm_emb'] = torch.zeros(len(residues), self.esm_dim,
484
+ device=self.device)
485
+
486
+ self.receptors[label] = rec_data
487
+
488
+ def load_receptor_from_coords(self, coords, mask, aa_idx=None, chi=None,
489
+ label='path'):
490
+ """
491
+ Load a receptor from raw backbone coords (not from PDB file).
492
+
493
+ Used for interpolated path frames that don't have PDB files.
494
+ If aa_idx is None, uses all-ALA (index 0). If chi is None, uses zeros.
495
+
496
+ Args:
497
+ coords: [N, 4, 3] numpy or torch backbone coords (N, CA, C, O)
498
+ mask: [N] numpy or torch bool
499
+ aa_idx: [N] numpy or torch int (default: all-ALA = 0)
500
+ chi: [N, 4] numpy or torch float (default: zeros)
501
+ label: str key for caching
502
+ """
503
+ import numpy as np
504
+
505
+ # Convert numpy to torch if needed
506
+ if isinstance(coords, np.ndarray):
507
+ coords = torch.from_numpy(coords).float()
508
+ if isinstance(mask, np.ndarray):
509
+ mask = torch.from_numpy(mask).bool()
510
+
511
+ N = coords.shape[0]
512
+
513
+ if aa_idx is None:
514
+ aa_idx = torch.zeros(N, dtype=torch.long) # all-ALA
515
+ elif isinstance(aa_idx, np.ndarray):
516
+ aa_idx = torch.from_numpy(aa_idx).long()
517
+
518
+ if chi is None:
519
+ chi = torch.zeros(N, 4, dtype=coords.dtype)
520
+ elif isinstance(chi, np.ndarray):
521
+ chi = torch.from_numpy(chi).float()
522
+
523
+ self.receptors[label] = {
524
+ 'coords': coords.to(self.device),
525
+ 'mask': mask.to(self.device),
526
+ 'aa_idx': aa_idx.to(self.device),
527
+ 'chi': chi.to(self.device),
528
+ }
529
+
530
+ def score(self, binder_coords, binder_mask, binder_aa_idx=None,
531
+ binder_chi=None, receptor_label='holo', cutoff=8.0):
532
+ """
533
+ Score binder against a cached receptor. Differentiable w.r.t. binder_coords.
534
+
535
+ Args:
536
+ binder_coords: [N_binder, 4, 3] tensor (can have requires_grad=True)
537
+ binder_mask: [N_binder] bool tensor
538
+ binder_aa_idx: [N_binder] int tensor (default: all UNK)
539
+ binder_chi: [N_binder, 4] tensor (default: zeros)
540
+ receptor_label: key into cached receptors
541
+ cutoff: interface distance cutoff
542
+
543
+ Returns:
544
+ score: scalar tensor in (0, 1), differentiable w.r.t. binder_coords
545
+ """
546
+ rec = self.receptors[receptor_label]
547
+ N_binder = binder_coords.shape[0]
548
+
549
+ if binder_aa_idx is None:
550
+ binder_aa_idx = torch.full((N_binder,), 20, device=self.device, dtype=torch.long) # UNK
551
+ if binder_chi is None:
552
+ binder_chi = torch.zeros(N_binder, 4, device=self.device, dtype=binder_coords.dtype)
553
+
554
+ graph = build_differentiable_interface_graph(
555
+ rec_coords=rec['coords'],
556
+ rec_mask=rec['mask'],
557
+ rec_aa_idx=rec['aa_idx'],
558
+ rec_chi=rec['chi'],
559
+ binder_coords=binder_coords,
560
+ binder_mask=binder_mask,
561
+ binder_aa_idx=binder_aa_idx,
562
+ binder_chi=binder_chi,
563
+ cutoff=cutoff,
564
+ )
565
+
566
+ if graph is None:
567
+ # No interface — return zero score with gradient
568
+ return torch.zeros(1, device=self.device, dtype=binder_coords.dtype, requires_grad=True).squeeze()
569
+
570
+ # Build ESM features if model uses ESM
571
+ esm_feats = None
572
+ if self.use_esm:
573
+ n_rec = graph['n_rec']
574
+ n_binder = graph['n_binder']
575
+ n_total = n_rec + n_binder
576
+ # Receptor ESM: use cached if available, else zeros
577
+ if 'esm_emb' in rec:
578
+ rec_esm = rec['esm_emb']
579
+ # Need to select interface residues (same indices as structural features)
580
+ # The graph was built with rec_iface_idx — we need those indices
581
+ # For simplicity, use zeros for now and rely on the projection layer
582
+ # to handle the zero binder ESM gracefully
583
+ rec_esm_full = rec_esm # [N_rec_total, 1280]
584
+ else:
585
+ rec_esm_full = torch.zeros(rec['coords'].shape[0], self.esm_dim,
586
+ device=self.device)
587
+ # Binder ESM: zeros (designed backbone, no sequence)
588
+ binder_esm = torch.zeros(binder_coords.shape[0], self.esm_dim,
589
+ device=self.device)
590
+ # We need interface indices to select — rebuild them
591
+ rec_ca = rec['coords'][:, 1, :]
592
+ binder_ca = binder_coords[:, 1, :]
593
+ dist_mat = torch.cdist(rec_ca.unsqueeze(0), binder_ca.unsqueeze(0)).squeeze(0)
594
+ dist_mat_c = dist_mat.clone()
595
+ dist_mat_c[~rec['mask'], :] = float('inf')
596
+ dist_mat_c[:, ~binder_mask] = float('inf')
597
+ rec_iface = (dist_mat_c < cutoff).any(dim=1)
598
+ binder_iface = (dist_mat_c < cutoff).any(dim=0)
599
+ rec_iface_idx = torch.where(rec_iface)[0][:n_rec]
600
+ binder_iface_idx = torch.where(binder_iface)[0][:n_binder]
601
+
602
+ rec_esm_iface = rec_esm_full[rec_iface_idx] # [n_rec, 1280]
603
+ binder_esm_iface = binder_esm[binder_iface_idx] # [n_binder, 1280]
604
+ esm_combined = torch.cat([rec_esm_iface, binder_esm_iface], dim=0) # [n_total, 1280]
605
+ esm_feats = esm_combined.unsqueeze(0) # [1, n_total, 1280]
606
+
607
+ score = self.model(graph['node_feats'], graph['edge_feats'], graph['node_mask'],
608
+ esm_feats=esm_feats)
609
+ return score.squeeze() # scalar
610
+
611
+ def selectivity_margin(self, binder_coords, binder_mask,
612
+ binder_aa_idx=None, binder_chi=None,
613
+ holo_label='holo', apo_label='apo', cutoff=8.0):
614
+ """
615
+ Compute selectivity margin S = Q(holo, Y) - Q(apo, Y).
616
+ Differentiable w.r.t. binder_coords.
617
+ """
618
+ q_holo = self.score(binder_coords, binder_mask, binder_aa_idx, binder_chi,
619
+ holo_label, cutoff)
620
+ q_apo = self.score(binder_coords, binder_mask, binder_aa_idx, binder_chi,
621
+ apo_label, cutoff)
622
+ return q_holo - q_apo, q_holo, q_apo
code/models/features.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SE(3)-invariant feature extraction for interface graphs.
3
+ Node and edge features used by the Q_theta scorer.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import numpy as np
9
+
10
+ # Ensure utils is importable (for both direct and package imports)
11
+ _CODE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
+ if _CODE_DIR not in sys.path:
13
+ sys.path.insert(0, _CODE_DIR)
14
+
15
+ from utils.pdb_utils import (
16
+ rbf_encode, compute_backbone_frames, compute_torsion_angles,
17
+ get_aa_indices, compute_chi_angles, get_cb_positions, NUM_AA
18
+ )
19
+
20
+ # Feature dimensions
21
+ # one-hot AA (21) + backbone torsions (6) + chi1 sin/cos (2) + chi2 sin/cos (2) + chain indicator (1) = 32
22
+ NODE_DIM = NUM_AA + 6 + 4 + 1 # = 32
23
+ EDGE_DIM = 16 + 3 + 9 + 8 + 1 # RBF dist (16) + direction (3) + rel rotation (9) + seq sep (8) + same chain (1) = 37
24
+
25
+ MAX_SEQ_SEP = 32 # bins for sequence separation
26
+
27
+
28
+ def seq_sep_encode(sep, n_bins=8, max_sep=MAX_SEQ_SEP):
29
+ """Bin-encode sequence separation."""
30
+ bins = np.linspace(-max_sep, max_sep, n_bins + 1)
31
+ sep_clipped = np.clip(sep, -max_sep, max_sep)
32
+ encoded = np.zeros(n_bins, dtype=np.float32)
33
+ bin_idx = np.digitize(sep_clipped, bins) - 1
34
+ bin_idx = np.clip(bin_idx, 0, n_bins - 1)
35
+ encoded[bin_idx] = 1.0
36
+ return encoded
37
+
38
+
39
+ def extract_node_features(residues, coords, mask, torsion_angles, chi_angles, chain_id):
40
+ """
41
+ Compute per-residue node features.
42
+
43
+ Args:
44
+ residues: list of Bio.PDB residues
45
+ coords: [N, 4, 3] backbone coords
46
+ mask: [N] bool
47
+ torsion_angles: [N, 6] sin/cos of phi, psi, omega
48
+ chi_angles: [N, 4] sin/cos of chi1, chi2
49
+ chain_id: 0 = receptor, 1 = binder
50
+
51
+ Returns:
52
+ node_feats: [N, NODE_DIM] (NODE_DIM = 32)
53
+ """
54
+ N = len(residues)
55
+ aa_idx = get_aa_indices(residues)
56
+
57
+ # One-hot amino acid
58
+ aa_onehot = np.zeros((N, NUM_AA), dtype=np.float32)
59
+ for i in range(N):
60
+ if mask[i]:
61
+ aa_onehot[i, aa_idx[i]] = 1.0
62
+
63
+ # Chain indicator
64
+ chain_feat = np.full((N, 1), chain_id, dtype=np.float32)
65
+
66
+ # Concatenate
67
+ node_feats = np.concatenate([
68
+ aa_onehot, # [N, 21]
69
+ torsion_angles, # [N, 6]
70
+ chi_angles, # [N, 4]
71
+ chain_feat, # [N, 1]
72
+ ], axis=-1)
73
+
74
+ return node_feats # [N, 32]
75
+
76
+
77
+ def extract_edge_features(coords_i, frames_i, coords_j, frames_j,
78
+ seq_idx_i, seq_idx_j, chain_i, chain_j, mask_i, mask_j):
79
+ """
80
+ Compute SE(3)-invariant edge features between residue sets i and j.
81
+ Vectorized over all pairs.
82
+
83
+ Args:
84
+ coords_i: [N_i, 4, 3] backbone coords of set i (full interface)
85
+ frames_i: (origins_i [N_i, 3], rotations_i [N_i, 3, 3])
86
+ coords_j: [N_j, 4, 3]
87
+ frames_j: (origins_j [N_j, 3], rotations_j [N_j, 3, 3])
88
+ seq_idx_i: [N_i] integer sequence indices (for sequence separation)
89
+ seq_idx_j: [N_j] integer sequence indices
90
+ chain_i: int (0 or 1)
91
+ chain_j: int (0 or 1)
92
+ mask_i: [N_i] bool
93
+ mask_j: [N_j] bool
94
+
95
+ Returns:
96
+ edge_feats: [N_i, N_j, EDGE_DIM]
97
+ """
98
+ N_i, N_j = len(coords_i), len(coords_j)
99
+ origins_i, rotations_i = frames_i
100
+ origins_j, rotations_j = frames_j
101
+
102
+ ca_i = origins_i # [N_i, 3]
103
+ ca_j = origins_j # [N_j, 3]
104
+
105
+ # --- Distance features ---
106
+ diff = ca_j[None, :, :] - ca_i[:, None, :] # [N_i, N_j, 3]
107
+ dist = np.sqrt((diff ** 2).sum(axis=-1)) # [N_i, N_j]
108
+ dist_rbf = rbf_encode(dist, d_min=0., d_max=20., n_bins=16) # [N_i, N_j, 16]
109
+
110
+ # --- Direction in local frame of i ---
111
+ # unit vector from i to j in global frame
112
+ unit_diff = diff / (dist[..., None] + 1e-8) # [N_i, N_j, 3]
113
+ # rotate by R_i^T to get local direction
114
+ # rotations_i: [N_i, 3, 3], unit_diff: [N_i, N_j, 3]
115
+ # local_dir[i,j] = R_i^T @ (ca_j - ca_i) / dist
116
+ local_dir = np.einsum('ikl,ijl->ijk', rotations_i, unit_diff) # [N_i, N_j, 3]
117
+
118
+ # --- Relative rotation: R_i^T R_j ---
119
+ # rotations_i: [N_i, 3, 3], rotations_j: [N_j, 3, 3]
120
+ # rel_rot[i,j] = R_i^T @ R_j -> [N_i, N_j, 3, 3] -> flatten to [N_i, N_j, 9]
121
+ rel_rot = np.einsum('ikl,jlm->ijkm', rotations_i, rotations_j) # [N_i, N_j, 3, 3]
122
+ rel_rot_flat = rel_rot.reshape(N_i, N_j, 9) # [N_i, N_j, 9]
123
+
124
+ # --- Sequence separation ---
125
+ sep = seq_idx_j[None, :] - seq_idx_i[:, None] # [N_i, N_j]
126
+ # Encode each pair (loop over all; use vectorized bin assignment)
127
+ sep_flat = sep.reshape(-1)
128
+ sep_enc = np.array([seq_sep_encode(s) for s in sep_flat]) # [N_i*N_j, 8]
129
+ sep_enc = sep_enc.reshape(N_i, N_j, 8)
130
+
131
+ # Cross-chain pairs get sep=0 by convention if different chains
132
+ if chain_i != chain_j:
133
+ sep_enc[:] = 0.0
134
+
135
+ # --- Same chain indicator ---
136
+ same_chain = float(chain_i == chain_j)
137
+ same_chain_feat = np.full((N_i, N_j, 1), same_chain, dtype=np.float32)
138
+
139
+ # --- Concatenate ---
140
+ edge_feats = np.concatenate([
141
+ dist_rbf, # [N_i, N_j, 16]
142
+ local_dir, # [N_i, N_j, 3]
143
+ rel_rot_flat, # [N_i, N_j, 9]
144
+ sep_enc, # [N_i, N_j, 8]
145
+ same_chain_feat # [N_i, N_j, 1]
146
+ ], axis=-1) # [N_i, N_j, 37]
147
+
148
+ # Zero out edges involving masked residues
149
+ edge_feats[~mask_i, :, :] = 0.0
150
+ edge_feats[:, ~mask_j, :] = 0.0
151
+
152
+ return edge_feats.astype(np.float32)
153
+
154
+
155
+ def build_interface_graph(rec_residues, rec_coords, rec_mask,
156
+ binder_residues, binder_coords, binder_mask,
157
+ rec_interface_mask, binder_interface_mask,
158
+ max_nodes: int = 128):
159
+ """
160
+ Build a joint interface graph combining receptor and binder interface residues.
161
+
162
+ Returns a dict with:
163
+ node_feats: [N_total, NODE_DIM]
164
+ edge_feats: [N_total, N_total, EDGE_DIM]
165
+ node_mask: [N_total] bool
166
+ n_rec: int (number of receptor interface nodes)
167
+ n_binder: int (number of binder interface nodes)
168
+ """
169
+ # Select interface residues
170
+ rec_iface_idx = np.where(rec_interface_mask)[0]
171
+ binder_iface_idx = np.where(binder_interface_mask)[0]
172
+
173
+ # Truncate if too many
174
+ if len(rec_iface_idx) > max_nodes // 2:
175
+ rec_iface_idx = rec_iface_idx[:max_nodes // 2]
176
+ if len(binder_iface_idx) > max_nodes // 2:
177
+ binder_iface_idx = binder_iface_idx[:max_nodes // 2]
178
+
179
+ n_rec = len(rec_iface_idx)
180
+ n_binder = len(binder_iface_idx)
181
+ n_total = n_rec + n_binder
182
+
183
+ if n_total == 0:
184
+ return None
185
+
186
+ # Extract coords for interface residues
187
+ rec_iface_coords = rec_coords[rec_iface_idx] # [n_rec, 4, 3]
188
+ binder_iface_coords = binder_coords[binder_iface_idx] # [n_binder, 4, 3]
189
+ rec_iface_mask = rec_mask[rec_iface_idx]
190
+ binder_iface_mask = binder_mask[binder_iface_idx]
191
+
192
+ # Compute backbone frames
193
+ rec_origins, rec_rotations = compute_backbone_frames(rec_iface_coords, rec_iface_mask)
194
+ binder_origins, binder_rotations = compute_backbone_frames(binder_iface_coords, binder_iface_mask)
195
+
196
+ # Compute torsion angles
197
+ # We need full-chain coords for proper phi/psi computation, but use local approximation here
198
+ rec_torsion = compute_torsion_angles(rec_iface_coords, rec_iface_mask)
199
+ binder_torsion = compute_torsion_angles(binder_iface_coords, binder_iface_mask)
200
+
201
+ # Extract residues
202
+ rec_iface_residues = [rec_residues[i] for i in rec_iface_idx]
203
+ binder_iface_residues = [binder_residues[i] for i in binder_iface_idx]
204
+
205
+ # Compute sidechain chi1/chi2 angles
206
+ rec_chi = compute_chi_angles(rec_iface_residues, rec_iface_mask)
207
+ binder_chi = compute_chi_angles(binder_iface_residues, binder_iface_mask)
208
+
209
+ # Node features
210
+ rec_node_feats = extract_node_features(
211
+ rec_iface_residues, rec_iface_coords, rec_iface_mask, rec_torsion, rec_chi, chain_id=0
212
+ ) # [n_rec, NODE_DIM]
213
+ binder_node_feats = extract_node_features(
214
+ binder_iface_residues, binder_iface_coords, binder_iface_mask, binder_torsion, binder_chi, chain_id=1
215
+ ) # [n_binder, NODE_DIM]
216
+
217
+ node_feats = np.concatenate([rec_node_feats, binder_node_feats], axis=0) # [N, NODE_DIM]
218
+ node_mask = np.concatenate([rec_iface_mask, binder_iface_mask], axis=0)
219
+
220
+ # Edge features (4 blocks: RR, RB, BR, BB)
221
+ all_coords = np.concatenate([rec_iface_coords, binder_iface_coords], axis=0)
222
+ all_mask = node_mask
223
+ all_origins = np.concatenate([rec_origins, binder_origins], axis=0)
224
+ all_rotations = np.concatenate([rec_rotations, binder_rotations], axis=0)
225
+ all_seq_idx = np.concatenate([rec_iface_idx, binder_iface_idx + len(rec_residues)], axis=0)
226
+ all_chain = np.array([0] * n_rec + [1] * n_binder, dtype=np.int32)
227
+
228
+ # Compute full NxN edge features
229
+ frames_all = (all_origins, all_rotations)
230
+ edge_feats = extract_edge_features(
231
+ all_coords, frames_all,
232
+ all_coords, frames_all,
233
+ all_seq_idx, all_seq_idx,
234
+ -1, -1, # chain handled via all_chain array below
235
+ all_mask, all_mask
236
+ ) # [N, N, EDGE_DIM]
237
+
238
+ # Patch same_chain feature (last dim) using actual chain IDs
239
+ same_chain_feat = (all_chain[:, None] == all_chain[None, :]).astype(np.float32)
240
+ edge_feats[:, :, -1] = same_chain_feat
241
+
242
+ return {
243
+ 'node_feats': node_feats.astype(np.float32), # [N, NODE_DIM]
244
+ 'edge_feats': edge_feats.astype(np.float32), # [N, N, EDGE_DIM]
245
+ 'node_mask': node_mask, # [N]
246
+ 'n_rec': n_rec,
247
+ 'n_binder': n_binder,
248
+ 'rec_iface_idx': rec_iface_idx, # [n_rec] original residue indices
249
+ 'binder_iface_idx': binder_iface_idx, # [n_binder] original residue indices
250
+ }
code/models/scorer.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Q_theta: State-selectivity scorer for Allo-Designer.
3
+
4
+ Architecture: Dense Edge-Biased Graph Transformer
5
+ - Input: padded interface graph (node feats + pairwise edge feats)
6
+ - SE(3)-invariant features (all features from distances/angles in backbone frames)
7
+ - Output: Q_theta(X, Y) in (0,1) = probability-like compatibility/selectivity score
8
+
9
+ No torch_geometric dependency: uses dense attention with edge biases.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import math
16
+
17
+
18
+ class RBFLayer(nn.Module):
19
+ """Learnable RBF embedding for edge distances."""
20
+ def __init__(self, n_bins: int = 16, d_min: float = 0., d_max: float = 20.):
21
+ super().__init__()
22
+ centers = torch.linspace(d_min, d_max, n_bins)
23
+ self.register_buffer('centers', centers)
24
+ self.log_sigma = nn.Parameter(torch.zeros(1))
25
+
26
+ def forward(self, dist):
27
+ # dist: [...] -> [..., n_bins]
28
+ sigma = torch.exp(self.log_sigma)
29
+ return torch.exp(-((dist.unsqueeze(-1) - self.centers) ** 2) / (2 * sigma ** 2))
30
+
31
+
32
+ class EdgeBiasedMHA(nn.Module):
33
+ """
34
+ Multi-Head Self-Attention with additive edge biases.
35
+ Implements the core equation:
36
+ A_ij = (Q_i K_j^T / sqrt(d)) + b_ij
37
+ where b_ij is computed from edge features.
38
+ """
39
+ def __init__(self, d_model: int, n_heads: int, d_edge: int, dropout: float = 0.1):
40
+ super().__init__()
41
+ assert d_model % n_heads == 0
42
+ self.n_heads = n_heads
43
+ self.d_head = d_model // n_heads
44
+ self.scale = math.sqrt(self.d_head)
45
+
46
+ self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
47
+ self.out_proj = nn.Linear(d_model, d_model)
48
+ self.edge_proj = nn.Linear(d_edge, n_heads) # edge features -> per-head bias
49
+ self.dropout = nn.Dropout(dropout)
50
+
51
+ def forward(self, x, edge_feats, mask=None):
52
+ """
53
+ x: [B, N, d_model]
54
+ edge_feats: [B, N, N, d_edge]
55
+ mask: [B, N] bool (True = valid residue)
56
+ """
57
+ B, N, D = x.shape
58
+ H = self.n_heads
59
+
60
+ # QKV projection
61
+ qkv = self.qkv_proj(x).reshape(B, N, 3, H, self.d_head).permute(2, 0, 3, 1, 4)
62
+ q, k, v = qkv.unbind(0) # each [B, H, N, d_head]
63
+
64
+ # Scaled dot-product attention logits
65
+ attn_logits = (q @ k.transpose(-2, -1)) / self.scale # [B, H, N, N]
66
+
67
+ # Edge bias: [B, N, N, H] -> [B, H, N, N]
68
+ edge_bias = self.edge_proj(edge_feats).permute(0, 3, 1, 2) # [B, H, N, N]
69
+ attn_logits = attn_logits + edge_bias
70
+
71
+ # Padding mask: mask out padded positions
72
+ if mask is not None:
73
+ # mask: [B, N] True=valid; padding=False
74
+ padding = ~mask # [B, N] True=padding
75
+ attn_logits = attn_logits.masked_fill(
76
+ padding[:, None, None, :], # [B, 1, 1, N]
77
+ float('-inf')
78
+ )
79
+
80
+ attn_weights = self.dropout(F.softmax(attn_logits, dim=-1))
81
+
82
+ # Handle all-padding rows (NaN -> 0)
83
+ attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
84
+
85
+ out = (attn_weights @ v) # [B, H, N, d_head]
86
+ out = out.transpose(1, 2).reshape(B, N, D) # [B, N, D]
87
+ return self.out_proj(out)
88
+
89
+
90
+ class InterfaceTransformerLayer(nn.Module):
91
+ """Single layer of edge-biased transformer with pre-norm."""
92
+ def __init__(self, d_model: int, n_heads: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1):
93
+ super().__init__()
94
+ self.attn = EdgeBiasedMHA(d_model, n_heads, d_edge, dropout)
95
+ self.ff = nn.Sequential(
96
+ nn.Linear(d_model, d_model * ff_mult),
97
+ nn.GELU(),
98
+ nn.Dropout(dropout),
99
+ nn.Linear(d_model * ff_mult, d_model),
100
+ )
101
+ self.norm1 = nn.LayerNorm(d_model)
102
+ self.norm2 = nn.LayerNorm(d_model)
103
+ self.drop = nn.Dropout(dropout)
104
+
105
+ def forward(self, x, edge_feats, mask=None):
106
+ x = x + self.drop(self.attn(self.norm1(x), edge_feats, mask))
107
+ x = x + self.drop(self.ff(self.norm2(x)))
108
+ return x
109
+
110
+
111
+ class GATLayer(nn.Module):
112
+ """Multi-head GAT layer with pre-norm. No edge features in attention."""
113
+ def __init__(self, d_model: int, n_heads: int, ff_mult: int = 4, dropout: float = 0.1):
114
+ super().__init__()
115
+ assert d_model % n_heads == 0
116
+ self.n_heads = n_heads
117
+ self.d_head = d_model // n_heads
118
+
119
+ self.W = nn.Linear(d_model, d_model, bias=False)
120
+ self.a_l = nn.Parameter(torch.randn(n_heads, self.d_head))
121
+ self.a_r = nn.Parameter(torch.randn(n_heads, self.d_head))
122
+ nn.init.xavier_uniform_(self.a_l.unsqueeze(0))
123
+ nn.init.xavier_uniform_(self.a_r.unsqueeze(0))
124
+ self.out_proj = nn.Linear(d_model, d_model)
125
+ self.leaky_relu = nn.LeakyReLU(0.2)
126
+ self.attn_drop = nn.Dropout(dropout)
127
+
128
+ self.ff = nn.Sequential(
129
+ nn.Linear(d_model, d_model * ff_mult), nn.GELU(),
130
+ nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model),
131
+ )
132
+ self.norm1 = nn.LayerNorm(d_model)
133
+ self.norm2 = nn.LayerNorm(d_model)
134
+ self.drop = nn.Dropout(dropout)
135
+
136
+ def forward(self, x, edge_feats, mask=None):
137
+ B, N, D = x.shape
138
+ H = self.n_heads
139
+
140
+ h = self.norm1(x)
141
+ Wh = self.W(h).view(B, N, H, self.d_head) # [B, N, H, d_head]
142
+ e_l = (Wh * self.a_l).sum(-1) # [B, N, H]
143
+ e_r = (Wh * self.a_r).sum(-1) # [B, N, H]
144
+ attn = self.leaky_relu(e_l.unsqueeze(2) + e_r.unsqueeze(1)) # [B, N, N, H]
145
+ attn = attn.permute(0, 3, 1, 2) # [B, H, N, N]
146
+
147
+ if mask is not None:
148
+ attn = attn.masked_fill(~mask[:, None, None, :], float('-inf'))
149
+
150
+ attn = self.attn_drop(F.softmax(attn, dim=-1))
151
+ attn = torch.nan_to_num(attn, nan=0.0)
152
+
153
+ out = torch.einsum('bhnm,bmhd->bnhd', attn, Wh)
154
+ out = out.reshape(B, N, D)
155
+ x = x + self.drop(self.out_proj(out))
156
+ x = x + self.drop(self.ff(self.norm2(x)))
157
+ return x
158
+
159
+
160
+ class GCNLayer(nn.Module):
161
+ """GCN layer with edge-weighted message passing and pre-norm."""
162
+ def __init__(self, d_model: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1):
163
+ super().__init__()
164
+ self.msg_proj = nn.Linear(d_model, d_model, bias=False)
165
+ self.edge_weight = nn.Linear(d_edge, 1)
166
+
167
+ self.ff = nn.Sequential(
168
+ nn.Linear(d_model, d_model * ff_mult), nn.GELU(),
169
+ nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model),
170
+ )
171
+ self.norm1 = nn.LayerNorm(d_model)
172
+ self.norm2 = nn.LayerNorm(d_model)
173
+ self.drop = nn.Dropout(dropout)
174
+
175
+ def forward(self, x, edge_feats, mask=None):
176
+ B, N, D = x.shape
177
+ h = self.norm1(x)
178
+ msg = self.msg_proj(h) # [B, N, D]
179
+
180
+ w = self.edge_weight(edge_feats).squeeze(-1) # [B, N, N]
181
+ if mask is not None:
182
+ w = w.masked_fill(~mask[:, None, :], float('-inf'))
183
+ w = F.softmax(w, dim=-1)
184
+ w = torch.nan_to_num(w, nan=0.0)
185
+
186
+ agg = torch.bmm(w, msg) # [B, N, D]
187
+ x = x + self.drop(agg)
188
+ x = x + self.drop(self.ff(self.norm2(x)))
189
+ return x
190
+
191
+
192
+ class CrossChainTransformerLayer(nn.Module):
193
+ """Cross-chain attention: each node attends only to nodes from the other chain."""
194
+ def __init__(self, d_model: int, n_heads: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1):
195
+ super().__init__()
196
+ assert d_model % n_heads == 0
197
+ self.n_heads = n_heads
198
+ self.d_head = d_model // n_heads
199
+ self.scale = math.sqrt(self.d_head)
200
+
201
+ self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
202
+ self.out_proj = nn.Linear(d_model, d_model)
203
+ self.edge_proj = nn.Linear(d_edge, n_heads)
204
+ self.attn_drop = nn.Dropout(dropout)
205
+
206
+ self.ff = nn.Sequential(
207
+ nn.Linear(d_model, d_model * ff_mult), nn.GELU(),
208
+ nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model),
209
+ )
210
+ self.norm1 = nn.LayerNorm(d_model)
211
+ self.norm2 = nn.LayerNorm(d_model)
212
+ self.drop = nn.Dropout(dropout)
213
+
214
+ def forward(self, x, edge_feats, mask=None, chain_mask=None):
215
+ """
216
+ x: [B, N, d_model]
217
+ edge_feats: [B, N, N, d_edge]
218
+ mask: [B, N] bool (True = valid)
219
+ chain_mask: [B, N] float (0=receptor, 1=binder)
220
+ """
221
+ B, N, D = x.shape
222
+ H = self.n_heads
223
+
224
+ h = self.norm1(x)
225
+ qkv = self.qkv_proj(h).reshape(B, N, 3, H, self.d_head).permute(2, 0, 3, 1, 4)
226
+ q, k, v = qkv.unbind(0) # each [B, H, N, d_head]
227
+
228
+ attn_logits = (q @ k.transpose(-2, -1)) / self.scale # [B, H, N, N]
229
+ edge_bias = self.edge_proj(edge_feats).permute(0, 3, 1, 2) # [B, H, N, N]
230
+ attn_logits = attn_logits + edge_bias
231
+
232
+ # Mask padding
233
+ if mask is not None:
234
+ attn_logits = attn_logits.masked_fill(~mask[:, None, None, :], float('-inf'))
235
+
236
+ # Cross-chain mask: block same-chain attention
237
+ if chain_mask is not None:
238
+ same_chain = (chain_mask.unsqueeze(1) == chain_mask.unsqueeze(2)) # [B, N, N]
239
+ attn_logits = attn_logits.masked_fill(same_chain[:, None, :, :], float('-inf'))
240
+
241
+ attn_weights = self.attn_drop(F.softmax(attn_logits, dim=-1))
242
+ attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
243
+
244
+ out = (attn_weights @ v).transpose(1, 2).reshape(B, N, D)
245
+ x = x + self.drop(self.out_proj(out))
246
+ x = x + self.drop(self.ff(self.norm2(x)))
247
+ return x
248
+
249
+
250
+ class EdgeUpdateLayer(nn.Module):
251
+ """Updates edge features using node representations each layer.
252
+ Memory-efficient: projects nodes to low-dim before outer product."""
253
+ def __init__(self, d_model: int, d_edge: int, dropout: float = 0.1):
254
+ super().__init__()
255
+ d_proj = min(32, d_model // 4) # Low-dim projection to save memory
256
+ self.proj_i = nn.Linear(d_model, d_proj, bias=False)
257
+ self.proj_j = nn.Linear(d_model, d_proj, bias=False)
258
+ self.edge_mlp = nn.Sequential(
259
+ nn.Linear(2 * d_proj + d_edge, d_edge),
260
+ nn.GELU(),
261
+ nn.Dropout(dropout),
262
+ nn.Linear(d_edge, d_edge),
263
+ )
264
+ self.norm = nn.LayerNorm(d_edge)
265
+
266
+ def forward(self, h, e, mask=None):
267
+ B, N, D = h.shape
268
+ hi = self.proj_i(h).unsqueeze(2).expand(-1, -1, N, -1) # [B, N, N, d_proj]
269
+ hj = self.proj_j(h).unsqueeze(1).expand(-1, N, -1, -1) # [B, N, N, d_proj]
270
+ inp = torch.cat([hi, hj, self.norm(e)], dim=-1)
271
+ e = e + self.edge_mlp(inp)
272
+ return e
273
+
274
+
275
+ class InterfaceGNN(nn.Module):
276
+ """
277
+ Q_theta scorer: SE(3)-invariant dense graph transformer for interface scoring.
278
+
279
+ Input:
280
+ node_feats: [B, N, node_dim] per-residue features
281
+ edge_feats: [B, N, N, edge_dim] pairwise edge features
282
+ mask: [B, N] bool (True = valid residue, False = padding)
283
+
284
+ Output:
285
+ scores: [B] in (0, 1) = Q_theta(X, Y)
286
+ """
287
+ def __init__(
288
+ self,
289
+ node_dim: int = 28,
290
+ edge_dim: int = 37,
291
+ hidden_dim: int = 128,
292
+ n_layers: int = 4,
293
+ n_heads: int = 8,
294
+ ff_mult: int = 4,
295
+ dropout: float = 0.1,
296
+ backbone: str = 'transformer',
297
+ pooling: str = 'meanmax', # 'meanmax' or 'attention'
298
+ edge_update: bool = False,
299
+ esm_dim: int = 0, # 0 = no ESM; >0 = ESM embedding dim to project
300
+ esm_proj_dim: int = 128, # projection dim for ESM features
301
+ esm_dropout: float = 0.0, # dropout on ESM projection
302
+ ):
303
+ super().__init__()
304
+ actual_node_dim = node_dim + (esm_proj_dim if esm_dim > 0 else 0)
305
+ self.esm_dim = esm_dim
306
+ if esm_dim > 0:
307
+ layers = [
308
+ nn.Linear(esm_dim, esm_proj_dim),
309
+ nn.LayerNorm(esm_proj_dim),
310
+ nn.GELU(),
311
+ ]
312
+ if esm_dropout > 0:
313
+ layers.append(nn.Dropout(esm_dropout))
314
+ self.esm_proj = nn.Sequential(*layers)
315
+ self.node_embed = nn.Sequential(
316
+ nn.Linear(actual_node_dim, hidden_dim),
317
+ nn.LayerNorm(hidden_dim),
318
+ nn.GELU(),
319
+ )
320
+ self.edge_embed = nn.Sequential(
321
+ nn.Linear(edge_dim, hidden_dim),
322
+ nn.GELU(),
323
+ nn.Linear(hidden_dim, hidden_dim // 2),
324
+ )
325
+ d_edge_hidden = hidden_dim // 2
326
+
327
+ if backbone == 'transformer':
328
+ self.layers = nn.ModuleList([
329
+ InterfaceTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout)
330
+ for _ in range(n_layers)
331
+ ])
332
+ elif backbone == 'gat':
333
+ self.layers = nn.ModuleList([
334
+ GATLayer(hidden_dim, n_heads, ff_mult, dropout)
335
+ for _ in range(n_layers)
336
+ ])
337
+ elif backbone == 'gcn':
338
+ self.layers = nn.ModuleList([
339
+ GCNLayer(hidden_dim, d_edge_hidden, ff_mult, dropout)
340
+ for _ in range(n_layers)
341
+ ])
342
+ elif backbone == 'crosschain':
343
+ # Interleave self-attention and cross-chain attention
344
+ layers = []
345
+ for i in range(n_layers):
346
+ if i % 2 == 0:
347
+ layers.append(InterfaceTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout))
348
+ else:
349
+ layers.append(CrossChainTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout))
350
+ self.layers = nn.ModuleList(layers)
351
+ else:
352
+ raise ValueError(f"Unknown backbone: {backbone}")
353
+
354
+ self.norm_out = nn.LayerNorm(hidden_dim)
355
+
356
+ # Edge update layers (optional)
357
+ self.edge_update = edge_update
358
+ if edge_update:
359
+ self.edge_update_layers = nn.ModuleList([
360
+ EdgeUpdateLayer(hidden_dim, d_edge_hidden, dropout)
361
+ for _ in range(n_layers)
362
+ ])
363
+
364
+ # Pooling
365
+ self.pooling = pooling
366
+ if pooling == 'attention':
367
+ self.attn_pool = nn.Sequential(
368
+ nn.Linear(hidden_dim, hidden_dim // 2),
369
+ nn.Tanh(),
370
+ nn.Linear(hidden_dim // 2, 1),
371
+ )
372
+ pool_dim = hidden_dim
373
+ else:
374
+ pool_dim = 2 * hidden_dim
375
+
376
+ # Scoring head
377
+ self.head = nn.Sequential(
378
+ nn.Linear(pool_dim, hidden_dim),
379
+ nn.GELU(),
380
+ nn.Dropout(dropout),
381
+ nn.Linear(hidden_dim, hidden_dim // 2),
382
+ nn.GELU(),
383
+ nn.Linear(hidden_dim // 2, 1),
384
+ )
385
+
386
+ def forward(self, node_feats, edge_feats, mask, esm_feats=None):
387
+ """
388
+ node_feats: [B, N, node_dim]
389
+ edge_feats: [B, N, N, edge_dim]
390
+ mask: [B, N] bool
391
+ esm_feats: [B, N, esm_dim] optional ESM-2 embeddings
392
+ Returns: scores [B] in (0, 1)
393
+ """
394
+ B, N, _ = node_feats.shape
395
+
396
+ # Extract chain mask for cross-chain attention (last dim = chain indicator)
397
+ chain_mask = node_feats[:, :, -1] # [B, N] float: 0=receptor, 1=binder
398
+
399
+ # Optionally concatenate projected ESM features
400
+ if self.esm_dim > 0 and esm_feats is not None:
401
+ esm_proj = self.esm_proj(esm_feats) # [B, N, 128]
402
+ node_feats = torch.cat([node_feats, esm_proj], dim=-1)
403
+
404
+ # Embed nodes and edges
405
+ h = self.node_embed(node_feats) # [B, N, hidden_dim]
406
+ e = self.edge_embed(edge_feats) # [B, N, N, hidden_dim//2]
407
+
408
+ # Graph transformer layers (with optional edge updates)
409
+ for i, layer in enumerate(self.layers):
410
+ if isinstance(layer, CrossChainTransformerLayer):
411
+ h = layer(h, e, mask, chain_mask=chain_mask)
412
+ else:
413
+ h = layer(h, e, mask)
414
+ if self.edge_update:
415
+ e = self.edge_update_layers[i](h, e, mask)
416
+
417
+ h = self.norm_out(h) # [B, N, hidden_dim]
418
+
419
+ # Pooling
420
+ mask_f = mask.float().unsqueeze(-1) # [B, N, 1]
421
+
422
+ if self.pooling == 'attention':
423
+ # Learned attention pooling
424
+ attn_logits = self.attn_pool(h).squeeze(-1) # [B, N]
425
+ attn_logits = attn_logits.masked_fill(~mask, float('-inf'))
426
+ attn_weights = F.softmax(attn_logits, dim=-1).unsqueeze(-1) # [B, N, 1]
427
+ attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
428
+ h_pool = (h * attn_weights).sum(dim=1) # [B, hidden_dim]
429
+ else:
430
+ # Mean + max pooling
431
+ h_masked = h * mask_f
432
+ h_mean = h_masked.sum(dim=1) / (mask_f.sum(dim=1) + 1e-8)
433
+ h_max_input = h_masked + (1 - mask_f) * (-1e9)
434
+ h_max = h_max_input.max(dim=1).values
435
+ h_pool = torch.cat([h_mean, h_max], dim=-1) # [B, 2*hidden_dim]
436
+
437
+ # Score
438
+ logits = self.head(h_pool).squeeze(-1) # [B]
439
+ scores = torch.sigmoid(logits) # [B] in (0, 1)
440
+ return scores
441
+
442
+
443
+ class AlloDesignerScorer(nn.Module):
444
+ """
445
+ Full Q_theta model wrapper with loss computation.
446
+
447
+ Implements the two-stage training objective:
448
+ Phase 1: DockQ regression (MSE loss)
449
+ Phase 2: Selectivity margin ranking (contrastive loss)
450
+
451
+ The selectivity margin from the paper (Eq. 3):
452
+ S_theta(Y; X+, N) = logit(Q(X+, Y)) - log sum_X- exp(logit(Q(X-, Y)))
453
+ """
454
+ def __init__(self, node_dim=28, edge_dim=37, hidden_dim=128,
455
+ n_layers=4, n_heads=8, dropout=0.1, backbone='transformer',
456
+ pooling='meanmax', edge_update=False, esm_dim=0,
457
+ esm_proj_dim=128, esm_dropout=0.0):
458
+ super().__init__()
459
+ self.gnn = InterfaceGNN(node_dim, edge_dim, hidden_dim, n_layers, n_heads,
460
+ dropout=dropout, backbone=backbone,
461
+ pooling=pooling, edge_update=edge_update,
462
+ esm_dim=esm_dim, esm_proj_dim=esm_proj_dim,
463
+ esm_dropout=esm_dropout)
464
+
465
+ def forward(self, node_feats, edge_feats, mask, esm_feats=None):
466
+ return self.gnn(node_feats, edge_feats, mask, esm_feats=esm_feats)
467
+
468
+ def compute_dockq_loss(self, scores, dockq_labels):
469
+ """Phase 1: MSE regression loss against DockQ labels."""
470
+ return F.mse_loss(scores, dockq_labels.float())
471
+
472
+ def compute_selectivity_loss(self, pos_scores, neg_scores_list, margin: float = 0.2):
473
+ """
474
+ Phase 2: Selectivity margin loss.
475
+
476
+ For each binder Y:
477
+ pos_score = Q(X+, Y)
478
+ neg_scores = [Q(X-, Y) for X- in N]
479
+
480
+ Loss = -mean(S_theta) where
481
+ S_theta = logit(pos_score) - log sum exp(logit(neg_scores))
482
+
483
+ Also computes a soft margin loss:
484
+ L_margin = mean(max(0, margin - (pos_score - neg_score)))
485
+ """
486
+ # logit = log(p / (1-p))
487
+ eps = 1e-6
488
+ pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps))
489
+
490
+ # neg_scores_list: list of [B] tensors
491
+ neg_logits = torch.stack([
492
+ torch.log(s.clamp(eps, 1 - eps) / (1 - s).clamp(eps))
493
+ for s in neg_scores_list
494
+ ], dim=-1) # [B, n_neg]
495
+
496
+ # InfoNCE-style selectivity margin
497
+ log_denom = torch.logsumexp(neg_logits, dim=-1) # [B]
498
+ selectivity = pos_logit - log_denom # [B]
499
+ selectivity_loss = -selectivity.mean()
500
+
501
+ # Soft margin loss (averaged over all negatives)
502
+ margin_losses = []
503
+ for neg_scores in neg_scores_list:
504
+ margin_losses.append(F.relu(margin - (pos_scores - neg_scores)))
505
+ margin_loss = torch.stack(margin_losses, dim=-1).mean()
506
+
507
+ return selectivity_loss + margin_loss
508
+
509
+ def compute_path_selectivity_loss(self, pos_scores, neg_scores_list,
510
+ path_scores_list, path_taus,
511
+ margin=0.2, path_lambda=0.5):
512
+ """
513
+ Extended selectivity loss with path monotonicity regularization.
514
+
515
+ Args:
516
+ pos_scores: [B] Q(X1, Y) -- goal state scores
517
+ neg_scores_list: list of [B] -- Q(X0, Y), Q(X_cryptic, Y), etc.
518
+ path_scores_list: list of [B] -- Q(X_tau, Y) for each path frame
519
+ path_taus: list of float -- tau values for each path frame (sorted)
520
+ margin: margin for ranking loss
521
+ path_lambda: weight for path monotonicity loss
522
+
523
+ Returns:
524
+ total_loss: selectivity loss + path_lambda * monotonicity loss
525
+ loss_dict: breakdown of loss components
526
+ """
527
+ # Standard selectivity loss (unchanged)
528
+ select_loss = self.compute_selectivity_loss(pos_scores, neg_scores_list, margin)
529
+
530
+ # Path monotonicity loss: ensure Q increases with tau
531
+ loss_monotone = torch.tensor(0.0, device=pos_scores.device)
532
+ if path_scores_list and path_lambda > 0:
533
+ small_margin = 0.05
534
+ # Consecutive path frames should be monotonically increasing
535
+ for i in range(len(path_scores_list) - 1):
536
+ loss_monotone = loss_monotone + F.relu(
537
+ path_scores_list[i] - path_scores_list[i + 1] + small_margin
538
+ ).mean()
539
+ # Last path frame should be less than positive (holo) score
540
+ loss_monotone = loss_monotone + F.relu(
541
+ path_scores_list[-1] - pos_scores + margin
542
+ ).mean()
543
+ # First path frame should be greater than negative (apo) score
544
+ if neg_scores_list:
545
+ loss_monotone = loss_monotone + F.relu(
546
+ neg_scores_list[0] - path_scores_list[0] + small_margin
547
+ ).mean()
548
+
549
+ total = select_loss + path_lambda * loss_monotone
550
+ return total, {
551
+ 'loss_selectivity': select_loss.item(),
552
+ 'loss_path_monotone': loss_monotone.item(),
553
+ }
554
+
555
+ def compute_combined_loss(self, pos_scores, neg_scores_list, dockq_labels,
556
+ lambda_rank: float = 1.0):
557
+ """Combined Phase 1 + Phase 2 loss."""
558
+ # Regression loss on all scores (pos + neg get appropriate labels)
559
+ dockq_loss = self.compute_dockq_loss(pos_scores, dockq_labels)
560
+
561
+ # Selectivity loss
562
+ select_loss = self.compute_selectivity_loss(pos_scores, neg_scores_list)
563
+
564
+ return dockq_loss + lambda_rank * select_loss, {
565
+ 'loss_dockq': dockq_loss.item(),
566
+ 'loss_selectivity': select_loss.item(),
567
+ }
568
+
569
+
570
+ def build_model(config: dict) -> AlloDesignerScorer:
571
+ """Build the Q_theta scorer from a config dict."""
572
+ return AlloDesignerScorer(
573
+ node_dim=config.get('node_dim', 32),
574
+ edge_dim=config.get('edge_dim', 37),
575
+ hidden_dim=config.get('hidden_dim', 128),
576
+ n_layers=config.get('n_layers', 4),
577
+ n_heads=config.get('n_heads', 8),
578
+ dropout=config.get('dropout', 0.1),
579
+ backbone=config.get('backbone', 'transformer'),
580
+ pooling=config.get('pooling', 'meanmax'),
581
+ edge_update=config.get('edge_update', False),
582
+ esm_dim=config.get('esm_dim', 0),
583
+ esm_proj_dim=config.get('esm_proj_dim', 128),
584
+ esm_dropout=config.get('esm_dropout', 0.0),
585
+ )
code/requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ torch>=2.0.0
3
+ numpy>=1.24.0
4
+
5
+ # Protein structure
6
+ biopython>=1.80
7
+
8
+ # ML utilities
9
+ scipy>=1.10.0
10
+ scikit-learn>=1.3.0
11
+
12
+ # Experiment tracking
13
+ wandb>=0.12.0
14
+
15
+ # Config
16
+ pyyaml>=6.0
17
+
18
+ # Visualization
19
+ matplotlib>=3.7.0
20
+
21
+ # Optional accelerations
22
+ einops>=0.6.0
code/scripts/README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # `code/scripts/` — entry points
2
+
3
+ This public release ships only the inference and sampling code for Q_θ.
4
+
5
+ | File / dir | Purpose |
6
+ |---|---|
7
+ | `evaluate.py` | Score binders in a pre-built `*.pkl` test set with a Q_θ checkpoint; reports Spearman ρ, AUC, selectivity gap. |
8
+ | `rescore.py` | Re-score raw PDB designs (binder + holo + apo) with Q_θ. |
9
+ | `pxdesign_guidance/` | PXDesign-prior guidance with Q_θ (Langevin / SMC / TDS / classifier). |
10
+
11
+ Training, baseline scoring (ProteinMPNN / ESM-IF / Rosetta / DFIRE / energy panel), guidance for RFdiffusion / Proteina-ComplexA, and paper-figure aggregation are **not** shipped; the inference path above is the only supported surface for the public release.
12
+
13
+ ---
14
+
15
+ ## Deploying Q_θ with other base models
16
+
17
+ Q_θ provides two interfaces:
18
+
19
+ 1. **Re-ranker (best-of-K).** Given K candidate binders from any prior, score each with `S(Y) = Q_θ(X¹, Y) − Q_θ(X⁰, Y)` and pick the top. No gradient signal needed; the prior is unmodified.
20
+ 2. **Gradient signal for guidance.** Compute `∇_Y S(Y)` via `DifferentiableQTheta` (in `code/models/differentiable_features.py`) and inject into the prior's sampler (Langevin step, SMC weight, TDS twist, classifier guidance score).
21
+
22
+ The `pxdesign_guidance/` subdir is a worked example of interface (2) wrapping PXDesign. To plug Q_θ into another prior, mirror that pattern:
23
+
24
+ ### RFdiffusion
25
+
26
+ 1. Clone RFdiffusion: <https://github.com/RosettaCommons/RFdiffusion>.
27
+ 2. Follow its install + checkpoint download.
28
+ 3. In RFdiffusion's diffusion loop, after each denoising step, materialize the predicted backbone, build the holo/apo graph inputs expected by `DifferentiableQTheta`, and either:
29
+ - Apply a Langevin nudge: `x ← x + η · ∇_x S(x)`.
30
+ - Add a classifier-guidance term to the denoiser's `xt-1` mean: `μ' = μ + s · σ² · ∇_x log p(y|x)`, where `log p(y|x) ≈ S(x)` (Q_θ is treated as the log-likelihood of "is good binder").
31
+ 4. Reference template: `pxdesign_guidance/guided_pxdesign.py`.
32
+
33
+ ### Proteina-ComplexA
34
+
35
+ 1. Clone Proteina: <https://github.com/proteinabio/proteina-complexa> (or the released artifact).
36
+ 2. Use its ComplexA mode that emits binder coords conditioned on a receptor.
37
+ 3. Same plug pattern as RFdiffusion — wrap the sampler with `DifferentiableQTheta` for guidance, or run unguided and re-rank with `evaluate.py` / `rescore.py`.
38
+
39
+ ### Any backbone prior
40
+
41
+ The only contract Q_θ enforces:
42
+
43
+ - Receptor input is a PDB with holo and apo coordinates.
44
+ - Binder input is a PDB (or coords) with chain id distinct from receptor's.
45
+ - For guidance, expose differentiable Cα + backbone coordinates so `∇_x S(x)` flows.
46
+
47
+ See `code/models/differentiable_features.py:DifferentiableQTheta` for the exact interface (`load_receptor(holo_path, apo_path, …)`, `score(design_path, binder_chain, state)`, `.differentiable_score(coords, …)`).
48
+
49
+ ---
50
+
51
+ ## Why other guidance scripts aren't shipped
52
+
53
+ The RFdiffusion / Proteina guidance variants in our internal tree depend on those projects' un-released CIF formats and patched samplers; we don't want to ship modified third-party code. The PXDesign variants we do ship use only PXDesign's public API and are self-contained.
54
+
55
+ For citation / reproduction context, see the paper §4 (guidance methods).
code/scripts/evaluate.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation script for the trained Q_theta scorer.
3
+
4
+ Computes:
5
+ 1. Selectivity metrics (gap, ranking accuracy, AUC)
6
+ 2. DockQ correlation (Spearman/Pearson)
7
+ 3. Score distributions (violin plots)
8
+ 4. Best-of-K analysis (as function of K)
9
+ 5. Per-target breakdown
10
+
11
+ Usage:
12
+ python code/scripts/evaluate.py \
13
+ --target cam \
14
+ --checkpoint checkpoints/Q_theta_phase2.pt \
15
+ --data_dir data/processed \
16
+ --gpu 7
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ import argparse
22
+ import logging
23
+ import json
24
+ import numpy as np
25
+ import torch
26
+ import matplotlib
27
+ matplotlib.use('Agg')
28
+ import matplotlib.pyplot as plt
29
+ from scipy.stats import spearmanr, pearsonr
30
+ from sklearn.metrics import roc_auc_score, roc_curve
31
+
32
+ _CODE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
33
+ if _CODE_DIR not in sys.path:
34
+ sys.path.insert(0, _CODE_DIR)
35
+
36
+ from models.scorer import build_model
37
+ from data.dataset import TwoStateComplexDataset, collate_fn
38
+ from torch.utils.data import DataLoader
39
+
40
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ def compute_best_of_k(pos_scores, K_values=None, threshold=0.7):
45
+ """
46
+ Simulate best-of-K selection: what fraction of draws contain at least one good binder?
47
+ Assumes pos_scores are from a distribution of candidate binders for goal state X+.
48
+ """
49
+ if K_values is None:
50
+ K_values = [1, 2, 5, 10, 20, 50, 100]
51
+ results = {}
52
+ n = len(pos_scores)
53
+ n_trials = 1000
54
+
55
+ for K in K_values:
56
+ successes = 0
57
+ for _ in range(n_trials):
58
+ idxs = np.random.choice(n, size=min(K, n), replace=False)
59
+ best_score = pos_scores[idxs].max()
60
+ if best_score >= threshold:
61
+ successes += 1
62
+ results[K] = successes / n_trials
63
+
64
+ return results
65
+
66
+
67
+ def compute_selectivity_margin(pos_scores, neg_scores):
68
+ """Compute per-sample selectivity margin S_theta."""
69
+ eps = 1e-6
70
+ pos_logit = np.log(pos_scores.clip(eps, 1-eps) / (1-pos_scores).clip(eps))
71
+ neg_logit = np.log(neg_scores.clip(eps, 1-eps) / (1-neg_scores).clip(eps))
72
+ selectivity = pos_logit - np.log(np.exp(neg_logit) + 1e-8)
73
+ return selectivity
74
+
75
+
76
+ def plot_score_distributions(pos_scores, neg_scores, decoy_scores=None,
77
+ title='Score Distributions', outpath=None):
78
+ """Violin plot of score distributions for different complex types."""
79
+ fig, ax = plt.subplots(figsize=(8, 6))
80
+
81
+ data = [pos_scores, neg_scores]
82
+ labels = ['Positive\n(X+, Y)', 'Negative\n(X0, Y)']
83
+ colors = ['#2196F3', '#F44336']
84
+
85
+ if decoy_scores is not None and len(decoy_scores) > 0:
86
+ data.append(decoy_scores)
87
+ labels.append('Decoys\n(X+, Y~)')
88
+ colors.append('#FF9800')
89
+
90
+ parts = ax.violinplot(data, positions=range(len(data)), showmedians=True)
91
+ for i, (pc, c) in enumerate(zip(parts['bodies'], colors)):
92
+ pc.set_facecolor(c)
93
+ pc.set_alpha(0.7)
94
+
95
+ ax.set_xticks(range(len(data)))
96
+ ax.set_xticklabels(labels)
97
+ ax.set_ylabel('Q_theta Score', fontsize=12)
98
+ ax.set_title(title, fontsize=14)
99
+ ax.set_ylim(0, 1)
100
+ ax.axhline(0.5, color='gray', linestyle='--', alpha=0.5, label='Decision boundary')
101
+ ax.legend()
102
+
103
+ # Add mean + std annotations
104
+ for i, (d, c) in enumerate(zip(data, colors)):
105
+ ax.text(i, 0.02, f'μ={d.mean():.2f}\nσ={d.std():.2f}',
106
+ ha='center', fontsize=9, color=c)
107
+
108
+ plt.tight_layout()
109
+ if outpath:
110
+ plt.savefig(outpath, dpi=150, bbox_inches='tight')
111
+ logger.info(f"Saved plot to {outpath}")
112
+ plt.close()
113
+
114
+
115
+ def plot_roc_curve(labels, scores, title='ROC Curve', outpath=None):
116
+ """Plot ROC curve for positive vs negative classification."""
117
+ fpr, tpr, _ = roc_curve(labels, scores)
118
+ auc = roc_auc_score(labels, scores)
119
+
120
+ fig, ax = plt.subplots(figsize=(6, 6))
121
+ ax.plot(fpr, tpr, 'b-', lw=2, label=f'AUC = {auc:.3f}')
122
+ ax.plot([0, 1], [0, 1], 'k--', lw=1)
123
+ ax.set_xlabel('False Positive Rate')
124
+ ax.set_ylabel('True Positive Rate')
125
+ ax.set_title(title)
126
+ ax.legend()
127
+ plt.tight_layout()
128
+ if outpath:
129
+ plt.savefig(outpath, dpi=150, bbox_inches='tight')
130
+ plt.close()
131
+ return auc
132
+
133
+
134
+ def plot_best_of_k(results, outpath=None):
135
+ """Plot best-of-K success rate as a function of K."""
136
+ Ks = sorted(results.keys())
137
+ success_rates = [results[K] for K in Ks]
138
+
139
+ fig, ax = plt.subplots(figsize=(8, 5))
140
+ ax.semilogx(Ks, success_rates, 'b-o', lw=2, markersize=8)
141
+ ax.set_xlabel('K (number of candidates)', fontsize=12)
142
+ ax.set_ylabel('Success rate (best score > 0.7)', fontsize=12)
143
+ ax.set_title('Best-of-K Analysis', fontsize=14)
144
+ ax.set_ylim(0, 1.05)
145
+ ax.grid(True, alpha=0.3)
146
+ ax.axhline(0.8, color='red', linestyle='--', alpha=0.5, label='80% success')
147
+ ax.legend()
148
+ plt.tight_layout()
149
+ if outpath:
150
+ plt.savefig(outpath, dpi=150, bbox_inches='tight')
151
+ plt.close()
152
+
153
+
154
+ @torch.no_grad()
155
+ def evaluate(model, loader, device):
156
+ """Run model on a dataset and collect all predictions."""
157
+ model.eval()
158
+ all_scores, all_labels, all_types, all_pdbs = [], [], [], []
159
+
160
+ for batch in loader:
161
+ esm_feats = batch['esm_feats'].to(device) if 'esm_feats' in batch else None
162
+ scores = model(
163
+ batch['node_feats'].to(device),
164
+ batch['edge_feats'].to(device),
165
+ batch['node_mask'].to(device),
166
+ esm_feats=esm_feats,
167
+ )
168
+ all_scores.extend(scores.cpu().numpy().tolist())
169
+ all_labels.extend(batch['label'].numpy().tolist())
170
+ all_types.extend(batch['type'])
171
+ all_pdbs.extend(batch['pdb'])
172
+
173
+ return (np.array(all_scores), np.array(all_labels),
174
+ np.array(all_types), np.array(all_pdbs))
175
+
176
+
177
+ def main():
178
+ parser = argparse.ArgumentParser(description='Evaluate Allo-Designer Q_theta scorer')
179
+ parser.add_argument('--target', default='cam',
180
+ help='Target name (cam, abl, era, or any custom target with data in data/processed/)')
181
+ parser.add_argument('--all_targets', action='store_true',
182
+ help='Evaluate on all available targets and produce aggregated results')
183
+ parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint')
184
+ parser.add_argument('--data_dir', default='data/processed')
185
+ parser.add_argument('--split', choices=['val', 'test'], default='test')
186
+ parser.add_argument('--batch_size', type=int, default=32)
187
+ parser.add_argument('--gpu', type=int, default=7)
188
+ parser.add_argument('--outdir', default='results')
189
+ parser.add_argument('--bok_threshold', type=float, default=0.7,
190
+ help='Score threshold for best-of-K (default 0.7; use per-target value for calibrated results)')
191
+ parser.add_argument('--esm_dir', default=None,
192
+ help='Path to ESM-2 embedding cache (auto-detected at <data_dir>/esm2_embeddings if omitted)')
193
+ parser.add_argument('--no_wandb', action='store_true', help='(ignored; here for CLI compatibility)')
194
+ args = parser.parse_args()
195
+
196
+ # Auto-detect ESM dir under data_dir
197
+ if args.esm_dir is None:
198
+ cand = os.path.join(args.data_dir, 'esm2_embeddings')
199
+ if os.path.isdir(cand):
200
+ args.esm_dir = cand
201
+
202
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
203
+ os.makedirs(args.outdir, exist_ok=True)
204
+ os.makedirs(f'{args.outdir}/figures', exist_ok=True)
205
+ os.makedirs(f'{args.outdir}/tables', exist_ok=True)
206
+
207
+ # Load model
208
+ state = torch.load(args.checkpoint, map_location=device)
209
+ config = state.get('config', {})
210
+ model = build_model(config).to(device)
211
+ model.load_state_dict(state['model_state'])
212
+ logger.info(f"Loaded model from {args.checkpoint}")
213
+
214
+ # Load dataset
215
+ data_path = os.path.join(args.data_dir, args.target, f'{args.split}.pkl')
216
+ if not os.path.exists(data_path):
217
+ logger.error(f"Data not found: {data_path}")
218
+ sys.exit(1)
219
+
220
+ dataset = TwoStateComplexDataset(data_path, max_nodes=128,
221
+ esm_dir=args.esm_dir, target_name=args.target)
222
+ loader = DataLoader(
223
+ dataset, batch_size=args.batch_size, shuffle=False,
224
+ num_workers=2, collate_fn=collate_fn
225
+ )
226
+
227
+ # Run evaluation
228
+ logger.info(f"Evaluating on {len(dataset)} samples...")
229
+ scores, labels, types, pdbs = evaluate(model, loader, device)
230
+
231
+ # Separate by type
232
+ pos_mask = (types == 'positive')
233
+ neg_apo_mask = (types == 'negative_apo')
234
+ decoy_mask = np.array(['decoy' in t for t in types])
235
+
236
+ pos_scores = scores[pos_mask]
237
+ neg_scores = scores[neg_apo_mask]
238
+ decoy_scores = scores[decoy_mask]
239
+
240
+ logger.info(f"\n{'='*50}")
241
+ logger.info(f"Results for {args.target} ({args.split})")
242
+ logger.info(f"{'='*50}")
243
+ logger.info(f"Positive samples: {pos_mask.sum()}")
244
+ logger.info(f"Negative (apo) samples: {neg_apo_mask.sum()}")
245
+ logger.info(f"Decoy samples: {decoy_mask.sum()}")
246
+
247
+ # --- Core metrics ---
248
+ metrics = {}
249
+
250
+ # 1. Spearman correlation with DockQ labels
251
+ sp, p_val = spearmanr(scores, labels)
252
+ metrics['spearman_all'] = float(sp)
253
+ metrics['spearman_pval'] = float(p_val)
254
+ logger.info(f"\nSpearman(Q_theta, DockQ): {sp:.3f} (p={p_val:.3e})")
255
+
256
+ # 2. Selectivity gap (positive vs negative_apo)
257
+ if pos_mask.sum() > 0 and neg_apo_mask.sum() > 0:
258
+ gap = float(pos_scores.mean() - neg_scores.mean())
259
+ ranking_acc = float((pos_scores.mean() > neg_scores).mean() if len(neg_scores) > 0 else 0.5)
260
+ metrics['selectivity_gap'] = gap
261
+ metrics['pos_score_mean'] = float(pos_scores.mean())
262
+ metrics['neg_score_mean'] = float(neg_scores.mean())
263
+ metrics['pos_score_std'] = float(pos_scores.std())
264
+ metrics['neg_score_std'] = float(neg_scores.std())
265
+ logger.info(f"Selectivity gap (pos - neg): {gap:.3f}")
266
+ logger.info(f" Pos: {pos_scores.mean():.3f} ± {pos_scores.std():.3f}")
267
+ logger.info(f" Neg: {neg_scores.mean():.3f} ± {neg_scores.std():.3f}")
268
+
269
+ # 3. AUC for positive vs negative
270
+ if pos_mask.sum() > 0 and neg_apo_mask.sum() > 0:
271
+ pn_scores = np.concatenate([pos_scores, neg_scores])
272
+ pn_labels = np.concatenate([np.ones(len(pos_scores)), np.zeros(len(neg_scores))])
273
+ auc = roc_auc_score(pn_labels, pn_scores)
274
+ metrics['auc_pos_vs_neg'] = float(auc)
275
+ logger.info(f"AUC (pos vs neg_apo): {auc:.3f}")
276
+
277
+ # ROC curve
278
+ plot_roc_curve(
279
+ pn_labels, pn_scores,
280
+ title=f'ROC: Positive vs Negative Apo ({args.target.upper()})',
281
+ outpath=f'{args.outdir}/figures/roc_{args.target}_{args.split}.png'
282
+ )
283
+
284
+ # 4. AUC for quality classification (DockQ > 0.5)
285
+ binary = (labels > 0.5).astype(int)
286
+ if binary.sum() > 0 and binary.sum() < len(binary):
287
+ auc_quality = roc_auc_score(binary, scores)
288
+ metrics['auc_quality'] = float(auc_quality)
289
+ logger.info(f"AUC (quality>0.5): {auc_quality:.3f}")
290
+
291
+ # 5. Best-of-K analysis
292
+ if len(pos_scores) > 0:
293
+ bok_results = compute_best_of_k(pos_scores, K_values=[1, 2, 5, 10, 20, 50],
294
+ threshold=args.bok_threshold)
295
+ metrics['best_of_k'] = {str(K): float(v) for K, v in bok_results.items()}
296
+ logger.info(f"\nBest-of-K success rates:")
297
+ for K, rate in bok_results.items():
298
+ logger.info(f" K={K:3d}: {rate:.3f}")
299
+ plot_best_of_k(
300
+ bok_results,
301
+ outpath=f'{args.outdir}/figures/best_of_k_{args.target}_{args.split}.png'
302
+ )
303
+
304
+ # 6. Score distributions plot
305
+ plot_score_distributions(
306
+ pos_scores if len(pos_scores) > 0 else np.array([]),
307
+ neg_scores if len(neg_scores) > 0 else np.array([]),
308
+ decoy_scores if len(decoy_scores) > 0 else None,
309
+ title=f'Q_theta Score Distributions ({args.target.upper()})',
310
+ outpath=f'{args.outdir}/figures/score_dist_{args.target}_{args.split}.png'
311
+ )
312
+
313
+ # Save metrics
314
+ out_json = f'{args.outdir}/tables/eval_{args.target}_{args.split}.json'
315
+ with open(out_json, 'w') as f:
316
+ json.dump(metrics, f, indent=2)
317
+ logger.info(f"\nSaved metrics to {out_json}")
318
+
319
+ # Print summary table
320
+ logger.info(f"\n{'='*50}")
321
+ logger.info("SUMMARY TABLE")
322
+ logger.info(f"{'='*50}")
323
+ logger.info(f"{'Metric':<30} {'Value':>10}")
324
+ logger.info(f"{'-'*42}")
325
+ for k, v in metrics.items():
326
+ if isinstance(v, float):
327
+ logger.info(f"{k:<30} {v:>10.4f}")
328
+ logger.info(f"{'='*50}")
329
+
330
+
331
+ if __name__ == '__main__':
332
+ main()
code/scripts/pxdesign_guidance/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # PXDesign + Q_theta guidance integration
code/scripts/pxdesign_guidance/convert_cif_to_pdb.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Convert PXDesign CIF outputs to PDB format for evaluation pipeline.
3
+
4
+ PXDesign outputs .cif files with:
5
+ - Chain IDs like A0/B0 (multi-char, not PDB-compatible)
6
+ - Non-standard residue name 'xpb' for designed binder residues
7
+
8
+ This script converts them to PDB format with:
9
+ - Single-char chain IDs (A, B)
10
+ - Preserved residue names (xpb is kept; eval tools handle it)
11
+
12
+ Usage:
13
+ python code/scripts/pxdesign_guidance/convert_cif_to_pdb.py
14
+ """
15
+ import os
16
+ import sys
17
+ from glob import glob
18
+
19
+ from Bio.PDB import MMCIFParser, PDBIO, Select
20
+
21
+ _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
22
+ _PROJECT_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '../../..'))
23
+
24
+
25
+ class ChainRenamer(Select):
26
+ """Rename multi-char chain IDs to single-char for PDB format."""
27
+ def __init__(self, chain_map):
28
+ self.chain_map = chain_map
29
+
30
+ def accept_chain(self, chain):
31
+ return 1
32
+
33
+ def accept_residue(self, residue):
34
+ return 1
35
+
36
+ def accept_atom(self, atom):
37
+ return 1
38
+
39
+
40
+ def convert_cif_to_pdb(cif_path, pdb_path):
41
+ """Convert a single CIF file to PDB format."""
42
+ parser = MMCIFParser(QUIET=True)
43
+ structure = parser.get_structure('s', cif_path)
44
+ model = structure[0]
45
+
46
+ # Build chain ID mapping (A0->A, B0->B, etc.)
47
+ chain_map = {}
48
+ used_ids = set()
49
+ for chain in model.get_chains():
50
+ old_id = chain.id
51
+ # Use first character
52
+ new_id = old_id[0] if old_id else 'A'
53
+ # Avoid duplicates
54
+ while new_id in used_ids:
55
+ new_id = chr(ord(new_id) + 1)
56
+ used_ids.add(new_id)
57
+ chain_map[old_id] = new_id
58
+
59
+ # Rename chains and fix non-standard residue names
60
+ chains_to_rename = list(model.get_chains())
61
+ for chain in chains_to_rename:
62
+ old_id = chain.id
63
+ new_id = chain_map.get(old_id, old_id)
64
+ if old_id != new_id:
65
+ chain.id = new_id
66
+ # Rename 'xpb' residues to 'GLY' (backbone-only binder residues)
67
+ for residue in chain.get_residues():
68
+ if residue.resname.strip().lower() == 'xpb':
69
+ residue.resname = 'GLY'
70
+
71
+ # Write PDB
72
+ io = PDBIO()
73
+ io.set_structure(structure)
74
+ io.save(pdb_path)
75
+ return True
76
+
77
+
78
+ def convert_directory(src_dir, method_name):
79
+ """Convert all CIF files in a directory tree to PDB."""
80
+ cif_files = sorted(glob(os.path.join(src_dir, '**/*.cif'), recursive=True))
81
+ cif_files = [f for f in cif_files if 'sample' in os.path.basename(f).lower()]
82
+
83
+ if not cif_files:
84
+ print(f" No CIF files found in {src_dir}")
85
+ return 0
86
+
87
+ # Create converted_pdbs directory
88
+ converted_dir = os.path.join(src_dir, 'converted_pdbs')
89
+ os.makedirs(converted_dir, exist_ok=True)
90
+
91
+ n_converted = 0
92
+ for cif_path in cif_files:
93
+ basename = os.path.basename(cif_path).replace('.cif', '.pdb')
94
+ # For TDS/SMC with round subdirs, include round info
95
+ rel_path = os.path.relpath(cif_path, src_dir)
96
+ parts = rel_path.split(os.sep)
97
+ if any(p.startswith('round_') for p in parts):
98
+ round_part = [p for p in parts if p.startswith('round_')][0]
99
+ basename = f"{round_part}_{basename}"
100
+
101
+ pdb_path = os.path.join(converted_dir, basename)
102
+ try:
103
+ convert_cif_to_pdb(cif_path, pdb_path)
104
+ n_converted += 1
105
+ except Exception as e:
106
+ print(f" Failed {cif_path}: {e}")
107
+
108
+ print(f" Converted {n_converted}/{len(cif_files)} CIF -> PDB in {converted_dir}")
109
+ return n_converted
110
+
111
+
112
+ def main():
113
+ methods = {
114
+ 'pxdesign_guided': os.path.join(_PROJECT_DIR, 'results/pxdesign_guided'),
115
+ 'pxdesign_tds': os.path.join(_PROJECT_DIR, 'results/pxdesign_tds'),
116
+ 'pxdesign_smc': os.path.join(_PROJECT_DIR, 'results/pxdesign_smc'),
117
+ }
118
+ # Langevin outputs are already PDB (post-hoc refinement)
119
+
120
+ total = 0
121
+ for name, src_dir in methods.items():
122
+ print(f"\n{name}:")
123
+ if os.path.exists(src_dir):
124
+ total += convert_directory(src_dir, name)
125
+ else:
126
+ print(f" Directory not found: {src_dir}")
127
+
128
+ print(f"\nTotal converted: {total}")
129
+
130
+
131
+ if __name__ == '__main__':
132
+ main()
code/scripts/pxdesign_guidance/guided_pxdesign.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PXDesign + Q_theta Classifier Guidance.
3
+
4
+ Monkey-patches PXDesign's diffusion sampling loop to inject Q_theta selectivity
5
+ gradient after each denoising step. This steers the diffusion trajectory toward
6
+ binder backbones that are conformationally selective.
7
+
8
+ The patched diffusion loop:
9
+ x_denoised = denoise_net(x_noisy, t_hat, ...)
10
+ grad = ∇_{x_denoised}[Q(holo,Y) - Q(apo,Y)] # <-- INJECTED
11
+ x_denoised = x_denoised + scale(t) * grad # <-- INJECTED
12
+ delta = (x_noisy - x_denoised) / t_hat
13
+ x_l = x_noisy + eta * dt * delta
14
+
15
+ Usage:
16
+ python code/scripts/pxdesign_guidance/guided_pxdesign.py \
17
+ --input experiments/pxdesign_cam/output/cam_binder.json \
18
+ --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
19
+ --ref_holo data/pdbs/cam_holo/3CLN.pdb \
20
+ --ref_apo data/pdbs/cam_apo/1CFD.pdb \
21
+ --guidance_scale 1.0 \
22
+ --N_sample 50 --N_step 400 \
23
+ --gpu 0
24
+ """
25
+
26
+ import os
27
+ import sys
28
+ import argparse
29
+ import json
30
+ import logging
31
+ import time
32
+ import shutil
33
+ from typing import Callable, Optional, Union
34
+ from functools import partial
35
+
36
+ import numpy as np
37
+ import torch
38
+
39
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
40
+ logger = logging.getLogger(__name__)
41
+
42
+ # ── Paths ────────────────────────────────────────────────────────────────────
43
+ _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
44
+ _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
45
+ _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
46
+ _PXDESIGN_DIR = os.environ.get('PXDESIGN_DIR', '')
47
+
48
+ if _ALLO_CODE_DIR not in sys.path:
49
+ sys.path.insert(0, _ALLO_CODE_DIR)
50
+ if _PXDESIGN_DIR not in sys.path:
51
+ sys.path.insert(0, _PXDESIGN_DIR)
52
+
53
+
54
+ def guided_sample_diffusion(
55
+ denoise_net: Callable,
56
+ input_feature_dict: dict,
57
+ s_inputs: torch.Tensor,
58
+ s_trunk: torch.Tensor,
59
+ z_trunk: torch.Tensor,
60
+ noise_schedule: torch.Tensor,
61
+ N_sample: int = 1,
62
+ gamma0: float = 0.8,
63
+ gamma_min: float = 1.0,
64
+ noise_scale_lambda: float = 1.003,
65
+ step_scale_eta: Union[float, dict] = {"type": "const", "min": 1.5, "max": 1.5},
66
+ diffusion_chunk_size: Optional[int] = None,
67
+ inplace_safe: bool = False,
68
+ attn_chunk_size: Optional[int] = None,
69
+ # Guidance parameters (injected via partial)
70
+ guidance_module=None,
71
+ guidance_scale: float = 1.0,
72
+ guidance_start: float = 0.8,
73
+ guidance_end: float = 0.1,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Modified PXDesign sample_diffusion with Q_theta classifier guidance.
77
+
78
+ Same as original generator.sample_diffusion but with gradient injection
79
+ after each denoising step. The gradient is scaled by a schedule that
80
+ applies stronger guidance at high noise levels (early steps).
81
+ """
82
+ from protenix.model.utils import centre_random_augmentation
83
+
84
+ N_atom = input_feature_dict["atom_to_token_idx"].size(-1)
85
+ batch_shape = s_inputs.shape[:-2]
86
+ device = s_inputs.device
87
+ dtype = s_inputs.dtype
88
+
89
+ logger.info(f"Guided sampling: scale={guidance_scale}, "
90
+ f"window=[{guidance_end:.1f}, {guidance_start:.1f}]")
91
+
92
+ def _chunk_sample_diffusion_guided(chunk_n_sample, inplace_safe):
93
+ x_l = noise_schedule[0] * torch.randn(
94
+ size=(*batch_shape, chunk_n_sample, N_atom, 3),
95
+ device=device, dtype=dtype
96
+ )
97
+ T = len(noise_schedule)
98
+
99
+ for step_t, (c_tau_last, c_tau) in enumerate(
100
+ zip(noise_schedule[:-1], noise_schedule[1:])
101
+ ):
102
+ # Centre random augmentation
103
+ x_l = (
104
+ centre_random_augmentation(x_input_coords=x_l, N_sample=1)
105
+ .squeeze(dim=-3)
106
+ .to(dtype)
107
+ )
108
+
109
+ # Predictor step: add noise
110
+ gamma = float(gamma0) if c_tau > gamma_min else 0
111
+ t_hat = c_tau_last * (gamma + 1)
112
+ delta_noise_level = torch.sqrt(t_hat**2 - c_tau_last**2)
113
+ x_noisy = x_l + noise_scale_lambda * delta_noise_level * torch.randn(
114
+ size=x_l.shape, device=device, dtype=dtype
115
+ )
116
+
117
+ # Reshape t_hat for network
118
+ t_hat_tensor = (
119
+ t_hat.reshape((1,) * (len(batch_shape) + 1))
120
+ .expand(*batch_shape, chunk_n_sample)
121
+ .to(dtype)
122
+ )
123
+
124
+ # Denoise
125
+ x_denoised = denoise_net(
126
+ x_noisy=x_noisy,
127
+ t_hat_noise_level=t_hat_tensor,
128
+ input_feature_dict=input_feature_dict,
129
+ s_inputs=s_inputs,
130
+ s_trunk=s_trunk,
131
+ z_trunk=z_trunk,
132
+ chunk_size=attn_chunk_size,
133
+ inplace_safe=inplace_safe,
134
+ )
135
+
136
+ # ── Q_theta guidance injection ──────────────────────────────
137
+ if guidance_module is not None:
138
+ # Compute progress fraction (0=start/high noise, 1=end/low noise)
139
+ progress = step_t / (T - 1) if T > 1 else 1.0
140
+
141
+ # Apply guidance only within the specified window
142
+ if guidance_end <= (1.0 - progress) <= guidance_start:
143
+ # Handle batch dimensions
144
+ x_for_grad = x_denoised
145
+ if x_for_grad.dim() > 3:
146
+ x_for_grad = x_for_grad.squeeze(0)
147
+
148
+ # Scale: stronger at high noise, weaker near convergence
149
+ noise_fraction = 1.0 - progress
150
+ scale = guidance_scale * noise_fraction
151
+
152
+ try:
153
+ # Compute gradient for first sample (or all if small batch)
154
+ n_guide = min(chunk_n_sample, 4)
155
+ grad_accum = torch.zeros_like(x_for_grad)
156
+
157
+ for si in range(n_guide):
158
+ grad, margin = guidance_module.compute_guidance_gradient(
159
+ x_for_grad, input_feature_dict,
160
+ t_hat=t_hat, sample_idx=si
161
+ )
162
+ grad_accum[si] = grad[si] if grad.shape[0] > si else grad[0]
163
+
164
+ # Broadcast gradient to remaining samples
165
+ if n_guide < chunk_n_sample and n_guide > 0:
166
+ avg_grad = grad_accum[:n_guide].mean(dim=0, keepdim=True)
167
+ grad_accum[n_guide:] = avg_grad.expand(
168
+ chunk_n_sample - n_guide, -1, -1)
169
+
170
+ # Normalize gradient to prevent explosion
171
+ grad_norm = grad_accum.norm(dim=-1, keepdim=True).clamp(min=1e-8)
172
+ grad_normalized = grad_accum / grad_norm
173
+ avg_norm = grad_norm.mean().item()
174
+
175
+ # Apply guidance
176
+ if avg_norm > 1e-6:
177
+ # Scale by average gradient magnitude to keep step size reasonable
178
+ x_denoised = x_denoised + scale * avg_norm * grad_normalized
179
+
180
+ if step_t % 50 == 0:
181
+ logger.info(
182
+ f" Step {step_t}/{T}: margin={margin:.3f}, "
183
+ f"grad_norm={avg_norm:.4f}, scale={scale:.3f}")
184
+ except Exception as e:
185
+ if step_t % 100 == 0:
186
+ logger.debug(f" Step {step_t}: guidance failed: {e}")
187
+ # ── End guidance ────────────────────────────────────────────
188
+
189
+ # Euler step
190
+ delta = (x_noisy - x_denoised) / t_hat_tensor[..., None, None]
191
+ dt = c_tau - t_hat_tensor
192
+ if isinstance(step_scale_eta, float):
193
+ eta = step_scale_eta
194
+ elif step_scale_eta["type"] == "const":
195
+ assert step_scale_eta["min"] == step_scale_eta["max"]
196
+ eta = step_scale_eta["min"]
197
+ else:
198
+ eta_min, eta_max = step_scale_eta["min"], step_scale_eta["max"]
199
+ if step_scale_eta["type"] == "linear":
200
+ eta = eta_min + (eta_max - eta_min) * (step_t / T)
201
+ elif step_scale_eta["type"] == "poly":
202
+ eta = eta_min + (eta_max - eta_min) * (step_t / T) ** 2
203
+ elif step_scale_eta["type"] == "cos":
204
+ eta = eta_min + 0.5 * (eta_max - eta_min) * (
205
+ 1 - np.cos(np.pi * step_t / T))
206
+ elif step_scale_eta["type"] == "piecewise":
207
+ eta = eta_min if step_t / T < 0.5 else eta_max
208
+ elif step_scale_eta["type"] == "piecewise_65":
209
+ eta = eta_min if step_t / T < 0.65 else eta_max
210
+ elif step_scale_eta["type"] == "piecewise_70":
211
+ eta = eta_min if step_t / T < 0.70 else eta_max
212
+ else:
213
+ raise ValueError("Unsupported eta schedule!")
214
+ x_l = x_noisy + eta * dt[..., None, None] * delta
215
+
216
+ return x_l
217
+
218
+ # Chunked sampling
219
+ if diffusion_chunk_size is None:
220
+ x_l = _chunk_sample_diffusion_guided(N_sample, inplace_safe=inplace_safe)
221
+ else:
222
+ x_l = []
223
+ no_chunks = N_sample // diffusion_chunk_size + (
224
+ N_sample % diffusion_chunk_size != 0)
225
+ for i in range(no_chunks):
226
+ chunk_n_sample = (
227
+ diffusion_chunk_size
228
+ if i < no_chunks - 1
229
+ else N_sample - i * diffusion_chunk_size
230
+ )
231
+ chunk_x_l = _chunk_sample_diffusion_guided(
232
+ chunk_n_sample, inplace_safe=inplace_safe)
233
+ x_l.append(chunk_x_l)
234
+ x_l = torch.cat(x_l, -3)
235
+
236
+ return x_l
237
+
238
+
239
+ def run_guided_pxdesign(args):
240
+ """Run PXDesign with Q_theta classifier guidance."""
241
+ if 'CUDA_VISIBLE_DEVICES' not in os.environ:
242
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
243
+
244
+ # Import PXDesign components
245
+ from pxdesign.runner.inference import InferenceRunner, main as pxdesign_main
246
+ from pxdesign.utils.infer import (
247
+ get_configs, convert_to_bioassembly_dict, download_inference_cache, derive_seed
248
+ )
249
+ from pxdesign.utils.inputs import process_input_file
250
+ from protenix.config import save_config
251
+ from protenix.utils.seed import seed_everything
252
+ from protenix.utils.torch_utils import autocasting_disable_decorator
253
+
254
+ from qtheta_pxdesign import QThetaPXDesignGuidance
255
+
256
+ # Set up output directory
257
+ outdir = args.outdir if os.path.isabs(args.outdir) else os.path.join(_ALLO_ROOT, args.outdir)
258
+ os.makedirs(outdir, exist_ok=True)
259
+
260
+ # Build PXDesign CLI arguments
261
+ pxdesign_argv = [
262
+ '--dump_dir', outdir,
263
+ '--input', args.input,
264
+ '--dtype', 'bf16',
265
+ '--N_sample', str(args.N_sample),
266
+ '--N_step', str(args.N_step),
267
+ ]
268
+
269
+ configs = get_configs(pxdesign_argv)
270
+ configs.input_json_path = process_input_file(
271
+ configs.input_json_path, out_dir=outdir)
272
+ download_inference_cache(configs)
273
+
274
+ # Convert inputs
275
+ save_config(configs, os.path.join(outdir, "config.yaml"))
276
+ with open(configs.input_json_path, "r") as f:
277
+ orig_inputs = json.load(f)
278
+ for x in orig_inputs:
279
+ convert_to_bioassembly_dict(x, outdir)
280
+ configs.input_json_path = os.path.join(outdir, "input_tasks.json")
281
+ with open(configs.input_json_path, "w") as f:
282
+ json.dump(orig_inputs, f, indent=4)
283
+
284
+ # Create runner
285
+ runner = InferenceRunner(configs)
286
+
287
+ # Initialize Q_theta guidance
288
+ guidance = QThetaPXDesignGuidance(
289
+ checkpoint=args.qtheta_checkpoint if os.path.isabs(args.qtheta_checkpoint) else os.path.join(_ALLO_ROOT, args.qtheta_checkpoint),
290
+ ref_holo=args.ref_holo if os.path.isabs(args.ref_holo) else os.path.join(_ALLO_ROOT, args.ref_holo),
291
+ ref_apo=args.ref_apo if os.path.isabs(args.ref_apo) else os.path.join(_ALLO_ROOT, args.ref_apo),
292
+ ref_chain=args.ref_chain,
293
+ device='cuda:0', # After CUDA_VISIBLE_DEVICES remapping
294
+ esm_target=args.esm_target,
295
+ )
296
+
297
+ # Monkey-patch the sample_diffusion function
298
+ from pxdesign.model import generator as pxdesign_generator
299
+ import pxdesign.model.pxdesign as pxdesign_model
300
+
301
+ # Create guided version with guidance params bound
302
+ guided_fn = partial(
303
+ guided_sample_diffusion,
304
+ guidance_module=guidance,
305
+ guidance_scale=args.guidance_scale,
306
+ guidance_start=args.guidance_start,
307
+ guidance_end=args.guidance_end,
308
+ )
309
+
310
+ # Patch the module-level function in generator.py
311
+ pxdesign_generator.sample_diffusion = guided_fn
312
+
313
+ # CRITICAL: pxdesign.py does `from pxdesign.model.generator import sample_diffusion`
314
+ # which creates a local binding in pxdesign.model.pxdesign namespace.
315
+ # We must patch that local binding too, otherwise the ProtenixDesign.sample_diffusion()
316
+ # method will still call the original unpatched function.
317
+ pxdesign_model.sample_diffusion = guided_fn
318
+
319
+ logger.info("PXDesign diffusion loop patched with Q_theta guidance")
320
+
321
+ # Run inference
322
+ seeds = [derive_seed(time.time_ns())] if not configs.seeds else configs.seeds
323
+ for seed in seeds:
324
+ logger.info(f"Running guided inference with seed {seed}")
325
+ seed_everything(seed=seed, deterministic=False)
326
+ runner._inference(seed)
327
+
328
+ # Score all generated designs
329
+ logger.info("Scoring generated designs...")
330
+ from glob import glob
331
+
332
+ pdb_dir = outdir
333
+ pdbs = []
334
+ for ext in ('*.pdb', '*.cif'):
335
+ pdbs.extend(glob(os.path.join(pdb_dir, '**/' + ext), recursive=True))
336
+ pdbs = sorted([p for p in pdbs if 'sample' in os.path.basename(p).lower()])
337
+
338
+ results = []
339
+ for i, pdb_path in enumerate(pdbs):
340
+ design_id = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
341
+ result = guidance.score_design(pdb_path)
342
+ if result is not None:
343
+ result['design_id'] = design_id
344
+ result['pdb_path'] = pdb_path
345
+ results.append(result)
346
+ logger.info(
347
+ f"[{i+1}/{len(pdbs)}] {design_id}: "
348
+ f"Q+={result['q_holo']:.3f} Q-={result['q_apo']:.3f} "
349
+ f"S={result['margin']:+.3f}")
350
+
351
+ # Save results
352
+ if results:
353
+ results.sort(key=lambda x: x['margin'], reverse=True)
354
+ margins = np.array([r['margin'] for r in results])
355
+
356
+ summary = {
357
+ 'method': 'PXDesign + Classifier Guidance',
358
+ 'n_designs': len(results),
359
+ 'guidance_scale': args.guidance_scale,
360
+ 'guidance_window': [args.guidance_end, args.guidance_start],
361
+ 'margin_mean': float(margins.mean()),
362
+ 'margin_std': float(margins.std()),
363
+ 'frac_positive': float((margins > 0).mean()),
364
+ 'q_holo_mean': float(np.mean([r['q_holo'] for r in results])),
365
+ 'q_apo_mean': float(np.mean([r['q_apo'] for r in results])),
366
+ }
367
+
368
+ with open(os.path.join(outdir, 'guided_scores.json'), 'w') as f:
369
+ json.dump(results, f, indent=2)
370
+ with open(os.path.join(outdir, 'guided_summary.json'), 'w') as f:
371
+ json.dump(summary, f, indent=2)
372
+
373
+ logger.info(f"\n{'='*60}")
374
+ logger.info(f"PXDesign + Classifier Guidance Results ({len(results)} designs)")
375
+ logger.info(f" Margin: {margins.mean():.3f} ± {margins.std():.3f}")
376
+ logger.info(f" Fraction S > 0: {(margins > 0).mean():.1%}")
377
+ logger.info(f" Q(holo) mean: {summary['q_holo_mean']:.3f}")
378
+ logger.info(f"{'='*60}")
379
+
380
+
381
+ def main():
382
+ parser = argparse.ArgumentParser(description='PXDesign + Q_theta Classifier Guidance')
383
+ parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json',
384
+ help='PXDesign input JSON')
385
+ parser.add_argument('--qtheta_checkpoint',
386
+ default='results/checkpoints_cam_v3/best_phase2.pt')
387
+ parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
388
+ parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
389
+ parser.add_argument('--ref_chain', default='A')
390
+ parser.add_argument('--guidance_scale', type=float, default=1.0,
391
+ help='Guidance gradient scale')
392
+ parser.add_argument('--guidance_start', type=float, default=0.8,
393
+ help='Start guidance at this noise fraction (high noise)')
394
+ parser.add_argument('--guidance_end', type=float, default=0.1,
395
+ help='Stop guidance at this noise fraction (low noise)')
396
+ parser.add_argument('--N_sample', type=int, default=50)
397
+ parser.add_argument('--N_step', type=int, default=400)
398
+ parser.add_argument('--gpu', type=int, default=0)
399
+ parser.add_argument('--outdir', default='results/pxdesign_guided')
400
+ parser.add_argument('--esm_target', default='cam',
401
+ help='Subdir under data/esm2_embeddings (e.g., adk, cam)')
402
+ args = parser.parse_args()
403
+
404
+ run_guided_pxdesign(args)
405
+
406
+
407
+ if __name__ == '__main__':
408
+ main()
code/scripts/pxdesign_guidance/iterative_refinement.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Iterative Refinement via Langevin Noise-Refine Cycles.
3
+
4
+ Inspired by ProDifEvo (Uehara et al., ICML 2025): repeatedly perturb and
5
+ refine structures through Q_theta gradient ascent. Each cycle adds noise
6
+ for diversity, then refines with Langevin dynamics toward higher selectivity.
7
+
8
+ This allows designs to escape local optima and explore better selectivity
9
+ regions that single-shot generation cannot reach.
10
+
11
+ Pipeline:
12
+ 1. Start from existing PXDesign outputs (seed structures)
13
+ 2. Align binder to reference receptor frames
14
+ 3. Run Langevin refinement with Q_theta gradient
15
+ 4. Score the refined output
16
+ 5. Repeat for K iterations, keeping best designs
17
+
18
+ Usage:
19
+ python code/scripts/pxdesign_guidance/iterative_refinement.py \
20
+ --input_dir results/pxdesign_guided/converted_pdbs \
21
+ --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
22
+ --ref_holo data/pdbs/cam_holo/3CLN.pdb \
23
+ --ref_apo data/pdbs/cam_apo/1CFD.pdb \
24
+ --n_iterations 3 --n_designs 10 \
25
+ --gpu 6
26
+ """
27
+ import os
28
+ import sys
29
+ import json
30
+ import logging
31
+ import numpy as np
32
+ import torch
33
+ from glob import glob
34
+
35
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
36
+ logger = logging.getLogger(__name__)
37
+
38
+ _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
39
+ _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
40
+ _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
41
+
42
+ if _ALLO_CODE_DIR not in sys.path:
43
+ sys.path.insert(0, _ALLO_CODE_DIR)
44
+
45
+
46
+ def score_designs(pdb_paths, guidance):
47
+ """Score a list of PDB paths with Q_theta."""
48
+ results = []
49
+ for pdb_path in pdb_paths:
50
+ result = guidance.score_design(pdb_path)
51
+ if result is not None:
52
+ result['pdb_path'] = pdb_path
53
+ result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
54
+ results.append(result)
55
+ return results
56
+
57
+
58
+ def run_langevin_cycle(pdb_paths, guidance, n_steps=50, step_size=0.005,
59
+ iteration=0, outdir='results/iterative_refinement'):
60
+ """Run Langevin refinement cycle on binder backbone coords using Q_theta.
61
+
62
+ Uses guidance.dq (DifferentiableQTheta) for differentiable scoring.
63
+ Aligns binder to holo/apo reference frames for dual-state scoring.
64
+ """
65
+ from utils.pdb_utils import (load_structure, get_residues, get_backbone_coords,
66
+ get_aa_indices, align_structures)
67
+
68
+ refined_results = []
69
+ os.makedirs(outdir, exist_ok=True)
70
+
71
+ for pdb_path in pdb_paths:
72
+ try:
73
+ model = load_structure(pdb_path)
74
+ chains = {c.id: c for c in model.get_chains()}
75
+
76
+ binder_chain = None
77
+ for cid in sorted(chains.keys()):
78
+ if cid != 'A':
79
+ binder_chain = cid
80
+ break
81
+ if binder_chain is None:
82
+ continue
83
+
84
+ rec_res = get_residues(chains['A'])
85
+ if not rec_res:
86
+ rec_res = get_residues(chains['A'], only_standard=False)
87
+ binder_res = get_residues(chains[binder_chain])
88
+ if not binder_res:
89
+ binder_res = get_residues(chains[binder_chain], only_standard=False)
90
+ if len(binder_res) < 5:
91
+ continue
92
+
93
+ binder_coords, binder_mask = get_backbone_coords(binder_res)
94
+ rec_coords, _ = get_backbone_coords(rec_res)
95
+
96
+ try:
97
+ aa_idx = get_aa_indices(binder_res)
98
+ except Exception:
99
+ aa_idx = np.zeros(len(binder_res), dtype=np.int64)
100
+
101
+ # Compute alignment transforms
102
+ rec_ca = rec_coords[:, 1, :]
103
+ ref_holo_ca = guidance.ref_holo_ca.cpu().numpy()
104
+ ref_apo_ca = guidance.ref_apo_ca.cpu().numpy()
105
+ n_h = min(len(rec_ca), len(ref_holo_ca))
106
+ n_a = min(len(rec_ca), len(ref_apo_ca))
107
+ if n_h < 5 or n_a < 5:
108
+ continue
109
+
110
+ _, R_h = align_structures(rec_ca[:n_h], ref_holo_ca[:n_h])
111
+ center_h = rec_ca[:n_h].mean(0)
112
+ ref_center_h = ref_holo_ca[:n_h].mean(0)
113
+ aligned_holo = (binder_coords.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h
114
+ aligned_holo = aligned_holo.reshape(-1, 4, 3)
115
+
116
+ _, R_a = align_structures(rec_ca[:n_a], ref_apo_ca[:n_a])
117
+ center_a = rec_ca[:n_a].mean(0)
118
+ ref_center_a = ref_apo_ca[:n_a].mean(0)
119
+
120
+ device = guidance.device
121
+ dq = guidance.dq
122
+
123
+ # Precompute alignment tensors (detached constants)
124
+ R_h_t = torch.from_numpy(R_h).float().to(device)
125
+ R_a_t = torch.from_numpy(R_a).float().to(device)
126
+ center_h_t = torch.from_numpy(center_h).float().to(device)
127
+ ref_center_h_t = torch.from_numpy(ref_center_h).float().to(device)
128
+ center_a_t = torch.from_numpy(center_a).float().to(device)
129
+ ref_center_a_t = torch.from_numpy(ref_center_a).float().to(device)
130
+
131
+ # Work in holo-aligned frame
132
+ coords_t = torch.from_numpy(aligned_holo.copy()).float().to(device)
133
+ mask_t = torch.from_numpy(binder_mask).bool().to(device)
134
+ aa_t = torch.from_numpy(aa_idx).long().to(device)
135
+
136
+ # Add noise for diversity (constant, small)
137
+ noise = torch.randn_like(coords_t) * 0.05
138
+ coords_t = coords_t + noise
139
+
140
+ best_margin = -float('inf')
141
+ best_coords = coords_t.clone()
142
+
143
+ def project_bond_lengths(coords, target_dist=3.8, n_iters=5):
144
+ """Project CA-CA distances to target_dist via SHAKE-like iteration."""
145
+ with torch.no_grad():
146
+ for _ in range(n_iters):
147
+ ca = coords[:, 1, :].clone()
148
+ for i in range(len(ca) - 1):
149
+ delta = ca[i+1] - ca[i]
150
+ d = delta.norm()
151
+ if d < 1e-6:
152
+ continue
153
+ correction = 0.5 * (d - target_dist) / d * delta
154
+ coords[i, :, :] += correction.unsqueeze(0)
155
+ coords[i+1, :, :] -= correction.unsqueeze(0)
156
+ return coords
157
+
158
+ for step in range(n_steps):
159
+ coords_t = coords_t.detach().requires_grad_(True)
160
+
161
+ with torch.enable_grad():
162
+ q_holo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t,
163
+ receptor_label='holo')
164
+
165
+ # Transform holo-frame → original → apo-frame
166
+ flat_t = coords_t.reshape(-1, 3)
167
+ original = (flat_t - ref_center_h_t) @ R_h_t + center_h_t
168
+ apo_aligned = (original - center_a_t) @ R_a_t.T + ref_center_a_t
169
+ coords_apo = apo_aligned.reshape(-1, 4, 3)
170
+
171
+ q_apo = dq.score(coords_apo, mask_t, binder_aa_idx=aa_t,
172
+ receptor_label='apo')
173
+ margin = q_holo - q_apo
174
+ margin.backward()
175
+
176
+ grad = coords_t.grad
177
+ if grad is None or torch.isnan(grad).any():
178
+ continue
179
+
180
+ grad_norm = grad.norm().clamp(min=1e-8)
181
+
182
+ if margin.item() > best_margin:
183
+ best_margin = margin.item()
184
+ best_coords = coords_t.detach().clone()
185
+
186
+ if step % 10 == 0:
187
+ logger.info(f" [{os.path.basename(pdb_path)}] Step {step}: "
188
+ f"Q+={q_holo.item():.3f} Q-={q_apo.item():.3f} "
189
+ f"S={margin.item():.3f} |g|={grad_norm.item():.4f}")
190
+
191
+ with torch.no_grad():
192
+ coords_t = coords_t + step_size * grad / grad_norm
193
+ # Annealed Langevin noise (small)
194
+ noise_scale = step_size * 0.05 * (1 - step / n_steps)
195
+ coords_t = coords_t + noise_scale * torch.randn_like(coords_t)
196
+ # Hard projection: enforce CA-CA = 3.8A
197
+ coords_t = project_bond_lengths(coords_t)
198
+
199
+ # Write refined backbone PDB
200
+ final_coords = best_coords.detach().cpu().numpy()
201
+ basename = os.path.basename(pdb_path).replace('.pdb', '')
202
+ out_path = os.path.join(outdir, f'{basename}_iter{iteration}.pdb')
203
+
204
+ atom_names = [' N ', ' CA ', ' C ', ' O ']
205
+ elements = ['N', 'C', 'C', 'O']
206
+ with open(out_path, 'w') as f:
207
+ atom_num = 1
208
+ for i in range(len(final_coords)):
209
+ if not binder_mask[i]:
210
+ continue
211
+ for j, (aname, elem) in enumerate(zip(atom_names, elements)):
212
+ x, y, z = final_coords[i, j]
213
+ f.write(f"ATOM {atom_num:5d} {aname} ALA B{i+1:4d} "
214
+ f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 {elem}\n")
215
+ atom_num += 1
216
+ f.write("END\n")
217
+
218
+ # Score refined design
219
+ result = guidance.score_design(out_path)
220
+ if result is not None:
221
+ result['pdb_path'] = out_path
222
+ result['iteration'] = iteration
223
+ result['best_margin_during_opt'] = best_margin
224
+ refined_results.append(result)
225
+ logger.info(f" -> Refined: S={result['margin']:.3f} "
226
+ f"(best during opt: {best_margin:.3f})")
227
+
228
+ except Exception as e:
229
+ logger.warning(f"Failed to refine {pdb_path}: {e}")
230
+ import traceback
231
+ traceback.print_exc()
232
+
233
+ return refined_results
234
+
235
+
236
+ def main():
237
+ import argparse
238
+ parser = argparse.ArgumentParser()
239
+ parser.add_argument('--input_dir',
240
+ default='results/pxdesign_guided/converted_pdbs')
241
+ parser.add_argument('--qtheta_checkpoint',
242
+ default='results/checkpoints_cam_v3/best_phase2.pt')
243
+ parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
244
+ parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
245
+ parser.add_argument('--ref_chain', default='A')
246
+ parser.add_argument('--n_iterations', type=int, default=4,
247
+ help='Number of refine cycles')
248
+ parser.add_argument('--n_designs', type=int, default=20,
249
+ help='Number of designs to refine')
250
+ parser.add_argument('--n_steps', type=int, default=50,
251
+ help='Langevin steps per iteration')
252
+ parser.add_argument('--step_size', type=float, default=0.005)
253
+ parser.add_argument('--gpu', type=int, default=6)
254
+ parser.add_argument('--outdir', default='results/iterative_refinement')
255
+ args = parser.parse_args()
256
+
257
+ os.chdir(_ALLO_ROOT)
258
+
259
+ from scripts.pxdesign_guidance.qtheta_pxdesign import QThetaPXDesignGuidance
260
+
261
+ outdir = args.outdir
262
+ os.makedirs(outdir, exist_ok=True)
263
+
264
+ # Initialize scorer
265
+ guidance = QThetaPXDesignGuidance(
266
+ checkpoint=args.qtheta_checkpoint,
267
+ ref_holo=args.ref_holo,
268
+ ref_apo=args.ref_apo,
269
+ ref_chain=args.ref_chain,
270
+ device=f'cuda:{args.gpu}',
271
+ )
272
+ guidance._lazy_init()
273
+
274
+ # Collect input designs
275
+ input_pdbs = sorted(glob(os.path.join(args.input_dir, '*.pdb')))[:args.n_designs]
276
+ logger.info(f"Selected {len(input_pdbs)} designs for iterative refinement")
277
+
278
+ # Score initial designs
279
+ logger.info("Scoring initial designs...")
280
+ initial_results = score_designs(input_pdbs, guidance)
281
+ initial_margins = [r['margin'] for r in initial_results]
282
+ logger.info(f"Initial: S={np.mean(initial_margins):.3f}\u00b1{np.std(initial_margins):.3f}")
283
+
284
+ all_iteration_results = {'initial': initial_results}
285
+
286
+ # Iterative refinement
287
+ current_pdbs = input_pdbs
288
+ for iteration in range(args.n_iterations):
289
+ logger.info(f"\n{'='*50}")
290
+ logger.info(f"Iteration {iteration + 1}/{args.n_iterations}")
291
+ logger.info(f"{'='*50}")
292
+
293
+ iter_results = run_langevin_cycle(
294
+ current_pdbs, guidance,
295
+ n_steps=args.n_steps,
296
+ step_size=args.step_size,
297
+ iteration=iteration,
298
+ outdir=outdir,
299
+ )
300
+
301
+ if iter_results:
302
+ margins = [r['margin'] for r in iter_results]
303
+ logger.info(f"Iteration {iteration}: S={np.mean(margins):.3f}\u00b1{np.std(margins):.3f}")
304
+ all_iteration_results[f'iteration_{iteration}'] = iter_results
305
+
306
+ # Use refined designs as input for next iteration
307
+ current_pdbs = [r['pdb_path'] for r in iter_results]
308
+
309
+ # Summary
310
+ logger.info(f"\n{'='*60}")
311
+ logger.info("Iterative Refinement Summary")
312
+ logger.info(f"{'='*60}")
313
+ for key, results in all_iteration_results.items():
314
+ if results:
315
+ margins = [r['margin'] for r in results]
316
+ logger.info(f"{key:15s}: S={np.mean(margins):.3f}\u00b1{np.std(margins):.3f}, "
317
+ f"N={len(results)}, S>0={100*np.mean([m>0 for m in margins]):.0f}%")
318
+
319
+ # Save results
320
+ out_path = os.path.join(outdir, 'iterative_refinement_summary.json')
321
+ summary = {}
322
+ for key, results in all_iteration_results.items():
323
+ if results:
324
+ margins = [r['margin'] for r in results]
325
+ summary[key] = {
326
+ 'n': len(results),
327
+ 'margin_mean': float(np.mean(margins)),
328
+ 'margin_std': float(np.std(margins)),
329
+ 'margin_max': float(np.max(margins)),
330
+ 'frac_positive': float(np.mean([m > 0 for m in margins])),
331
+ }
332
+ with open(out_path, 'w') as f:
333
+ json.dump(summary, f, indent=2)
334
+ logger.info(f"\nSaved to {out_path}")
335
+
336
+
337
+ if __name__ == '__main__':
338
+ main()
code/scripts/pxdesign_guidance/langevin_pxdesign.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PXDesign + Langevin Refinement.
3
+
4
+ Post-hoc gradient ascent on existing PXDesign binder backbones using Q_theta
5
+ selectivity gradient:
6
+ x_{t+1} = x_t + η · ∇_x[Q(holo,Y) - Q(apo,Y)] + √(2η) · ε
7
+
8
+ Takes PXDesign outputs (which have full sidechains), extracts backbone coords,
9
+ refines them via Langevin dynamics, and outputs refined backbone-only PDBs.
10
+
11
+ Usage:
12
+ python code/scripts/pxdesign_guidance/langevin_pxdesign.py \
13
+ --designs_dir experiments/pxdesign_cam/output/ \
14
+ --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
15
+ --ref_holo data/pdbs/cam_holo/3CLN.pdb \
16
+ --ref_apo data/pdbs/cam_apo/1CFD.pdb \
17
+ --n_steps 100 --step_size 0.01 \
18
+ --gpu 0
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import argparse
24
+ import json
25
+ import logging
26
+ import numpy as np
27
+ import torch
28
+ from glob import glob
29
+
30
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
31
+ logger = logging.getLogger(__name__)
32
+
33
+ _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
34
+ _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
35
+ _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
36
+
37
+ if _ALLO_CODE_DIR not in sys.path:
38
+ sys.path.insert(0, _ALLO_CODE_DIR)
39
+
40
+ from utils.pdb_utils import (
41
+ load_structure, get_residues, get_backbone_coords,
42
+ get_aa_indices, align_structures
43
+ )
44
+
45
+
46
+ def write_backbone_pdb(coords, mask, out_path, chain='B'):
47
+ """Write backbone PDB (N, CA, C, O) from [N, 4, 3] numpy coords."""
48
+ atom_names = [' N ', ' CA ', ' C ', ' O ']
49
+ elements = ['N', 'C', 'C', 'O']
50
+ with open(out_path, 'w') as f:
51
+ atom_idx = 1
52
+ for i in range(len(coords)):
53
+ if not mask[i]:
54
+ continue
55
+ for j, (aname, elem) in enumerate(zip(atom_names, elements)):
56
+ x, y, z = coords[i, j, :]
57
+ f.write(
58
+ f"ATOM {atom_idx:5d} {aname:4s} ALA {chain}{i+1:4d} "
59
+ f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 {elem}\n"
60
+ )
61
+ atom_idx += 1
62
+ f.write("END\n")
63
+
64
+
65
+ def find_pxdesign_pdbs(designs_dir):
66
+ """Find all PXDesign output PDB files."""
67
+ pdbs = sorted(glob(os.path.join(designs_dir, '**/*.pdb'), recursive=True))
68
+ pdbs = [p for p in pdbs if 'sample' in os.path.basename(p).lower()
69
+ or 'design' in os.path.basename(p).lower()
70
+ or 'rank' in os.path.basename(p).lower()]
71
+ if not pdbs:
72
+ pdbs = sorted(glob(os.path.join(designs_dir, '**/*.pdb'), recursive=True))
73
+ return pdbs
74
+
75
+
76
+ def langevin_refine(dq, binder_coords_init, binder_mask, binder_aa_idx,
77
+ rec_coords, rec_mask, ref_holo_ca, ref_apo_ca,
78
+ n_steps=100, step_size=0.01, noise_scale=0.0,
79
+ device='cuda:0'):
80
+ """
81
+ Langevin refinement of binder backbone coordinates.
82
+
83
+ Args:
84
+ dq: DifferentiableQTheta scorer
85
+ binder_coords_init: [N_binder, 4, 3] numpy — initial binder backbone
86
+ binder_mask: [N_binder] numpy bool
87
+ binder_aa_idx: [N_binder] numpy int
88
+ rec_coords: [N_rec, 4, 3] numpy — receptor backbone
89
+ rec_mask: [N_rec] numpy bool
90
+ ref_holo_ca: [N_ref, 3] torch — holo reference CA
91
+ ref_apo_ca: [N_ref, 3] torch — apo reference CA
92
+ n_steps: int
93
+ step_size: float (η)
94
+ noise_scale: float (for stochastic Langevin, 0 = gradient ascent)
95
+ device: str
96
+
97
+ Returns:
98
+ best_coords: [N_binder, 4, 3] numpy — refined coords
99
+ trajectory: list of dicts with step info
100
+ """
101
+ device = torch.device(device)
102
+
103
+ # Convert to tensors
104
+ x = torch.from_numpy(binder_coords_init.copy()).float().to(device)
105
+ mask_t = torch.from_numpy(binder_mask).bool().to(device)
106
+ aa_t = torch.from_numpy(binder_aa_idx).long().to(device)
107
+ rec_ca = torch.from_numpy(rec_coords[:, 1, :]).float().to(device)
108
+
109
+ best_margin = -float('inf')
110
+ best_coords = binder_coords_init.copy()
111
+ best_q_holo = 0.0
112
+ best_q_apo = 0.0
113
+ trajectory = []
114
+
115
+ for step in range(n_steps):
116
+ x_grad = x.clone().requires_grad_(True)
117
+
118
+ try:
119
+ with torch.enable_grad():
120
+ # Align to holo reference
121
+ n_align_h = min(len(rec_ca), len(ref_holo_ca))
122
+ if n_align_h < 5:
123
+ break
124
+ from qtheta_pxdesign import differentiable_kabsch
125
+ R_h, t_h = differentiable_kabsch(rec_ca[:n_align_h].detach(),
126
+ ref_holo_ca[:n_align_h].detach())
127
+ R_h, t_h = R_h.detach(), t_h.detach()
128
+ aligned_holo = x_grad.reshape(-1, 3) @ R_h.T + t_h
129
+ aligned_holo = aligned_holo.reshape(-1, 4, 3)
130
+
131
+ q_holo = dq.score(aligned_holo, mask_t, binder_aa_idx=aa_t,
132
+ receptor_label='holo')
133
+
134
+ # Align to apo reference
135
+ n_align_a = min(len(rec_ca), len(ref_apo_ca))
136
+ R_a, t_a = differentiable_kabsch(rec_ca[:n_align_a].detach(),
137
+ ref_apo_ca[:n_align_a].detach())
138
+ R_a, t_a = R_a.detach(), t_a.detach()
139
+ aligned_apo = x_grad.reshape(-1, 3) @ R_a.T + t_a
140
+ aligned_apo = aligned_apo.reshape(-1, 4, 3)
141
+
142
+ q_apo = dq.score(aligned_apo, mask_t, binder_aa_idx=aa_t,
143
+ receptor_label='apo')
144
+
145
+ margin = q_holo - q_apo
146
+ margin.backward()
147
+
148
+ grad = x_grad.grad
149
+ if grad is None or torch.isnan(grad).any():
150
+ continue
151
+
152
+ # Gradient ascent step
153
+ x = x + step_size * grad
154
+
155
+ # Optional noise for stochastic Langevin
156
+ if noise_scale > 0:
157
+ x = x + noise_scale * np.sqrt(2 * step_size) * torch.randn_like(x)
158
+
159
+ current_margin = margin.item()
160
+ step_info = {
161
+ 'step': step,
162
+ 'q_holo': q_holo.item(),
163
+ 'q_apo': q_apo.item(),
164
+ 'margin': current_margin,
165
+ 'grad_norm': grad.norm().item(),
166
+ }
167
+ trajectory.append(step_info)
168
+
169
+ if current_margin > best_margin:
170
+ best_margin = current_margin
171
+ best_coords = x.detach().cpu().numpy()
172
+ best_q_holo = q_holo.item()
173
+ best_q_apo = q_apo.item()
174
+
175
+ if step % 20 == 0:
176
+ logger.info(
177
+ f" Step {step:3d}: Q+={q_holo.item():.3f} Q-={q_apo.item():.3f} "
178
+ f"S={current_margin:+.3f} |∇|={grad.norm().item():.4f}")
179
+
180
+ except Exception as e:
181
+ logger.debug(f" Step {step}: {e}")
182
+ continue
183
+
184
+ return best_coords, trajectory, best_margin, best_q_holo, best_q_apo
185
+
186
+
187
+ def main():
188
+ parser = argparse.ArgumentParser(description='PXDesign + Langevin Refinement')
189
+ parser.add_argument('--designs_dir', default='experiments/pxdesign_cam/output/')
190
+ parser.add_argument('--qtheta_checkpoint',
191
+ default='results/checkpoints_cam_v3/best_phase2.pt')
192
+ parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
193
+ parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
194
+ parser.add_argument('--ref_chain', default='A')
195
+ parser.add_argument('--n_steps', type=int, default=100)
196
+ parser.add_argument('--step_size', type=float, default=0.01)
197
+ parser.add_argument('--noise_scale', type=float, default=0.0,
198
+ help='Noise scale for stochastic Langevin (0=gradient ascent)')
199
+ parser.add_argument('--gpu', type=int, default=0)
200
+ parser.add_argument('--outdir', default='results/pxdesign_langevin')
201
+ args = parser.parse_args()
202
+
203
+ os.chdir(_ALLO_ROOT)
204
+
205
+ device = f'cuda:{args.gpu}'
206
+
207
+ from models.differentiable_features import DifferentiableQTheta
208
+
209
+ # Load scorer
210
+ dq = DifferentiableQTheta(args.qtheta_checkpoint, device=device)
211
+ dq.load_receptor(args.ref_holo, chain=args.ref_chain, label='holo')
212
+ dq.load_receptor(args.ref_apo, chain=args.ref_chain, label='apo')
213
+
214
+ # Load reference CA coords
215
+ holo_model = load_structure(args.ref_holo)
216
+ holo_res = get_residues(holo_model[args.ref_chain])
217
+ holo_coords, _ = get_backbone_coords(holo_res)
218
+ ref_holo_ca = torch.from_numpy(holo_coords[:, 1, :]).float().to(device)
219
+
220
+ apo_model = load_structure(args.ref_apo)
221
+ apo_res = get_residues(apo_model[args.ref_chain])
222
+ apo_coords, _ = get_backbone_coords(apo_res)
223
+ ref_apo_ca = torch.from_numpy(apo_coords[:, 1, :]).float().to(device)
224
+
225
+ # Find designs
226
+ pdbs = find_pxdesign_pdbs(args.designs_dir)
227
+ logger.info(f"Found {len(pdbs)} PXDesign outputs to refine")
228
+
229
+ outdir = args.outdir
230
+ os.makedirs(outdir, exist_ok=True)
231
+
232
+ all_results = []
233
+ for i, pdb_path in enumerate(pdbs):
234
+ design_id = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
235
+ logger.info(f"\n[{i+1}/{len(pdbs)}] Refining {design_id}...")
236
+
237
+ try:
238
+ model = load_structure(pdb_path)
239
+ chains = {c.get_id(): c for c in model.get_chains()}
240
+ chain_ids = sorted(chains.keys())
241
+
242
+ # Identify chains
243
+ ref_len = len(holo_res)
244
+ rec_chain_id, binder_chain_id = None, None
245
+ for cid in chain_ids:
246
+ cres = get_residues(chains[cid])
247
+ if abs(len(cres) - ref_len) < ref_len * 0.3:
248
+ rec_chain_id = cid
249
+ else:
250
+ binder_chain_id = cid
251
+
252
+ if rec_chain_id is None or binder_chain_id is None:
253
+ if len(chain_ids) >= 2:
254
+ rec_chain_id, binder_chain_id = chain_ids[0], chain_ids[1]
255
+ else:
256
+ logger.warning(f"Skipping {design_id}: cannot identify chains")
257
+ continue
258
+
259
+ rec_res = get_residues(chains[rec_chain_id])
260
+ binder_res = get_residues(chains[binder_chain_id])
261
+
262
+ rec_coords_np, rec_mask = get_backbone_coords(rec_res)
263
+ binder_coords_np, binder_mask = get_backbone_coords(binder_res)
264
+ aa_idx = get_aa_indices(binder_res)
265
+
266
+ # Score before refinement
267
+ rec_ca = rec_coords_np[:, 1, :]
268
+ n_align = min(len(rec_ca), len(holo_coords[:, 1, :]))
269
+ _, R_h = align_structures(rec_ca[:n_align], holo_coords[:n_align, 1, :])
270
+ center_h = rec_ca[:n_align].mean(0)
271
+ ref_center_h = holo_coords[:n_align, 1, :].mean(0)
272
+
273
+ aligned_init = (binder_coords_np.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h
274
+ aligned_init = aligned_init.reshape(-1, 4, 3)
275
+ with torch.no_grad():
276
+ q_h_init = dq.score(
277
+ torch.from_numpy(aligned_init).float().to(device),
278
+ torch.from_numpy(binder_mask).bool().to(device),
279
+ binder_aa_idx=torch.from_numpy(aa_idx).long().to(device),
280
+ receptor_label='holo').item()
281
+
282
+ n_align_a = min(len(rec_ca), len(apo_coords[:, 1, :]))
283
+ _, R_a = align_structures(rec_ca[:n_align_a], apo_coords[:n_align_a, 1, :])
284
+ center_a = rec_ca[:n_align_a].mean(0)
285
+ ref_center_a = apo_coords[:n_align_a, 1, :].mean(0)
286
+ aligned_init_a = (binder_coords_np.reshape(-1, 3) - center_a) @ R_a.T + ref_center_a
287
+ aligned_init_a = aligned_init_a.reshape(-1, 4, 3)
288
+ with torch.no_grad():
289
+ q_a_init = dq.score(
290
+ torch.from_numpy(aligned_init_a).float().to(device),
291
+ torch.from_numpy(binder_mask).bool().to(device),
292
+ binder_aa_idx=torch.from_numpy(aa_idx).long().to(device),
293
+ receptor_label='apo').item()
294
+
295
+ margin_init = q_h_init - q_a_init
296
+
297
+ # Run Langevin refinement
298
+ refined_coords, trajectory, best_margin, best_qh, best_qa = langevin_refine(
299
+ dq, binder_coords_np, binder_mask, aa_idx,
300
+ rec_coords_np, rec_mask, ref_holo_ca, ref_apo_ca,
301
+ n_steps=args.n_steps, step_size=args.step_size,
302
+ noise_scale=args.noise_scale, device=device,
303
+ )
304
+
305
+ # Use best-margin values (matching the saved best_coords PDB)
306
+ margin_final = best_margin if trajectory else margin_init
307
+
308
+ # Save refined PDB
309
+ out_pdb = os.path.join(outdir, f'{design_id}_refined.pdb')
310
+ write_backbone_pdb(refined_coords, binder_mask, out_pdb)
311
+
312
+ result = {
313
+ 'design_id': design_id,
314
+ 'pdb_path': pdb_path,
315
+ 'refined_pdb': out_pdb,
316
+ 'q_holo_init': q_h_init,
317
+ 'q_apo_init': q_a_init,
318
+ 'margin_init': margin_init,
319
+ 'q_holo_final': best_qh if trajectory else q_h_init,
320
+ 'q_apo_final': best_qa if trajectory else q_a_init,
321
+ 'margin_final': margin_final,
322
+ 'margin_delta': margin_final - margin_init,
323
+ 'n_steps_converged': len(trajectory),
324
+ 'n_res': len(binder_res),
325
+ }
326
+ all_results.append(result)
327
+
328
+ logger.info(
329
+ f" {design_id}: S_init={margin_init:+.3f} -> S_final={margin_final:+.3f} "
330
+ f"(Δ={margin_final - margin_init:+.3f})")
331
+
332
+ except Exception as e:
333
+ logger.warning(f"Failed to refine {design_id}: {e}")
334
+ continue
335
+
336
+ # Summary
337
+ if all_results:
338
+ all_results.sort(key=lambda x: x['margin_final'], reverse=True)
339
+ margins_init = np.array([r['margin_init'] for r in all_results])
340
+ margins_final = np.array([r['margin_final'] for r in all_results])
341
+ deltas = margins_final - margins_init
342
+
343
+ summary = {
344
+ 'method': 'PXDesign + Langevin',
345
+ 'n_designs': len(all_results),
346
+ 'n_steps': args.n_steps,
347
+ 'step_size': args.step_size,
348
+ 'margin_init_mean': float(margins_init.mean()),
349
+ 'margin_final_mean': float(margins_final.mean()),
350
+ 'margin_delta_mean': float(deltas.mean()),
351
+ 'frac_improved': float((deltas > 0).mean()),
352
+ 'frac_positive_init': float((margins_init > 0).mean()),
353
+ 'frac_positive_final': float((margins_final > 0).mean()),
354
+ 'q_holo_final_mean': float(np.mean([r['q_holo_final'] for r in all_results])),
355
+ }
356
+
357
+ with open(os.path.join(outdir, 'langevin_scores.json'), 'w') as f:
358
+ json.dump(all_results, f, indent=2)
359
+ with open(os.path.join(outdir, 'langevin_summary.json'), 'w') as f:
360
+ json.dump(summary, f, indent=2)
361
+
362
+ logger.info(f"\n{'='*60}")
363
+ logger.info(f"PXDesign + Langevin Results ({len(all_results)} designs)")
364
+ logger.info(f" Margin init: {margins_init.mean():.3f} ± {margins_init.std():.3f}")
365
+ logger.info(f" Margin final: {margins_final.mean():.3f} ± {margins_final.std():.3f}")
366
+ logger.info(f" Δ margin: {deltas.mean():+.3f} ± {deltas.std():.3f}")
367
+ logger.info(f" % improved: {(deltas > 0).mean():.1%}")
368
+ logger.info(f" S>0 init/final: {(margins_init > 0).mean():.1%} / "
369
+ f"{(margins_final > 0).mean():.1%}")
370
+ logger.info(f"{'='*60}")
371
+
372
+
373
+ if __name__ == '__main__':
374
+ main()
code/scripts/pxdesign_guidance/qtheta_pxdesign.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core Q_theta guidance module for PXDesign integration.
3
+
4
+ Provides differentiable Q_theta scoring for PXDesign's atom coordinate format.
5
+ Key responsibilities:
6
+ - Extract binder backbone (N, CA, C, O) from PXDesign's flat atom array
7
+ - Align binder to reference receptor frames via differentiable Kabsch
8
+ - Compute selectivity gradient ∇[Q(holo,Y) - Q(apo,Y)] w.r.t. atom coords
9
+ - Works in pxdesign env (PyTorch 2.3.1) using pure-PyTorch scorer (no e3nn)
10
+
11
+ Usage:
12
+ guidance = QThetaPXDesignGuidance(
13
+ checkpoint='results/checkpoints_cam_v3/best_phase2.pt',
14
+ ref_holo='data/pdbs/cam_holo/3CLN.pdb',
15
+ ref_apo='data/pdbs/cam_apo/1CFD.pdb',
16
+ ref_chain='A',
17
+ device='cuda:0',
18
+ )
19
+ # Inside PXDesign diffusion loop:
20
+ grad = guidance.compute_guidance_gradient(x_denoised, input_feature_dict, t_hat)
21
+ x_denoised = x_denoised + scale * grad
22
+ """
23
+
24
+ import os
25
+ import sys
26
+ import logging
27
+ import numpy as np
28
+ import torch
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Add Allo-Designer code directory to path
33
+ _ALLO_CODE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
34
+ if _ALLO_CODE_DIR not in sys.path:
35
+ sys.path.insert(0, _ALLO_CODE_DIR)
36
+
37
+
38
+ def differentiable_kabsch(mobile, target):
39
+ """
40
+ Differentiable Kabsch alignment using SVD.
41
+
42
+ Args:
43
+ mobile: [N, 3] tensor (points to align FROM)
44
+ target: [N, 3] tensor (points to align TO)
45
+
46
+ Returns:
47
+ R: [3, 3] rotation matrix
48
+ t: [3] translation vector
49
+ Such that aligned = (mobile - mobile_center) @ R.T + target_center
50
+ """
51
+ mobile_center = mobile.mean(dim=0)
52
+ target_center = target.mean(dim=0)
53
+
54
+ mobile_centered = mobile - mobile_center
55
+ target_centered = target - target_center
56
+
57
+ H = mobile_centered.T @ target_centered # [3, 3]
58
+ U, S, Vh = torch.linalg.svd(H)
59
+ V = Vh.T
60
+
61
+ # Ensure proper rotation (det > 0)
62
+ d = torch.det(V @ U.T)
63
+ sign_matrix = torch.diag(torch.tensor([1.0, 1.0, torch.sign(d)],
64
+ device=mobile.device, dtype=mobile.dtype))
65
+ R = V @ sign_matrix @ U.T # [3, 3]
66
+ t = target_center - mobile_center @ R.T # [3]
67
+
68
+ return R, t
69
+
70
+
71
+ class QThetaPXDesignGuidance:
72
+ """
73
+ Q_theta guidance for PXDesign diffusion process.
74
+
75
+ Lazily initializes the scorer and reference structures on first use.
76
+ Handles extraction of binder backbone from PXDesign's flat atom array
77
+ and alignment to reference receptor frames.
78
+ """
79
+
80
+ def __init__(self, checkpoint, ref_holo, ref_apo, ref_chain='A',
81
+ device='cuda:0', cutoff=8.0, esm_target='cam'):
82
+ self.checkpoint = checkpoint
83
+ self.ref_holo = ref_holo
84
+ self.ref_apo = ref_apo
85
+ self.ref_chain = ref_chain
86
+ self.device = torch.device(device)
87
+ self.cutoff = cutoff
88
+ self.esm_target = esm_target
89
+
90
+ self._initialized = False
91
+ self.dq = None
92
+ self.ref_holo_ca = None
93
+ self.ref_apo_ca = None
94
+
95
+ def _lazy_init(self):
96
+ """Initialize Q_theta scorer and load reference structures."""
97
+ if self._initialized:
98
+ return
99
+
100
+ from models.differentiable_features import DifferentiableQTheta
101
+ from utils.pdb_utils import load_structure, get_residues, get_backbone_coords
102
+
103
+ logger.info(f"Loading Q_theta checkpoint: {self.checkpoint}")
104
+ self.dq = DifferentiableQTheta(self.checkpoint, device=str(self.device))
105
+ self.dq.load_receptor(self.ref_holo, chain=self.ref_chain, label='holo',
106
+ esm_target=self.esm_target)
107
+ self.dq.load_receptor(self.ref_apo, chain=self.ref_chain, label='apo',
108
+ esm_target=self.esm_target)
109
+
110
+ # Cache reference CA coords for alignment
111
+ holo_model = load_structure(self.ref_holo)
112
+ holo_res = get_residues(holo_model[self.ref_chain])
113
+ holo_coords, _ = get_backbone_coords(holo_res)
114
+ self.ref_holo_ca = torch.from_numpy(holo_coords[:, 1, :]).float().to(self.device)
115
+
116
+ apo_model = load_structure(self.ref_apo)
117
+ apo_res = get_residues(apo_model[self.ref_chain])
118
+ apo_coords, _ = get_backbone_coords(apo_res)
119
+ self.ref_apo_ca = torch.from_numpy(apo_coords[:, 1, :]).float().to(self.device)
120
+
121
+ self._initialized = True
122
+ logger.info(f"Q_theta guidance initialized: holo={len(holo_res)} res, apo={len(apo_res)} res")
123
+
124
+ def extract_binder_backbone(self, x_coords, input_feature_dict):
125
+ """
126
+ Extract binder backbone atoms (N, CA, C, O) from PXDesign's flat atom array.
127
+
128
+ PXDesign stores all atoms in a flat [N_atom, 3] array. Entity annotations
129
+ identify which atoms belong to the designed binder (entity_id=2 typically,
130
+ or the last entity). We extract backbone atoms for each binder residue.
131
+
132
+ Args:
133
+ x_coords: [N_sample, N_atom, 3] — current coordinates from diffusion
134
+ input_feature_dict: dict with atom_to_token_idx, entity_id, etc.
135
+
136
+ Returns:
137
+ binder_bb: [N_sample, N_binder_res, 4, 3] — backbone coords (N, CA, C, O)
138
+ binder_mask: [N_binder_res] — validity mask
139
+ rec_bb: [N_rec_res, 4, 3] — receptor backbone coords (from condition)
140
+ rec_mask: [N_rec_res] — receptor validity mask
141
+ binder_atom_indices: [N_binder_bb_atoms] — indices into flat atom array
142
+ """
143
+ atom_to_token = input_feature_dict['atom_to_token_idx'] # [N_atom]
144
+ if atom_to_token.dim() > 1:
145
+ atom_to_token = atom_to_token.squeeze(0)
146
+
147
+ # Identify binder vs receptor tokens
148
+ # In PXDesign: design_token_mask=True for binder tokens
149
+ design_token_mask = input_feature_dict.get('design_token_mask', None)
150
+ if design_token_mask is not None:
151
+ if design_token_mask.dim() > 1:
152
+ design_token_mask = design_token_mask.squeeze(0)
153
+ binder_tokens = torch.where(design_token_mask)[0]
154
+ rec_tokens = torch.where(~design_token_mask)[0]
155
+ else:
156
+ # Fallback: use entity_id (binder is typically entity_id=2, the last entity)
157
+ entity_id = input_feature_dict['entity_id']
158
+ if entity_id.dim() > 1:
159
+ entity_id = entity_id.squeeze(0)
160
+ max_entity = entity_id.max()
161
+ binder_tokens = torch.where(entity_id == max_entity)[0]
162
+ rec_tokens = torch.where(entity_id != max_entity)[0]
163
+
164
+ # Map tokens to atoms
165
+ # For standard amino acids, atom order within each token is:
166
+ # N(0), CA(1), C(2), O(3), CB(4), ...
167
+ # We need atoms 0-3 (N, CA, C, O) per token
168
+
169
+ # Get atom indices for each binder token
170
+ n_binder_res = len(binder_tokens)
171
+ if n_binder_res == 0:
172
+ return None
173
+
174
+ # Find atoms belonging to each binder residue
175
+ binder_bb_list = []
176
+ binder_atom_idx_list = []
177
+ for tok_idx in binder_tokens:
178
+ atom_indices = torch.where(atom_to_token == tok_idx.item())[0]
179
+ if len(atom_indices) >= 4:
180
+ # First 4 atoms are N, CA, C, O for standard amino acids
181
+ bb_atoms = atom_indices[:4]
182
+ binder_bb_list.append(bb_atoms)
183
+ binder_atom_idx_list.append(bb_atoms)
184
+
185
+ if not binder_bb_list:
186
+ return None
187
+
188
+ n_binder_res = len(binder_bb_list)
189
+ binder_bb_indices = torch.stack(binder_bb_list) # [N_binder, 4]
190
+ all_binder_atom_indices = torch.cat(binder_atom_idx_list) # [N_binder * 4]
191
+
192
+ # Extract binder backbone coords for all samples
193
+ # x_coords: [N_sample, N_atom, 3]
194
+ binder_bb = x_coords[:, binder_bb_indices, :] # [N_sample, N_binder, 4, 3]
195
+ binder_mask = torch.ones(n_binder_res, dtype=torch.bool, device=x_coords.device)
196
+
197
+ # Extract receptor backbone from x_coords or condition_coordinate.
198
+ # PXDesign stores condition_coordinate in label_dict (not input_feature_dict),
199
+ # so we extract receptor backbone from x_coords directly. In the diffusion
200
+ # process, receptor atoms are conditioned at their reference positions.
201
+ # Try condition_coordinate first (if available), then fall back to x_coords.
202
+ cond_coords = input_feature_dict.get('condition_coordinate', None)
203
+ if cond_coords is None:
204
+ # Also try label_dict nesting
205
+ label_dict = input_feature_dict.get('label_dict', None)
206
+ if label_dict is not None:
207
+ cond_coords = label_dict.get('condition_coordinate', None)
208
+
209
+ rec_bb = None
210
+ rec_mask = None
211
+
212
+ # Get receptor backbone atoms
213
+ rec_bb_list = []
214
+ for tok_idx in rec_tokens:
215
+ atom_indices = torch.where(atom_to_token == tok_idx.item())[0]
216
+ if len(atom_indices) >= 4:
217
+ rec_bb_list.append(atom_indices[:4])
218
+
219
+ if rec_bb_list:
220
+ rec_bb_indices = torch.stack(rec_bb_list) # [N_rec, 4]
221
+
222
+ if cond_coords is not None:
223
+ if cond_coords.dim() > 2:
224
+ cond_coords = cond_coords.squeeze(0)
225
+ rec_bb = cond_coords[rec_bb_indices, :] # [N_rec, 4, 3]
226
+ else:
227
+ # Fallback: extract receptor coords from x_coords (sample 0)
228
+ # Receptor atoms are conditioned and constant across samples
229
+ rec_bb = x_coords[0, rec_bb_indices, :].detach() # [N_rec, 4, 3]
230
+
231
+ rec_mask = torch.ones(len(rec_bb_list), dtype=torch.bool,
232
+ device=x_coords.device)
233
+
234
+ return {
235
+ 'binder_bb': binder_bb, # [N_sample, N_binder, 4, 3]
236
+ 'binder_mask': binder_mask, # [N_binder]
237
+ 'rec_bb': rec_bb, # [N_rec, 4, 3] or None
238
+ 'rec_mask': rec_mask, # [N_rec] or None
239
+ 'binder_atom_indices': binder_bb_indices, # [N_binder, 4]
240
+ 'all_binder_atom_indices': all_binder_atom_indices, # [N_binder * 4]
241
+ }
242
+
243
+ def align_and_score(self, binder_bb, rec_bb, rec_mask, receptor_label):
244
+ """
245
+ Align binder to a reference receptor frame and score with Q_theta.
246
+
247
+ Uses the receptor chain from the design to compute Kabsch alignment
248
+ to the reference receptor, then transforms the binder accordingly.
249
+
250
+ Args:
251
+ binder_bb: [N_binder, 4, 3] — binder backbone coords (requires_grad)
252
+ rec_bb: [N_rec, 4, 3] — receptor backbone coords
253
+ rec_mask: [N_rec] bool
254
+ receptor_label: 'holo' or 'apo'
255
+
256
+ Returns:
257
+ score: scalar tensor, differentiable w.r.t. binder_bb
258
+ """
259
+ if receptor_label == 'holo':
260
+ ref_ca = self.ref_holo_ca
261
+ else:
262
+ ref_ca = self.ref_apo_ca
263
+
264
+ # Get CA atoms from receptor
265
+ rec_ca = rec_bb[:, 1, :] # [N_rec, 3]
266
+
267
+ # Use overlapping residues for alignment (take min length)
268
+ n_align = min(len(rec_ca), len(ref_ca))
269
+ if n_align < 5:
270
+ return torch.zeros(1, device=binder_bb.device, dtype=binder_bb.dtype,
271
+ requires_grad=True).squeeze()
272
+
273
+ mobile_ca = rec_ca[:n_align].detach()
274
+ target_ca = ref_ca[:n_align].detach()
275
+
276
+ # Compute Kabsch alignment (detached — no gradient through rotation)
277
+ R, t = differentiable_kabsch(mobile_ca, target_ca)
278
+ R = R.detach()
279
+ t = t.detach()
280
+
281
+ # Apply transform to binder (gradient flows through binder_bb)
282
+ binder_flat = binder_bb.reshape(-1, 3) # [N_binder*4, 3]
283
+ aligned = binder_flat @ R.T + t # [N_binder*4, 3]
284
+ aligned_bb = aligned.reshape(-1, 4, 3) # [N_binder, 4, 3]
285
+
286
+ # Score with Q_theta
287
+ binder_mask = torch.ones(aligned_bb.shape[0], dtype=torch.bool,
288
+ device=binder_bb.device)
289
+ score = self.dq.score(aligned_bb, binder_mask, receptor_label=receptor_label,
290
+ cutoff=self.cutoff)
291
+ return score
292
+
293
+ def compute_guidance_gradient(self, x_denoised, input_feature_dict, t_hat=None,
294
+ sample_idx=0):
295
+ """
296
+ Compute Q_theta selectivity gradient for guidance.
297
+
298
+ Args:
299
+ x_denoised: [N_sample, N_atom, 3] — denoised coordinates from diffusion net
300
+ input_feature_dict: PXDesign input features dict
301
+ t_hat: current noise level (for logging/scaling)
302
+ sample_idx: which sample to compute gradient for (or -1 for all)
303
+
304
+ Returns:
305
+ gradient: [N_sample, N_atom, 3] — gradient to add to x_denoised
306
+ (non-zero only at binder backbone atom positions)
307
+ margin: float — current selectivity margin
308
+ """
309
+ self._lazy_init()
310
+
311
+ extraction = self.extract_binder_backbone(x_denoised.detach(), input_feature_dict)
312
+ if extraction is None:
313
+ return torch.zeros_like(x_denoised), 0.0
314
+
315
+ binder_bb = extraction['binder_bb'] # [N_sample, N_binder, 4, 3]
316
+ binder_mask = extraction['binder_mask'] # [N_binder]
317
+ rec_bb = extraction['rec_bb'] # [N_rec, 4, 3]
318
+ rec_mask = extraction['rec_mask'] # [N_rec]
319
+ binder_atom_indices = extraction['binder_atom_indices'] # [N_binder, 4]
320
+
321
+ if rec_bb is None:
322
+ return torch.zeros_like(x_denoised), 0.0
323
+
324
+ N_sample = x_denoised.shape[0]
325
+ gradient = torch.zeros_like(x_denoised)
326
+ margins = []
327
+
328
+ # Ensure receptor is float32 for Q_theta scoring
329
+ if rec_bb is not None:
330
+ rec_bb = rec_bb.float()
331
+
332
+ # Process each sample
333
+ indices = range(N_sample) if sample_idx == -1 else [sample_idx]
334
+ for si in indices:
335
+ # Make binder coords differentiable, cast to float32 for Q_theta
336
+ binder_si = binder_bb[si].clone().float().requires_grad_(True) # [N_binder, 4, 3]
337
+
338
+ try:
339
+ with torch.enable_grad():
340
+ q_holo = self.align_and_score(binder_si, rec_bb, rec_mask, 'holo')
341
+ q_apo = self.align_and_score(binder_si, rec_bb, rec_mask, 'apo')
342
+ margin = q_holo - q_apo
343
+ margin.backward()
344
+
345
+ if binder_si.grad is not None and not torch.isnan(binder_si.grad).any():
346
+ # Map gradient back to full atom array
347
+ grad_bb = binder_si.grad # [N_binder, 4, 3]
348
+ for ri in range(len(binder_atom_indices)):
349
+ for ai in range(4):
350
+ atom_idx = binder_atom_indices[ri, ai]
351
+ gradient[si, atom_idx] = grad_bb[ri, ai]
352
+ margins.append(margin.item())
353
+ else:
354
+ margins.append(0.0)
355
+ except Exception as e:
356
+ logger.debug(f"Gradient computation failed for sample {si}: {e}")
357
+ margins.append(0.0)
358
+
359
+ avg_margin = np.mean(margins) if margins else 0.0
360
+ return gradient, avg_margin
361
+
362
+ def score_design(self, pdb_path, rec_chain='A', binder_chain='B'):
363
+ """
364
+ Score a single PXDesign output PDB/CIF (post-hoc, no gradient).
365
+
366
+ Handles PXDesign CIF files which use chain IDs like 'A0'/'B0' and
367
+ non-standard residue name 'xpb' for designed binder residues.
368
+
369
+ Returns:
370
+ dict with q_holo, q_apo, margin, or None on failure
371
+ """
372
+ self._lazy_init()
373
+
374
+ from utils.pdb_utils import (
375
+ load_structure, get_residues, get_backbone_coords,
376
+ get_aa_indices, align_structures
377
+ )
378
+
379
+ try:
380
+ model = load_structure(pdb_path)
381
+ chains = {c.get_id(): c for c in model.get_chains()}
382
+
383
+ if len(chains) < 2:
384
+ return None
385
+
386
+ chain_ids = sorted(chains.keys())
387
+
388
+ # Identify receptor and binder
389
+ # PXDesign CIF uses chain IDs like 'A0', 'B0' instead of 'A', 'B'
390
+ rc, bc = None, None
391
+ if rec_chain in chains and binder_chain in chains:
392
+ rc, bc = rec_chain, binder_chain
393
+ else:
394
+ # Match by residue count: receptor matches reference length,
395
+ # binder is the other chain
396
+ ref_model = load_structure(self.ref_holo)
397
+ ref_res = get_residues(ref_model[self.ref_chain])
398
+ ref_len = len(ref_res)
399
+ for cid in chain_ids:
400
+ # Try standard residues first, then all residues
401
+ cres = get_residues(chains[cid])
402
+ if not cres:
403
+ cres = get_residues(chains[cid], only_standard=False)
404
+ n_res = len(cres)
405
+ if n_res > 0 and abs(n_res - ref_len) < ref_len * 0.3:
406
+ rc = cid
407
+ elif n_res > 0:
408
+ bc = cid
409
+ if rc is None or bc is None:
410
+ rc, bc = chain_ids[0], chain_ids[1]
411
+
412
+ rec_res = get_residues(chains[rc])
413
+ if not rec_res:
414
+ rec_res = get_residues(chains[rc], only_standard=False)
415
+
416
+ # For binder: PXDesign uses 'xpb' residue names (non-standard)
417
+ binder_res = get_residues(chains[bc])
418
+ if not binder_res:
419
+ binder_res = get_residues(chains[bc], only_standard=False)
420
+
421
+ if not rec_res or not binder_res:
422
+ return None
423
+
424
+ rec_coords, rec_mask = get_backbone_coords(rec_res)
425
+ binder_coords, binder_mask = get_backbone_coords(binder_res)
426
+
427
+ # Handle amino acid indices: use get_aa_indices for standard AAs,
428
+ # default to GLY (7) for non-standard (PXDesign 'xpb')
429
+ try:
430
+ aa_idx = get_aa_indices(binder_res)
431
+ except Exception:
432
+ aa_idx = np.zeros(len(binder_res), dtype=np.int64) # default to ALA
433
+
434
+ device = self.device
435
+
436
+ # Align to holo
437
+ rec_ca = rec_coords[:, 1, :]
438
+ ref_holo_ca_np = self.ref_holo_ca.cpu().numpy()
439
+ n_align = min(len(rec_ca), len(ref_holo_ca_np))
440
+ if n_align < 5:
441
+ return None
442
+ _, R_h = align_structures(rec_ca[:n_align], ref_holo_ca_np[:n_align])
443
+ center_h = rec_ca[:n_align].mean(0)
444
+ ref_center_h = ref_holo_ca_np[:n_align].mean(0)
445
+ aligned_holo = (binder_coords.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h
446
+ aligned_holo = aligned_holo.reshape(-1, 4, 3)
447
+
448
+ # Align to apo
449
+ ref_apo_ca_np = self.ref_apo_ca.cpu().numpy()
450
+ n_align_a = min(len(rec_ca), len(ref_apo_ca_np))
451
+ _, R_a = align_structures(rec_ca[:n_align_a], ref_apo_ca_np[:n_align_a])
452
+ center_a = rec_ca[:n_align_a].mean(0)
453
+ ref_center_a = ref_apo_ca_np[:n_align_a].mean(0)
454
+ aligned_apo = (binder_coords.reshape(-1, 3) - center_a) @ R_a.T + ref_center_a
455
+ aligned_apo = aligned_apo.reshape(-1, 4, 3)
456
+
457
+ with torch.no_grad():
458
+ coords_h = torch.from_numpy(aligned_holo).float().to(device)
459
+ coords_a = torch.from_numpy(aligned_apo).float().to(device)
460
+ mask_t = torch.from_numpy(binder_mask).bool().to(device)
461
+ aa_t = torch.from_numpy(aa_idx).long().to(device)
462
+
463
+ q_holo = self.dq.score(coords_h, mask_t, binder_aa_idx=aa_t,
464
+ receptor_label='holo').item()
465
+ q_apo = self.dq.score(coords_a, mask_t, binder_aa_idx=aa_t,
466
+ receptor_label='apo').item()
467
+
468
+ return {
469
+ 'q_holo': q_holo,
470
+ 'q_apo': q_apo,
471
+ 'margin': q_holo - q_apo,
472
+ 'n_res': len(binder_res),
473
+ }
474
+
475
+ except Exception as e:
476
+ logger.warning(f"Error scoring {pdb_path}: {e}")
477
+ return None
code/scripts/pxdesign_guidance/smc_pxdesign.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PXDesign + SMC Reranking.
3
+
4
+ Post-hoc Sequential Monte Carlo: generate multiple batches of vanilla PXDesign
5
+ binders, score with Q_theta, and rank by selectivity margin. No modification
6
+ to the PXDesign diffusion process — pure generate-score-rank pipeline.
7
+
8
+ This is the simplest Q_theta integration strategy: generate a large pool of
9
+ candidates and select the best ones by selectivity score.
10
+
11
+ Usage:
12
+ python code/scripts/pxdesign_guidance/smc_pxdesign.py \
13
+ --input experiments/pxdesign_cam/output/cam_binder.json \
14
+ --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
15
+ --ref_holo data/pdbs/cam_holo/3CLN.pdb \
16
+ --ref_apo data/pdbs/cam_apo/1CFD.pdb \
17
+ --n_particles 16 --n_rounds 4 \
18
+ --gpu 0
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import argparse
24
+ import json
25
+ import logging
26
+ import shutil
27
+ import subprocess
28
+ from glob import glob
29
+
30
+ import numpy as np
31
+ import torch
32
+
33
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
34
+ logger = logging.getLogger(__name__)
35
+
36
+ _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
37
+ _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
38
+ _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
39
+
40
+ if _ALLO_CODE_DIR not in sys.path:
41
+ sys.path.insert(0, _ALLO_CODE_DIR)
42
+
43
+
44
+ def run_pxdesign_batch(input_json, outdir, n_sample, n_step, gpu):
45
+ """Run vanilla PXDesign via CLI subprocess."""
46
+ pxdesign_python = 'python'
47
+
48
+ # Use pxdesign CLI
49
+ cmd = [
50
+ pxdesign_python, '-m', 'pxdesign.runner.cli', 'infer',
51
+ '-o', outdir,
52
+ '-i', input_json,
53
+ '--dtype', 'bf16',
54
+ '--N_sample', str(n_sample),
55
+ '--N_step', str(n_step),
56
+ ]
57
+
58
+ env = os.environ.copy()
59
+ # Inherit CUDA_VISIBLE_DEVICES from parent
60
+
61
+ logger.info(f"Running PXDesign: {n_sample} samples -> {outdir}")
62
+ result = subprocess.run(cmd, capture_output=True, text=True, env=env,
63
+ timeout=7200)
64
+
65
+ if result.returncode != 0:
66
+ # Try alternative invocation via module
67
+ cmd_alt = [
68
+ pxdesign_python, '-m', 'pxdesign.runner.inference',
69
+ '--dump_dir', outdir,
70
+ '--input', input_json,
71
+ '--dtype', 'bf16',
72
+ '--N_sample', str(n_sample),
73
+ '--N_step', str(n_step),
74
+ ]
75
+ result = subprocess.run(cmd_alt, capture_output=True, text=True, env=env,
76
+ timeout=7200)
77
+ if result.returncode != 0:
78
+ logger.error(f"PXDesign failed:\nstdout: {result.stdout[-1000:]}\nstderr: {result.stderr[-1000:]}")
79
+ return False
80
+ return True
81
+
82
+
83
+ def collect_pdbs(outdir):
84
+ """Collect PDB/CIF files from PXDesign output."""
85
+ pdbs = []
86
+ for ext in ('*.pdb', '*.cif'):
87
+ pdbs.extend(glob(os.path.join(outdir, '**/' + ext), recursive=True))
88
+ pdbs = sorted(pdbs)
89
+ filtered = [p for p in pdbs if 'sample' in os.path.basename(p).lower()
90
+ or 'design' in os.path.basename(p).lower()
91
+ or 'rank' in os.path.basename(p).lower()]
92
+ return filtered if filtered else pdbs
93
+
94
+
95
+ def smc_particle_filter(args):
96
+ """Run SMC reranking with PXDesign."""
97
+ os.chdir(_ALLO_ROOT)
98
+
99
+ from qtheta_pxdesign import QThetaPXDesignGuidance
100
+
101
+ outdir = args.outdir
102
+ os.makedirs(outdir, exist_ok=True)
103
+
104
+ # Initialize scorer
105
+ guidance = QThetaPXDesignGuidance(
106
+ checkpoint=args.qtheta_checkpoint,
107
+ ref_holo=args.ref_holo,
108
+ ref_apo=args.ref_apo,
109
+ ref_chain=args.ref_chain,
110
+ device=f'cuda:{args.gpu}',
111
+ )
112
+ guidance._lazy_init()
113
+
114
+ all_designs = []
115
+ round_summaries = []
116
+
117
+ for round_idx in range(args.n_rounds):
118
+ round_dir = os.path.join(outdir, f'round_{round_idx}')
119
+ os.makedirs(round_dir, exist_ok=True)
120
+
121
+ logger.info(f"\n{'='*60}")
122
+ logger.info(f"SMC Round {round_idx + 1}/{args.n_rounds}")
123
+ logger.info(f"{'='*60}")
124
+
125
+ # Generate particles via vanilla PXDesign
126
+ gen_dir = os.path.join(round_dir, 'generated')
127
+ success = run_pxdesign_batch(
128
+ input_json=args.input,
129
+ outdir=gen_dir,
130
+ n_sample=args.n_particles,
131
+ n_step=args.N_step,
132
+ gpu=args.gpu,
133
+ )
134
+
135
+ if not success:
136
+ # If subprocess fails, try using existing PXDesign outputs
137
+ logger.warning(f"Round {round_idx} generation failed. "
138
+ f"Checking for existing outputs...")
139
+ pdbs = collect_pdbs(args.designs_dir) if hasattr(args, 'designs_dir') else []
140
+ if not pdbs:
141
+ continue
142
+ else:
143
+ pdbs = collect_pdbs(gen_dir)
144
+
145
+ if not pdbs:
146
+ logger.warning(f"No PDBs found in round {round_idx}")
147
+ continue
148
+
149
+ # Score all particles
150
+ logger.info(f"Scoring {len(pdbs)} particles...")
151
+ round_results = []
152
+ for pdb_path in pdbs:
153
+ result = guidance.score_design(pdb_path)
154
+ if result is not None:
155
+ result['pdb_path'] = pdb_path
156
+ result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
157
+ result['round'] = round_idx
158
+ round_results.append(result)
159
+
160
+ if not round_results:
161
+ continue
162
+
163
+ margins = np.array([r['margin'] for r in round_results])
164
+
165
+ round_summary = {
166
+ 'round': round_idx,
167
+ 'n_particles': len(round_results),
168
+ 'margin_mean': float(margins.mean()),
169
+ 'margin_std': float(margins.std()),
170
+ 'margin_max': float(margins.max()),
171
+ 'frac_positive': float((margins > 0).mean()),
172
+ }
173
+ round_summaries.append(round_summary)
174
+
175
+ logger.info(f"Round {round_idx}: margin={margins.mean():.3f}±{margins.std():.3f}, "
176
+ f"max={margins.max():.3f}, S>0={round_summary['frac_positive']:.1%}")
177
+
178
+ all_designs.extend(round_results)
179
+
180
+ # Final ranking and summary
181
+ if all_designs:
182
+ all_designs.sort(key=lambda x: x['margin'], reverse=True)
183
+ all_margins = np.array([d['margin'] for d in all_designs])
184
+ holo_scores = np.array([d['q_holo'] for d in all_designs])
185
+
186
+ # Best-of-K
187
+ bok = {}
188
+ for K in [1, 2, 5, 10]:
189
+ n_trials = 2000
190
+ n_avail = len(all_margins)
191
+ successes = sum(
192
+ 1 for _ in range(n_trials)
193
+ if all_margins[np.random.choice(n_avail, min(K, n_avail), replace=False)].max() > 0
194
+ )
195
+ bok[K] = successes / n_trials
196
+
197
+ summary = {
198
+ 'method': 'PXDesign + SMC',
199
+ 'n_rounds': args.n_rounds,
200
+ 'n_particles_per_round': args.n_particles,
201
+ 'total_designs': len(all_designs),
202
+ 'margin_mean': float(all_margins.mean()),
203
+ 'margin_std': float(all_margins.std()),
204
+ 'margin_max': float(all_margins.max()),
205
+ 'frac_positive': float((all_margins > 0).mean()),
206
+ 'q_holo_mean': float(holo_scores.mean()),
207
+ 'q_apo_mean': float(np.mean([d['q_apo'] for d in all_designs])),
208
+ 'best_of_k': {str(k): v for k, v in bok.items()},
209
+ 'round_summaries': round_summaries,
210
+ 'top5': all_designs[:5],
211
+ }
212
+
213
+ with open(os.path.join(outdir, 'smc_scores.json'), 'w') as f:
214
+ json.dump(all_designs, f, indent=2)
215
+ with open(os.path.join(outdir, 'smc_summary.json'), 'w') as f:
216
+ json.dump(summary, f, indent=2)
217
+
218
+ # Copy best designs
219
+ best_dir = os.path.join(outdir, 'best_designs')
220
+ os.makedirs(best_dir, exist_ok=True)
221
+ for i, d in enumerate(all_designs[:20]):
222
+ if os.path.exists(d['pdb_path']):
223
+ dest = os.path.join(best_dir, f'rank_{i:02d}_{d["design_id"]}.pdb')
224
+ shutil.copy2(d['pdb_path'], dest)
225
+
226
+ logger.info(f"\n{'='*60}")
227
+ logger.info(f"PXDesign + SMC Results ({len(all_designs)} total designs)")
228
+ logger.info(f" Margin: {all_margins.mean():.3f} ± {all_margins.std():.3f}")
229
+ logger.info(f" Max margin: {all_margins.max():.3f}")
230
+ logger.info(f" Fraction S > 0: {(all_margins > 0).mean():.1%}")
231
+ logger.info(f" Q(holo) mean: {holo_scores.mean():.3f}")
232
+ logger.info(f" Best-of-K:")
233
+ for k, v in sorted(bok.items()):
234
+ logger.info(f" K={k:3d}: {v:.3f}")
235
+ logger.info(f"{'='*60}")
236
+
237
+
238
+ def main():
239
+ parser = argparse.ArgumentParser(description='PXDesign + SMC Reranking')
240
+ parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json',
241
+ help='PXDesign input JSON')
242
+ parser.add_argument('--designs_dir', default='experiments/pxdesign_cam/output/',
243
+ help='Existing PXDesign outputs (fallback if generation fails)')
244
+ parser.add_argument('--qtheta_checkpoint',
245
+ default='results/checkpoints_cam_v3/best_phase2.pt')
246
+ parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
247
+ parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
248
+ parser.add_argument('--ref_chain', default='A')
249
+ parser.add_argument('--n_particles', type=int, default=16,
250
+ help='Particles per round')
251
+ parser.add_argument('--n_rounds', type=int, default=4,
252
+ help='Number of SMC rounds')
253
+ parser.add_argument('--N_step', type=int, default=400)
254
+ parser.add_argument('--gpu', type=int, default=0)
255
+ parser.add_argument('--outdir', default='results/pxdesign_smc')
256
+ args = parser.parse_args()
257
+
258
+ smc_particle_filter(args)
259
+
260
+
261
+ if __name__ == '__main__':
262
+ main()
code/scripts/pxdesign_guidance/tds_pxdesign.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PXDesign + Twisted Diffusion Sampling (TDS).
3
+
4
+ Multi-round particle filtering with guided PXDesign:
5
+ Round r:
6
+ 1. Generate N particles via PXDesign with Q_theta classifier guidance
7
+ 2. Score each particle with Q_theta selectivity margin
8
+ 3. Compute importance weights w_i ~ exp(margin_i / temperature)
9
+ 4. Resample particles (keep best, discard worst)
10
+ 5. Add perturbation noise for diversity
11
+
12
+ This combines in-process guidance (the "twisted proposal") with post-hoc
13
+ importance-weighted resampling for highest-quality designs.
14
+
15
+ Usage:
16
+ python code/scripts/pxdesign_guidance/tds_pxdesign.py \
17
+ --input experiments/pxdesign_cam/output/cam_binder.json \
18
+ --qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
19
+ --ref_holo data/pdbs/cam_holo/3CLN.pdb \
20
+ --ref_apo data/pdbs/cam_apo/1CFD.pdb \
21
+ --n_particles 16 --n_rounds 4 \
22
+ --guidance_scale 0.5 \
23
+ --gpu 0
24
+ """
25
+
26
+ import os
27
+ import sys
28
+ import argparse
29
+ import json
30
+ import logging
31
+ import shutil
32
+ import subprocess
33
+ from glob import glob
34
+
35
+ import numpy as np
36
+ import torch
37
+
38
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
39
+ logger = logging.getLogger(__name__)
40
+
41
+ _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
42
+ _ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
43
+ _ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
44
+
45
+ if _ALLO_CODE_DIR not in sys.path:
46
+ sys.path.insert(0, _ALLO_CODE_DIR)
47
+
48
+
49
+ def compute_ess(log_weights):
50
+ """Compute effective sample size from log-weights."""
51
+ log_weights = log_weights - log_weights.max()
52
+ weights = np.exp(log_weights)
53
+ weights = weights / weights.sum()
54
+ return 1.0 / (weights ** 2).sum()
55
+
56
+
57
+ def run_guided_pxdesign_batch(input_json, outdir, n_sample, n_step,
58
+ gpu, guidance_args):
59
+ """Run guided PXDesign as a subprocess."""
60
+ pxdesign_python = 'python'
61
+ cmd = [
62
+ pxdesign_python,
63
+ os.path.join(_SCRIPT_DIR, 'guided_pxdesign.py'),
64
+ '--input', input_json,
65
+ '--qtheta_checkpoint', guidance_args['checkpoint'],
66
+ '--ref_holo', guidance_args['ref_holo'],
67
+ '--ref_apo', guidance_args['ref_apo'],
68
+ '--ref_chain', guidance_args['ref_chain'],
69
+ '--guidance_scale', str(guidance_args['guidance_scale']),
70
+ '--guidance_start', str(guidance_args.get('guidance_start', 0.8)),
71
+ '--guidance_end', str(guidance_args.get('guidance_end', 0.1)),
72
+ '--N_sample', str(n_sample),
73
+ '--N_step', str(n_step),
74
+ '--gpu', str(gpu),
75
+ '--outdir', outdir,
76
+ ]
77
+
78
+ env = os.environ.copy()
79
+ # Inherit CUDA_VISIBLE_DEVICES from parent
80
+
81
+ logger.info(f"Running guided PXDesign: {n_sample} samples -> {outdir}")
82
+ result = subprocess.run(cmd, capture_output=True, text=True, env=env,
83
+ timeout=7200)
84
+
85
+ if result.returncode != 0:
86
+ logger.error(f"PXDesign failed:\n{result.stderr[-2000:]}")
87
+ return False
88
+ return True
89
+
90
+
91
+ def run_vanilla_pxdesign_batch(input_json, outdir, n_sample, n_step, gpu):
92
+ """Run vanilla PXDesign (no guidance) as a subprocess."""
93
+ pxdesign_env = 'python'
94
+
95
+ cmd = [
96
+ pxdesign_env, '-m', 'pxdesign.runner.inference',
97
+ '--dump_dir', outdir,
98
+ '--input', input_json,
99
+ '--dtype', 'bf16',
100
+ '--N_sample', str(n_sample),
101
+ '--N_step', str(n_step),
102
+ ]
103
+
104
+ env = os.environ.copy()
105
+ # Inherit CUDA_VISIBLE_DEVICES from parent
106
+
107
+ logger.info(f"Running vanilla PXDesign: {n_sample} samples -> {outdir}")
108
+ result = subprocess.run(cmd, capture_output=True, text=True, env=env,
109
+ timeout=7200)
110
+
111
+ if result.returncode != 0:
112
+ logger.error(f"PXDesign failed:\n{result.stderr[-2000:]}")
113
+ return False
114
+ return True
115
+
116
+
117
+ def collect_pdbs(outdir):
118
+ """Collect PDB/CIF paths from PXDesign output directory."""
119
+ pdbs = []
120
+ for ext in ('*.pdb', '*.cif'):
121
+ pdbs.extend(glob(os.path.join(outdir, '**/' + ext), recursive=True))
122
+ pdbs = sorted(pdbs)
123
+ filtered = [p for p in pdbs if 'sample' in os.path.basename(p).lower()
124
+ or 'design' in os.path.basename(p).lower()
125
+ or 'rank' in os.path.basename(p).lower()]
126
+ return filtered if filtered else pdbs
127
+
128
+
129
+ def tds_particle_filter(args):
130
+ """Run TDS particle filtering with PXDesign."""
131
+ from qtheta_pxdesign import QThetaPXDesignGuidance
132
+
133
+ outdir = os.path.join(_ALLO_ROOT, args.outdir)
134
+ os.makedirs(outdir, exist_ok=True)
135
+
136
+ # Initialize scorer
137
+ guidance = QThetaPXDesignGuidance(
138
+ checkpoint=os.path.join(_ALLO_ROOT, args.qtheta_checkpoint),
139
+ ref_holo=os.path.join(_ALLO_ROOT, args.ref_holo),
140
+ ref_apo=os.path.join(_ALLO_ROOT, args.ref_apo),
141
+ ref_chain=args.ref_chain,
142
+ device=f'cuda:{args.gpu}',
143
+ )
144
+ guidance._lazy_init()
145
+
146
+ guidance_args = {
147
+ 'checkpoint': args.qtheta_checkpoint,
148
+ 'ref_holo': args.ref_holo,
149
+ 'ref_apo': args.ref_apo,
150
+ 'ref_chain': args.ref_chain,
151
+ 'guidance_scale': args.guidance_scale,
152
+ 'guidance_start': args.guidance_start,
153
+ 'guidance_end': args.guidance_end,
154
+ }
155
+
156
+ all_designs = []
157
+ round_summaries = []
158
+
159
+ for round_idx in range(args.n_rounds):
160
+ round_dir = os.path.join(outdir, f'round_{round_idx}')
161
+ os.makedirs(round_dir, exist_ok=True)
162
+
163
+ logger.info(f"\n{'='*60}")
164
+ logger.info(f"TDS Round {round_idx + 1}/{args.n_rounds}")
165
+ logger.info(f"{'='*60}")
166
+
167
+ # Generate particles via guided PXDesign
168
+ gen_dir = os.path.join(round_dir, 'generated')
169
+ success = run_guided_pxdesign_batch(
170
+ input_json=os.path.join(_ALLO_ROOT, args.input),
171
+ outdir=gen_dir,
172
+ n_sample=args.n_particles,
173
+ n_step=args.N_step,
174
+ gpu=args.gpu,
175
+ guidance_args=guidance_args,
176
+ )
177
+
178
+ if not success:
179
+ logger.warning(f"Round {round_idx} generation failed, skipping")
180
+ continue
181
+
182
+ # Collect and score particles
183
+ pdbs = collect_pdbs(gen_dir)
184
+ if not pdbs:
185
+ logger.warning(f"No PDBs found in round {round_idx}")
186
+ continue
187
+
188
+ logger.info(f"Scoring {len(pdbs)} particles...")
189
+ round_results = []
190
+ for pdb_path in pdbs:
191
+ result = guidance.score_design(pdb_path)
192
+ if result is not None:
193
+ result['pdb_path'] = pdb_path
194
+ result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
195
+ result['round'] = round_idx
196
+ round_results.append(result)
197
+
198
+ if not round_results:
199
+ logger.warning(f"No scorable designs in round {round_idx}")
200
+ continue
201
+
202
+ margins = np.array([r['margin'] for r in round_results])
203
+
204
+ # Compute importance weights
205
+ log_weights = margins / args.temperature
206
+ ess = compute_ess(log_weights)
207
+
208
+ round_summary = {
209
+ 'round': round_idx,
210
+ 'n_particles': len(round_results),
211
+ 'margin_mean': float(margins.mean()),
212
+ 'margin_std': float(margins.std()),
213
+ 'margin_max': float(margins.max()),
214
+ 'frac_positive': float((margins > 0).mean()),
215
+ 'ess': float(ess),
216
+ }
217
+ round_summaries.append(round_summary)
218
+
219
+ logger.info(f"Round {round_idx}: margin={margins.mean():.3f}±{margins.std():.3f}, "
220
+ f"max={margins.max():.3f}, S>0={round_summary['frac_positive']:.1%}, "
221
+ f"ESS={ess:.1f}/{len(round_results)}")
222
+
223
+ # Add to design pool
224
+ all_designs.extend(round_results)
225
+
226
+ # Resample for next round (top-K selection for PXDesign since
227
+ # we can't easily perturb and re-denoise)
228
+ if round_idx < args.n_rounds - 1:
229
+ # Copy best designs to inform next round
230
+ # For PXDesign, each round generates fresh samples with guidance
231
+ # Resampling influence is through the guidance strength
232
+ # Increase guidance scale for later rounds
233
+ guidance_args['guidance_scale'] = args.guidance_scale * (1.0 + 0.2 * (round_idx + 1))
234
+ logger.info(f"Increasing guidance scale to {guidance_args['guidance_scale']:.2f} "
235
+ f"for next round")
236
+
237
+ # Final summary
238
+ if all_designs:
239
+ all_designs.sort(key=lambda x: x['margin'], reverse=True)
240
+ all_margins = np.array([d['margin'] for d in all_designs])
241
+ holo_scores = np.array([d['q_holo'] for d in all_designs])
242
+
243
+ # Best-of-K
244
+ bok = {}
245
+ for K in [1, 2, 5, 10]:
246
+ n_trials = 2000
247
+ n_avail = len(all_margins)
248
+ successes = sum(
249
+ 1 for _ in range(n_trials)
250
+ if all_margins[np.random.choice(n_avail, min(K, n_avail), replace=False)].max() > 0
251
+ )
252
+ bok[K] = successes / n_trials
253
+
254
+ summary = {
255
+ 'method': 'PXDesign + TDS',
256
+ 'n_rounds': args.n_rounds,
257
+ 'n_particles_per_round': args.n_particles,
258
+ 'total_designs': len(all_designs),
259
+ 'guidance_scale': args.guidance_scale,
260
+ 'temperature': args.temperature,
261
+ 'margin_mean': float(all_margins.mean()),
262
+ 'margin_std': float(all_margins.std()),
263
+ 'margin_max': float(all_margins.max()),
264
+ 'frac_positive': float((all_margins > 0).mean()),
265
+ 'q_holo_mean': float(holo_scores.mean()),
266
+ 'best_of_k': {str(k): v for k, v in bok.items()},
267
+ 'round_summaries': round_summaries,
268
+ 'top5': all_designs[:5],
269
+ }
270
+
271
+ with open(os.path.join(outdir, 'tds_scores.json'), 'w') as f:
272
+ json.dump(all_designs, f, indent=2)
273
+ with open(os.path.join(outdir, 'tds_summary.json'), 'w') as f:
274
+ json.dump(summary, f, indent=2)
275
+
276
+ # Copy best designs to top-level
277
+ best_dir = os.path.join(outdir, 'best_designs')
278
+ os.makedirs(best_dir, exist_ok=True)
279
+ for i, d in enumerate(all_designs[:20]):
280
+ if os.path.exists(d['pdb_path']):
281
+ dest = os.path.join(best_dir, f'rank_{i:02d}_{d["design_id"]}.pdb')
282
+ shutil.copy2(d['pdb_path'], dest)
283
+
284
+ logger.info(f"\n{'='*60}")
285
+ logger.info(f"PXDesign + TDS Results ({len(all_designs)} total designs)")
286
+ logger.info(f" Margin: {all_margins.mean():.3f} ± {all_margins.std():.3f}")
287
+ logger.info(f" Max margin: {all_margins.max():.3f}")
288
+ logger.info(f" Fraction S > 0: {(all_margins > 0).mean():.1%}")
289
+ logger.info(f" Q(holo) mean: {holo_scores.mean():.3f}")
290
+ logger.info(f" Best-of-K:")
291
+ for k, v in sorted(bok.items()):
292
+ logger.info(f" K={k:3d}: {v:.3f}")
293
+ logger.info(f"{'='*60}")
294
+
295
+
296
+ def main():
297
+ parser = argparse.ArgumentParser(description='PXDesign + TDS')
298
+ parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json')
299
+ parser.add_argument('--qtheta_checkpoint',
300
+ default='results/checkpoints_cam_v3/best_phase2.pt')
301
+ parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
302
+ parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
303
+ parser.add_argument('--ref_chain', default='A')
304
+ parser.add_argument('--n_particles', type=int, default=16,
305
+ help='Particles per round')
306
+ parser.add_argument('--n_rounds', type=int, default=4,
307
+ help='Number of TDS rounds')
308
+ parser.add_argument('--guidance_scale', type=float, default=0.5,
309
+ help='Initial guidance scale')
310
+ parser.add_argument('--guidance_start', type=float, default=0.8)
311
+ parser.add_argument('--guidance_end', type=float, default=0.1)
312
+ parser.add_argument('--temperature', type=float, default=0.5,
313
+ help='Temperature for importance weights')
314
+ parser.add_argument('--N_step', type=int, default=400)
315
+ parser.add_argument('--gpu', type=int, default=0)
316
+ parser.add_argument('--outdir', default='results/pxdesign_tds')
317
+ args = parser.parse_args()
318
+
319
+ tds_particle_filter(args)
320
+
321
+
322
+ if __name__ == '__main__':
323
+ main()
code/scripts/rescore.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Re-score binder PDB designs with a Q_theta checkpoint.
4
+
5
+ Walks a directory of designs (binder PDB + sibling holo / apo receptor PDBs),
6
+ runs each through DifferentiableQTheta, and writes per-design
7
+ S = Q_theta(holo) - Q_theta(apo) plus the raw holo/apo scores to JSON.
8
+
9
+ Usage:
10
+ python code/scripts/rescore.py \\
11
+ --checkpoint checkpoints/Q_theta_phase2.pt \\
12
+ --gpu 0
13
+ """
14
+ import os, sys, json, argparse, glob, logging
15
+ import numpy as np
16
+ import torch
17
+ from pathlib import Path
18
+
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
20
+ logger = logging.getLogger(__name__)
21
+
22
+ BASE = str(Path(__file__).resolve().parent.parent.parent)
23
+ sys.path.insert(0, os.path.join(BASE, 'code'))
24
+ sys.path.insert(0, BASE)
25
+
26
+ from models.differentiable_features import DifferentiableQTheta
27
+ from utils.pdb_utils import load_structure, get_residues, get_backbone_coords, get_aa_indices, align_structures
28
+
29
+ HOLO_PDB = os.path.join(BASE, 'data/pdbs/cam_holo/3CLN.pdb')
30
+ APO_PDB = os.path.join(BASE, 'data/pdbs/cam_apo/1CFD.pdb')
31
+
32
+
33
+ def score_pdb_list(dq, pdb_list, ref_resnums, ref_coords, device):
34
+ """Score a list of design PDB files."""
35
+ results = []
36
+ for pdb_path in pdb_list:
37
+ name = os.path.basename(pdb_path).replace(".pdb", "")
38
+ try:
39
+ design_model = load_structure(pdb_path)
40
+ chains = [c.id for c in design_model.get_chains()]
41
+ rec_chain = 'A' if 'A' in chains else chains[0]
42
+ binder_chain = 'B' if 'B' in chains else [c for c in chains if c != rec_chain][0]
43
+
44
+ rec_res = get_residues(design_model[rec_chain])
45
+ binder_res = get_residues(design_model[binder_chain])
46
+ rec_coords_d, _ = get_backbone_coords(rec_res)
47
+ binder_coords, binder_mask = get_backbone_coords(binder_res)
48
+ binder_aa_idx = get_aa_indices(binder_res)
49
+
50
+ design_resnums = {r.get_id()[1]: i for i, r in enumerate(rec_res)}
51
+ common = sorted(set(design_resnums.keys()) & set(ref_resnums.keys()))
52
+ if len(common) < 10:
53
+ logger.warning(f" Skip {name}: <10 common residues")
54
+ continue
55
+
56
+ d_ca = rec_coords_d[[design_resnums[r] for r in common], 1]
57
+ r_ca = ref_coords[[ref_resnums[r] for r in common], 1]
58
+ mobile_center = d_ca.mean(0)
59
+ ref_center = r_ca.mean(0)
60
+ _, R = align_structures(d_ca, r_ca)
61
+
62
+ flat = binder_coords.reshape(-1, 3) - mobile_center
63
+ aligned_binder = (flat @ R.T + ref_center).reshape(-1, 4, 3)
64
+
65
+ coords_t = torch.from_numpy(aligned_binder).float().to(device)
66
+ mask_t = torch.from_numpy(binder_mask).bool().to(device)
67
+ aa_t = torch.from_numpy(binder_aa_idx).long().to(device)
68
+
69
+ with torch.no_grad():
70
+ q_holo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t,
71
+ receptor_label='holo').item()
72
+ q_apo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t,
73
+ receptor_label='apo').item()
74
+ S = q_holo - q_apo
75
+ results.append({"design": name, "Q_holo": q_holo, "Q_apo": q_apo, "S": S})
76
+ except Exception as e:
77
+ logger.warning(f" Skip {name}: {e}")
78
+ return results
79
+
80
+
81
+ def summarize(results, label):
82
+ if not results:
83
+ return {}
84
+ S = [r["S"] for r in results]
85
+ return {
86
+ "method": label, "n": len(S),
87
+ "S_mean": float(np.mean(S)), "S_std": float(np.std(S)),
88
+ "S_pos_pct": float(np.mean([s > 0 for s in S]) * 100),
89
+ "Q_holo_mean": float(np.mean([r["Q_holo"] for r in results])),
90
+ "Q_apo_mean": float(np.mean([r["Q_apo"] for r in results])),
91
+ }
92
+
93
+
94
+ def main():
95
+ parser = argparse.ArgumentParser()
96
+ parser.add_argument("--gpu", type=int, default=7)
97
+ parser.add_argument("--checkpoint", default="checkpoints/Q_theta_phase2.pt")
98
+ args = parser.parse_args()
99
+
100
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
101
+ device = "cuda:0"
102
+
103
+ logger.info(f"Loading Q_theta from {args.checkpoint}")
104
+ dq = DifferentiableQTheta(checkpoint_path=args.checkpoint, device=device,
105
+ esm_dir=os.path.join(BASE, "data/esm2_embeddings"))
106
+ dq.load_receptor(HOLO_PDB, chain='A', label='holo', esm_target='cam')
107
+ dq.load_receptor(APO_PDB, chain='A', label='apo', esm_target='cam')
108
+
109
+ ref_model = load_structure(HOLO_PDB)
110
+ ref_res = get_residues(ref_model['A'])
111
+ ref_coords, _ = get_backbone_coords(ref_res)
112
+ ref_resnums = {r.get_id()[1]: i for i, r in enumerate(ref_res)}
113
+
114
+ output_dir = os.path.join(BASE, "results/v2_strict_holdout/scoring")
115
+ os.makedirs(output_dir, exist_ok=True)
116
+
117
+ # Define design directories
118
+ design_sets = {
119
+ "vanilla": os.path.join(BASE, "results/independent_validation/vanilla/holo_pdbs"),
120
+ "langevin": os.path.join(BASE, "results/langevin_refinement/refined_pdbs"),
121
+ "classifier": os.path.join(BASE, "results/guided_diffusion/guided"),
122
+ "smc_r3": os.path.join(BASE, "results/smc_guidance/cam/round_3"),
123
+ }
124
+
125
+ # Also check for TDS and PXDesign
126
+ tds_dirs = glob.glob(os.path.join(BASE, "results/tds_guidance/cam/designs"))
127
+ if tds_dirs:
128
+ design_sets["tds"] = tds_dirs[0]
129
+
130
+ # PXDesign directories
131
+ for px_method in ["pxdesign_scoring", "pxdesign_classifier", "pxdesign_tds",
132
+ "pxdesign_smc", "pxdesign_langevin"]:
133
+ px_dir = os.path.join(BASE, f"results_familysplit/design_bd30/{px_method}")
134
+ if not os.path.exists(px_dir):
135
+ px_dir = os.path.join(BASE, f"results/{px_method}")
136
+ if os.path.exists(px_dir):
137
+ pdbs = glob.glob(os.path.join(px_dir, "*.pdb"))
138
+ if pdbs:
139
+ design_sets[px_method] = px_dir
140
+
141
+ all_results = {}
142
+ summaries = []
143
+
144
+ for method, pdb_dir in design_sets.items():
145
+ if not os.path.exists(pdb_dir):
146
+ logger.warning(f" {method}: directory not found ({pdb_dir})")
147
+ continue
148
+ pdbs = sorted(glob.glob(os.path.join(pdb_dir, "*.pdb")))
149
+ if not pdbs:
150
+ logger.warning(f" {method}: no PDB files")
151
+ continue
152
+
153
+ logger.info(f"\n=== {method} ({len(pdbs)} designs) ===")
154
+ results = score_pdb_list(dq, pdbs, ref_resnums, ref_coords, device)
155
+ s = summarize(results, method)
156
+ if s:
157
+ summaries.append(s)
158
+ logger.info(f" {method}: n={s['n']}, S̄={s['S_mean']:.3f}±{s['S_std']:.3f}, "
159
+ f"S>0={s['S_pos_pct']:.0f}%, Q+={s['Q_holo_mean']:.3f}, Q-={s['Q_apo_mean']:.3f}")
160
+ all_results[method] = {"results": results, "summary": s}
161
+
162
+ # Save
163
+ with open(os.path.join(output_dir, "rescore_v2_all.json"), "w") as f:
164
+ json.dump(all_results, f, indent=2)
165
+
166
+ # Print summary table
167
+ print("\n" + "=" * 70)
168
+ print("V2 RESCORING SUMMARY (strict holdout, CaM OOD)")
169
+ print("=" * 70)
170
+ print(f"{'Method':20s} {'n':>4s} {'S̄':>8s} {'±σ':>6s} {'S>0%':>6s} {'Q+':>6s} {'Q-':>6s}")
171
+ print("-" * 70)
172
+ for s in sorted(summaries, key=lambda x: x['S_mean'], reverse=True):
173
+ print(f"{s['method']:20s} {s['n']:4d} {s['S_mean']:8.3f} {s['S_std']:6.3f} "
174
+ f"{s['S_pos_pct']:5.1f}% {s['Q_holo_mean']:6.3f} {s['Q_apo_mean']:6.3f}")
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()
code/trainers/__init__.py ADDED
File without changes
code/trainers/trainer.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trainer for the Q_theta state-selectivity scorer.
3
+
4
+ Implements two-phase training:
5
+ Phase 1: DockQ regression (learn complex quality from all data)
6
+ Phase 2: Selectivity fine-tuning (learn to rank X+ > X- for the same binder)
7
+
8
+ Integrates with Weights & Biases for experiment tracking.
9
+ """
10
+
11
+ import os
12
+ import time
13
+ import logging
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.optim import AdamW
18
+ from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
19
+ from scipy.stats import spearmanr
20
+ from sklearn.metrics import roc_auc_score
21
+
22
+ import wandb
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class AverageMeter:
28
+ def __init__(self):
29
+ self.reset()
30
+
31
+ def reset(self):
32
+ self.val = 0.0
33
+ self.avg = 0.0
34
+ self.sum = 0.0
35
+ self.count = 0
36
+
37
+ def update(self, val, n=1):
38
+ self.val = val
39
+ self.sum += val * n
40
+ self.count += n
41
+ self.avg = self.sum / self.count
42
+
43
+
44
+ class AlloDesignerTrainer:
45
+ """
46
+ Two-phase trainer for Q_theta.
47
+
48
+ Phase 1 (DockQ regression):
49
+ - Minimizes MSE(Q_theta(X, Y), DockQ_label) on all complex types
50
+ - Learns general complex quality
51
+
52
+ Phase 2 (Selectivity fine-tuning):
53
+ - Minimizes selectivity margin loss on paired (pos, neg) data
54
+ - Learns to rank Q(X+, Y) > Q(X-, Y)
55
+ - Combined: L = L_regression + lambda_rank * L_selectivity
56
+ """
57
+
58
+ def __init__(self, model, config, device='cuda'):
59
+ self.model = model.to(device)
60
+ self.config = config
61
+ self.device = device
62
+ self.use_sam = config.get('optimizer', 'adamw') == 'sam'
63
+
64
+ # Optimizer
65
+ if self.use_sam:
66
+ from utils.sam import SAM
67
+ self.optimizer = SAM(
68
+ model.parameters(),
69
+ base_optimizer=AdamW,
70
+ rho=0.05,
71
+ lr=config.get('lr', 1e-4),
72
+ weight_decay=config.get('weight_decay', 1e-4),
73
+ betas=(0.9, 0.999),
74
+ )
75
+ # SAM wraps AdamW; scheduler goes on base_optimizer
76
+ sched_optimizer = self.optimizer.base_optimizer
77
+ else:
78
+ self.optimizer = AdamW(
79
+ model.parameters(),
80
+ lr=config.get('lr', 1e-4),
81
+ weight_decay=config.get('weight_decay', 1e-4),
82
+ betas=(0.9, 0.999),
83
+ )
84
+ sched_optimizer = self.optimizer
85
+
86
+ # Learning rate scheduler (warmup + cosine)
87
+ n_warmup = config.get('warmup_steps', 100)
88
+ n_total = config.get('max_steps', 5000)
89
+
90
+ warmup_sched = LinearLR(sched_optimizer, start_factor=0.01, end_factor=1.0, total_iters=n_warmup)
91
+ cosine_sched = CosineAnnealingLR(sched_optimizer, T_max=n_total - n_warmup, eta_min=1e-6)
92
+ self.scheduler = SequentialLR(sched_optimizer, [warmup_sched, cosine_sched], milestones=[n_warmup])
93
+
94
+ self.global_step = 0
95
+ self.best_val_metric = -float('inf')
96
+ self.checkpoint_dir = config.get('checkpoint_dir', 'results/checkpoints')
97
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
98
+
99
+ # ------------------------------------------------------------------ #
100
+ # Phase 1: DockQ regression
101
+ # ------------------------------------------------------------------ #
102
+
103
+ def train_step_phase1(self, batch):
104
+ """Single training step for Phase 1 (DockQ regression)."""
105
+ self.model.train()
106
+ node_feats = batch['node_feats'].to(self.device) # [B, N, node_dim]
107
+ edge_feats = batch['edge_feats'].to(self.device) # [B, N, N, edge_dim]
108
+ node_mask = batch['node_mask'].to(self.device) # [B, N]
109
+ labels = batch['label'].to(self.device) # [B]
110
+ esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None
111
+
112
+ self.optimizer.zero_grad()
113
+
114
+ scores = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats) # [B]
115
+ loss = nn.functional.mse_loss(scores, labels)
116
+
117
+ loss.backward()
118
+ nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
119
+
120
+ if self.use_sam:
121
+ self.optimizer.first_step()
122
+ # Second forward-backward pass
123
+ scores2 = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats)
124
+ loss2 = nn.functional.mse_loss(scores2, labels)
125
+ self.optimizer.zero_grad()
126
+ loss2.backward()
127
+ nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
128
+ self.optimizer.second_step()
129
+ else:
130
+ self.optimizer.step()
131
+
132
+ self.scheduler.step()
133
+ self.global_step += 1
134
+
135
+ return {'loss': loss.item(), 'scores': scores.detach(), 'labels': labels}
136
+
137
+ def run_phase1(self, train_loader, val_loader, n_epochs: int = 30, run_name: str = 'phase1'):
138
+ """Phase 1 training loop."""
139
+ logger.info(f"Starting Phase 1 (DockQ regression) for {n_epochs} epochs")
140
+ wandb.define_metric('phase1/step')
141
+ wandb.define_metric('phase1/*', step_metric='phase1/step')
142
+
143
+ for epoch in range(n_epochs):
144
+ # Train
145
+ train_meter = AverageMeter()
146
+ all_scores, all_labels = [], []
147
+
148
+ for batch in train_loader:
149
+ result = self.train_step_phase1(batch)
150
+ train_meter.update(result['loss'], n=len(result['scores']))
151
+ all_scores.append(result['scores'].cpu().numpy())
152
+ all_labels.append(result['labels'].cpu().numpy())
153
+
154
+ if self.global_step % 50 == 0:
155
+ wandb.log({
156
+ 'phase1/train_loss': result['loss'],
157
+ 'phase1/lr': self.optimizer.param_groups[0]['lr'],
158
+ 'phase1/step': self.global_step,
159
+ })
160
+
161
+ # Compute Spearman corr on training data
162
+ all_scores = np.concatenate(all_scores)
163
+ all_labels = np.concatenate(all_labels)
164
+ train_spearman = spearmanr(all_scores, all_labels).correlation
165
+
166
+ # Validate
167
+ val_metrics = self.evaluate_phase1(val_loader)
168
+
169
+ logger.info(
170
+ f"Phase1 Epoch {epoch+1}/{n_epochs} | "
171
+ f"Train Loss: {train_meter.avg:.4f} | "
172
+ f"Train Spearman: {train_spearman:.3f} | "
173
+ f"Val Loss: {val_metrics['val_loss']:.4f} | "
174
+ f"Val Spearman: {val_metrics['val_spearman']:.3f} | "
175
+ f"Val AUC: {val_metrics.get('val_auc', 0):.3f}"
176
+ )
177
+
178
+ wandb.log({
179
+ 'phase1/epoch': epoch + 1,
180
+ 'phase1/train_loss_epoch': train_meter.avg,
181
+ 'phase1/train_spearman': train_spearman,
182
+ **{f'phase1/{k}': v for k, v in val_metrics.items()},
183
+ })
184
+
185
+ # Checkpoint best model
186
+ if val_metrics['val_spearman'] > self.best_val_metric:
187
+ self.best_val_metric = val_metrics['val_spearman']
188
+ self.save_checkpoint('best_phase1.pt', extra={'epoch': epoch, 'phase': 1})
189
+ logger.info(f" -> New best Phase 1 model (val_spearman={self.best_val_metric:.3f})")
190
+
191
+ logger.info("Phase 1 training complete.")
192
+
193
+ @torch.no_grad()
194
+ def evaluate_phase1(self, loader):
195
+ """Evaluate Phase 1 model on val/test set."""
196
+ self.model.eval()
197
+ all_scores, all_labels = [], []
198
+ total_loss = 0.0
199
+ n_batches = 0
200
+
201
+ for batch in loader:
202
+ node_feats = batch['node_feats'].to(self.device)
203
+ edge_feats = batch['edge_feats'].to(self.device)
204
+ node_mask = batch['node_mask'].to(self.device)
205
+ labels = batch['label'].to(self.device)
206
+ esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None
207
+
208
+ scores = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats)
209
+ loss = nn.functional.mse_loss(scores, labels)
210
+
211
+ total_loss += loss.item()
212
+ n_batches += 1
213
+ all_scores.append(scores.cpu().numpy())
214
+ all_labels.append(labels.cpu().numpy())
215
+
216
+ all_scores = np.concatenate(all_scores)
217
+ all_labels = np.concatenate(all_labels)
218
+
219
+ spearman = spearmanr(all_scores, all_labels).correlation
220
+ if np.isnan(spearman):
221
+ spearman = 0.0
222
+
223
+ metrics = {
224
+ 'val_loss': total_loss / max(n_batches, 1),
225
+ 'val_spearman': float(spearman),
226
+ }
227
+
228
+ # AUC for binary quality (label > 0.5 = positive)
229
+ binary_labels = (all_labels > 0.5).astype(int)
230
+ if binary_labels.sum() > 0 and binary_labels.sum() < len(binary_labels):
231
+ try:
232
+ metrics['val_auc'] = roc_auc_score(binary_labels, all_scores)
233
+ except Exception:
234
+ pass
235
+
236
+ return metrics
237
+
238
+ # ------------------------------------------------------------------ #
239
+ # Phase 2: Selectivity fine-tuning
240
+ # ------------------------------------------------------------------ #
241
+
242
+ def train_step_phase2(self, batch, lambda_rank: float = 1.0, margin: float = 0.2,
243
+ lambda_ddg: float = 0.1):
244
+ """Single training step for Phase 2 (selectivity margin + ddG auxiliary)."""
245
+ self.model.train()
246
+
247
+ pos = batch['pos']
248
+ neg = batch['neg']
249
+
250
+ pos_node = pos['node_feats'].to(self.device)
251
+ pos_edge = pos['edge_feats'].to(self.device)
252
+ pos_mask = pos['node_mask'].to(self.device)
253
+ pos_label = pos['label'].to(self.device)
254
+ pos_ce = pos.get('contact_energy', None)
255
+ if pos_ce is not None:
256
+ pos_ce = pos_ce.to(self.device)
257
+
258
+ neg_node = neg['node_feats'].to(self.device)
259
+ neg_edge = neg['edge_feats'].to(self.device)
260
+ neg_mask = neg['node_mask'].to(self.device)
261
+ pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None
262
+ neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None
263
+
264
+ self.optimizer.zero_grad()
265
+
266
+ pos_scores = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm) # [B]
267
+ neg_scores = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm) # [B]
268
+
269
+ # Regression loss on positive examples
270
+ loss_reg = nn.functional.mse_loss(pos_scores, pos_label)
271
+
272
+ # Selectivity margin loss: pos_score - neg_score > margin
273
+ loss_margin = nn.functional.relu(margin - (pos_scores - neg_scores)).mean()
274
+
275
+ # InfoNCE-style selectivity loss
276
+ eps = 1e-6
277
+ pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps))
278
+ neg_logit = torch.log(neg_scores.clamp(eps, 1 - eps) / (1 - neg_scores).clamp(eps))
279
+ log_denom = torch.stack([pos_logit, neg_logit], dim=-1).logsumexp(dim=-1)
280
+ infonce_loss = -(pos_logit - log_denom).mean()
281
+
282
+ # ddG auxiliary loss: MSE against contact-energy proxy (physics-informed soft label)
283
+ loss_ddg = torch.tensor(0.0, device=self.device)
284
+ if pos_ce is not None and pos_ce.shape[0] > 0:
285
+ # pos_ce is a contact-energy-based ddG proxy in [0, 1]
286
+ # Align positive score toward the contact energy signal
287
+ loss_ddg = nn.functional.mse_loss(pos_scores, pos_ce)
288
+
289
+ loss = loss_reg + lambda_rank * (loss_margin + infonce_loss) + lambda_ddg * loss_ddg
290
+
291
+ loss.backward()
292
+ nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
293
+
294
+ if self.use_sam:
295
+ self.optimizer.first_step()
296
+ # Second forward-backward for SAM
297
+ pos_scores2 = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm)
298
+ neg_scores2 = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm)
299
+ loss_reg2 = nn.functional.mse_loss(pos_scores2, pos_label)
300
+ loss_margin2 = nn.functional.relu(margin - (pos_scores2 - neg_scores2)).mean()
301
+ eps2 = 1e-6
302
+ pl2 = torch.log(pos_scores2.clamp(eps2, 1-eps2) / (1-pos_scores2).clamp(eps2))
303
+ nl2 = torch.log(neg_scores2.clamp(eps2, 1-eps2) / (1-neg_scores2).clamp(eps2))
304
+ ld2 = torch.stack([pl2, nl2], dim=-1).logsumexp(dim=-1)
305
+ infonce2 = -(pl2 - ld2).mean()
306
+ loss2 = loss_reg2 + lambda_rank * (loss_margin2 + infonce2)
307
+ self.optimizer.zero_grad()
308
+ loss2.backward()
309
+ nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
310
+ self.optimizer.second_step()
311
+ else:
312
+ self.optimizer.step()
313
+
314
+ self.scheduler.step()
315
+ self.global_step += 1
316
+
317
+ selectivity_gap = (pos_scores - neg_scores).mean().item()
318
+
319
+ return {
320
+ 'loss': loss.item(),
321
+ 'loss_reg': loss_reg.item(),
322
+ 'loss_margin': loss_margin.item(),
323
+ 'loss_infonce': infonce_loss.item(),
324
+ 'loss_ddg': loss_ddg.item(),
325
+ 'selectivity_gap': selectivity_gap,
326
+ 'pos_scores': pos_scores.detach(),
327
+ 'neg_scores': neg_scores.detach(),
328
+ }
329
+
330
+ def train_step_phase2_v2(self, batch, lambda_rank: float = 1.0, margin: float = 0.2,
331
+ lambda_ddg: float = 0.0, lambda_path: float = 0.5):
332
+ """Phase 2 training step with multi-negative + path monotonicity."""
333
+ self.model.train()
334
+
335
+ pos = batch['pos']
336
+ neg = batch['neg']
337
+
338
+ pos_node = pos['node_feats'].to(self.device)
339
+ pos_edge = pos['edge_feats'].to(self.device)
340
+ pos_mask = pos['node_mask'].to(self.device)
341
+ pos_label = pos['label'].to(self.device)
342
+ pos_ce = pos.get('contact_energy', None)
343
+ if pos_ce is not None:
344
+ pos_ce = pos_ce.to(self.device)
345
+
346
+ neg_node = neg['node_feats'].to(self.device)
347
+ neg_edge = neg['edge_feats'].to(self.device)
348
+ neg_mask = neg['node_mask'].to(self.device)
349
+ pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None
350
+ neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None
351
+
352
+ self.optimizer.zero_grad()
353
+
354
+ pos_scores = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm)
355
+ neg_scores = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm)
356
+
357
+ # Score path frames if present
358
+ path_scores = []
359
+ path_taus = batch.get('path_taus', [])
360
+ for path_frame in batch.get('path', []):
361
+ p_node = path_frame['node_feats'].to(self.device)
362
+ p_edge = path_frame['edge_feats'].to(self.device)
363
+ p_mask = path_frame['node_mask'].to(self.device)
364
+ p_score = self.model(p_node, p_edge, p_mask)
365
+ path_scores.append(p_score)
366
+
367
+ # Regression loss on positive examples
368
+ loss_reg = nn.functional.mse_loss(pos_scores, pos_label)
369
+
370
+ # Selectivity margin loss
371
+ loss_margin = nn.functional.relu(margin - (pos_scores - neg_scores)).mean()
372
+
373
+ # InfoNCE-style selectivity loss
374
+ eps = 1e-6
375
+ pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps))
376
+ neg_logit = torch.log(neg_scores.clamp(eps, 1 - eps) / (1 - neg_scores).clamp(eps))
377
+ log_denom = torch.stack([pos_logit, neg_logit], dim=-1).logsumexp(dim=-1)
378
+ infonce_loss = -(pos_logit - log_denom).mean()
379
+
380
+ # ddG auxiliary loss
381
+ loss_ddg = torch.tensor(0.0, device=self.device)
382
+ if pos_ce is not None and pos_ce.shape[0] > 0 and lambda_ddg > 0:
383
+ loss_ddg = nn.functional.mse_loss(pos_scores, pos_ce)
384
+
385
+ # Path monotonicity loss
386
+ loss_path = torch.tensor(0.0, device=self.device)
387
+ if path_scores and lambda_path > 0:
388
+ small_margin = 0.05
389
+ for i in range(len(path_scores) - 1):
390
+ loss_path = loss_path + nn.functional.relu(
391
+ path_scores[i] - path_scores[i + 1] + small_margin
392
+ ).mean()
393
+ # Last path frame < positive score
394
+ loss_path = loss_path + nn.functional.relu(
395
+ path_scores[-1] - pos_scores + margin
396
+ ).mean()
397
+ # First path frame > negative score
398
+ loss_path = loss_path + nn.functional.relu(
399
+ neg_scores - path_scores[0] + small_margin
400
+ ).mean()
401
+
402
+ loss = (loss_reg + lambda_rank * (loss_margin + infonce_loss)
403
+ + lambda_ddg * loss_ddg + lambda_path * loss_path)
404
+
405
+ loss.backward()
406
+ nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
407
+ self.optimizer.step()
408
+ self.scheduler.step()
409
+ self.global_step += 1
410
+
411
+ selectivity_gap = (pos_scores - neg_scores).mean().item()
412
+
413
+ return {
414
+ 'loss': loss.item(),
415
+ 'loss_reg': loss_reg.item(),
416
+ 'loss_margin': loss_margin.item(),
417
+ 'loss_infonce': infonce_loss.item(),
418
+ 'loss_ddg': loss_ddg.item(),
419
+ 'loss_path': loss_path.item(),
420
+ 'selectivity_gap': selectivity_gap,
421
+ 'pos_scores': pos_scores.detach(),
422
+ 'neg_scores': neg_scores.detach(),
423
+ }
424
+
425
+ def run_phase2_path(self, train_loader, val_loader, n_epochs: int = 20,
426
+ lambda_rank: float = 1.0, margin: float = 0.2,
427
+ lambda_ddg: float = 0.0, lambda_path: float = 0.5):
428
+ """Phase 2 with path-aware training loop."""
429
+ logger.info(f"Starting Phase 2 (path-aware) for {n_epochs} epochs "
430
+ f"[lambda_rank={lambda_rank}, lambda_path={lambda_path}]")
431
+ self.best_val_metric = -float('inf')
432
+
433
+ for epoch in range(n_epochs):
434
+ loss_meter = AverageMeter()
435
+ gap_meter = AverageMeter()
436
+ path_meter = AverageMeter()
437
+
438
+ for batch in train_loader:
439
+ result = self.train_step_phase2_v2(
440
+ batch, lambda_rank, margin, lambda_ddg, lambda_path)
441
+ B = len(result['pos_scores'])
442
+ loss_meter.update(result['loss'], B)
443
+ gap_meter.update(result['selectivity_gap'], B)
444
+ path_meter.update(result['loss_path'], B)
445
+
446
+ if self.global_step % 50 == 0:
447
+ wandb.log({
448
+ 'phase2/train_loss': result['loss'],
449
+ 'phase2/loss_margin': result['loss_margin'],
450
+ 'phase2/loss_infonce': result['loss_infonce'],
451
+ 'phase2/loss_path': result['loss_path'],
452
+ 'phase2/selectivity_gap': result['selectivity_gap'],
453
+ 'phase2/lr': self.optimizer.param_groups[0]['lr'],
454
+ 'phase2/step': self.global_step,
455
+ })
456
+
457
+ val_metrics = self.evaluate_phase2(val_loader)
458
+
459
+ logger.info(
460
+ f"Phase2-Path Epoch {epoch+1}/{n_epochs} | "
461
+ f"Loss: {loss_meter.avg:.4f} | "
462
+ f"Gap: {gap_meter.avg:.3f} | "
463
+ f"Path: {path_meter.avg:.4f} | "
464
+ f"Val Gap: {val_metrics['val_selectivity_gap']:.3f} | "
465
+ f"Val Acc: {val_metrics['val_ranking_acc']:.3f}"
466
+ )
467
+
468
+ wandb.log({
469
+ 'phase2/epoch': epoch + 1,
470
+ 'phase2/train_loss_epoch': loss_meter.avg,
471
+ 'phase2/train_gap_epoch': gap_meter.avg,
472
+ 'phase2/train_path_loss_epoch': path_meter.avg,
473
+ **{f'phase2/{k}': v for k, v in val_metrics.items()},
474
+ })
475
+
476
+ if val_metrics['val_selectivity_gap'] > self.best_val_metric:
477
+ self.best_val_metric = val_metrics['val_selectivity_gap']
478
+ self.save_checkpoint('best_phase2.pt', extra={'epoch': epoch, 'phase': 2})
479
+ logger.info(f" -> New best Phase 2 model (val_gap={self.best_val_metric:.3f})")
480
+
481
+ logger.info("Phase 2 (path-aware) training complete.")
482
+
483
+ def run_phase2(self, train_loader, val_loader, n_epochs: int = 20,
484
+ lambda_rank: float = 1.0, margin: float = 0.2,
485
+ lambda_ddg: float = 0.1):
486
+ """Phase 2 training loop (selectivity fine-tuning + ddG auxiliary)."""
487
+ logger.info(f"Starting Phase 2 (selectivity fine-tuning) for {n_epochs} epochs "
488
+ f"[lambda_rank={lambda_rank}, lambda_ddg={lambda_ddg}]")
489
+ self.best_val_metric = -float('inf')
490
+
491
+ for epoch in range(n_epochs):
492
+ loss_meter = AverageMeter()
493
+ gap_meter = AverageMeter()
494
+
495
+ for batch in train_loader:
496
+ result = self.train_step_phase2(batch, lambda_rank, margin, lambda_ddg)
497
+ B = len(result['pos_scores'])
498
+ loss_meter.update(result['loss'], B)
499
+ gap_meter.update(result['selectivity_gap'], B)
500
+
501
+ if self.global_step % 50 == 0:
502
+ wandb.log({
503
+ 'phase2/train_loss': result['loss'],
504
+ 'phase2/loss_margin': result['loss_margin'],
505
+ 'phase2/loss_infonce': result['loss_infonce'],
506
+ 'phase2/loss_ddg': result['loss_ddg'],
507
+ 'phase2/selectivity_gap': result['selectivity_gap'],
508
+ 'phase2/lr': self.optimizer.param_groups[0]['lr'],
509
+ 'phase2/step': self.global_step,
510
+ })
511
+
512
+ # Validate
513
+ val_metrics = self.evaluate_phase2(val_loader)
514
+
515
+ logger.info(
516
+ f"Phase2 Epoch {epoch+1}/{n_epochs} | "
517
+ f"Loss: {loss_meter.avg:.4f} | "
518
+ f"Gap: {gap_meter.avg:.3f} | "
519
+ f"Val Gap: {val_metrics['val_selectivity_gap']:.3f} | "
520
+ f"Val Acc: {val_metrics['val_ranking_acc']:.3f}"
521
+ )
522
+
523
+ wandb.log({
524
+ 'phase2/epoch': epoch + 1,
525
+ 'phase2/train_loss_epoch': loss_meter.avg,
526
+ 'phase2/train_gap_epoch': gap_meter.avg,
527
+ **{f'phase2/{k}': v for k, v in val_metrics.items()},
528
+ })
529
+
530
+ # Checkpoint
531
+ if val_metrics['val_selectivity_gap'] > self.best_val_metric:
532
+ self.best_val_metric = val_metrics['val_selectivity_gap']
533
+ self.save_checkpoint('best_phase2.pt', extra={'epoch': epoch, 'phase': 2})
534
+ logger.info(f" -> New best Phase 2 model (val_gap={self.best_val_metric:.3f})")
535
+
536
+ logger.info("Phase 2 training complete.")
537
+
538
+ @torch.no_grad()
539
+ def evaluate_phase2(self, loader):
540
+ """Evaluate selectivity on paired (pos, neg) val set."""
541
+ self.model.eval()
542
+ all_pos_scores, all_neg_scores = [], []
543
+
544
+ for batch in loader:
545
+ if 'pos' not in batch:
546
+ continue
547
+ pos = batch['pos']
548
+ neg = batch['neg']
549
+
550
+ pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None
551
+ neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None
552
+ pos_scores = self.model(
553
+ pos['node_feats'].to(self.device),
554
+ pos['edge_feats'].to(self.device),
555
+ pos['node_mask'].to(self.device),
556
+ esm_feats=pos_esm
557
+ )
558
+ neg_scores = self.model(
559
+ neg['node_feats'].to(self.device),
560
+ neg['edge_feats'].to(self.device),
561
+ neg['node_mask'].to(self.device),
562
+ esm_feats=neg_esm
563
+ )
564
+ all_pos_scores.append(pos_scores.cpu().numpy())
565
+ all_neg_scores.append(neg_scores.cpu().numpy())
566
+
567
+ if not all_pos_scores:
568
+ return {'val_selectivity_gap': 0.0, 'val_ranking_acc': 0.5}
569
+
570
+ all_pos = np.concatenate(all_pos_scores)
571
+ all_neg = np.concatenate(all_neg_scores)
572
+
573
+ gap = float((all_pos - all_neg).mean())
574
+ acc = float((all_pos > all_neg).mean())
575
+
576
+ return {
577
+ 'val_selectivity_gap': gap,
578
+ 'val_ranking_acc': acc,
579
+ 'val_pos_score_mean': float(all_pos.mean()),
580
+ 'val_neg_score_mean': float(all_neg.mean()),
581
+ }
582
+
583
+ # ------------------------------------------------------------------ #
584
+ # Checkpointing
585
+ # ------------------------------------------------------------------ #
586
+
587
+ def save_checkpoint(self, filename: str, extra: dict = None):
588
+ path = os.path.join(self.checkpoint_dir, filename)
589
+ state = {
590
+ 'model_state': self.model.state_dict(),
591
+ 'optimizer_state': self.optimizer.state_dict(),
592
+ 'global_step': self.global_step,
593
+ 'config': self.config,
594
+ }
595
+ if extra:
596
+ state.update(extra)
597
+ torch.save(state, path)
598
+ logger.debug(f"Saved checkpoint: {path}")
599
+
600
+ def load_checkpoint(self, filename: str):
601
+ path = os.path.join(self.checkpoint_dir, filename)
602
+ if not os.path.exists(path):
603
+ logger.warning(f"Checkpoint not found: {path}")
604
+ return False
605
+ state = torch.load(path, map_location=self.device)
606
+ self.model.load_state_dict(state['model_state'])
607
+ self.optimizer.load_state_dict(state['optimizer_state'])
608
+ self.global_step = state.get('global_step', 0)
609
+ logger.info(f"Loaded checkpoint from {path} (step {self.global_step})")
610
+ return True
611
+
612
+ # ------------------------------------------------------------------ #
613
+ # Full evaluation (test set)
614
+ # ------------------------------------------------------------------ #
615
+
616
+ @torch.no_grad()
617
+ def evaluate_test(self, test_loader, phase: int = 2):
618
+ """Full evaluation on test set with all metrics."""
619
+ self.model.eval()
620
+ all_scores, all_labels, all_types = [], [], []
621
+
622
+ for batch in test_loader:
623
+ if 'pos' in batch:
624
+ # Paired batch
625
+ for key in ['pos', 'neg']:
626
+ d = batch[key]
627
+ d_esm = d['esm_feats'].to(self.device) if 'esm_feats' in d else None
628
+ scores = self.model(
629
+ d['node_feats'].to(self.device),
630
+ d['edge_feats'].to(self.device),
631
+ d['node_mask'].to(self.device),
632
+ esm_feats=d_esm
633
+ )
634
+ all_scores.extend(scores.cpu().numpy().tolist())
635
+ all_labels.extend(d['label'].numpy().tolist())
636
+ all_types.extend(['pos' if key == 'pos' else 'neg'] * len(scores))
637
+ else:
638
+ esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None
639
+ scores = self.model(
640
+ batch['node_feats'].to(self.device),
641
+ batch['edge_feats'].to(self.device),
642
+ batch['node_mask'].to(self.device),
643
+ esm_feats=esm_feats
644
+ )
645
+ all_scores.extend(scores.cpu().numpy().tolist())
646
+ all_labels.extend(batch['label'].numpy().tolist())
647
+ all_types.extend(batch['type'])
648
+
649
+ all_scores = np.array(all_scores)
650
+ all_labels = np.array(all_labels)
651
+
652
+ metrics = {}
653
+
654
+ # Spearman correlation (all samples)
655
+ metrics['test_spearman'] = float(spearmanr(all_scores, all_labels).correlation or 0)
656
+
657
+ # AUC (binary: label > 0.5 = positive quality)
658
+ binary = (all_labels > 0.5).astype(int)
659
+ if binary.sum() > 0 and binary.sum() < len(binary):
660
+ try:
661
+ metrics['test_auc'] = float(roc_auc_score(binary, all_scores))
662
+ except Exception:
663
+ pass
664
+
665
+ # Selectivity gap (pos vs neg_apo pairs)
666
+ pos_mask = np.array([t == 'pos' or t == 'positive' for t in all_types])
667
+ neg_mask = np.array([t == 'neg' or t == 'negative_apo' for t in all_types])
668
+ if pos_mask.sum() > 0 and neg_mask.sum() > 0:
669
+ metrics['test_selectivity_gap'] = float(all_scores[pos_mask].mean() - all_scores[neg_mask].mean())
670
+
671
+ logger.info(f"Test evaluation: {metrics}")
672
+ wandb.log({f'test/{k}': v for k, v in metrics.items()})
673
+
674
+ return metrics, all_scores, all_labels, all_types
code/utils/__init__.py ADDED
File without changes
code/utils/anm.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Anisotropic Network Model (ANM) for conformational path interpolation.
3
+
4
+ From-scratch implementation using scipy eigendecomposition.
5
+ Projects the apo→holo displacement onto low-frequency normal modes
6
+ to create physically motivated interpolation paths.
7
+ """
8
+
9
+ import numpy as np
10
+ from scipy.linalg import eigh
11
+
12
+
13
+ def compute_anm_modes(ca_coords, cutoff=15.0, n_modes=10):
14
+ """
15
+ Build elastic network Hessian and compute normal modes via eigendecomposition.
16
+
17
+ Args:
18
+ ca_coords: [N, 3] CA atom coordinates
19
+ cutoff: distance cutoff for spring connections (Angstroms)
20
+ n_modes: number of non-trivial modes to return
21
+
22
+ Returns:
23
+ eigenvalues: [n_modes] array of eigenvalues (force constants)
24
+ eigenvectors: [n_modes, N, 3] mode displacement vectors
25
+ """
26
+ N = len(ca_coords)
27
+ if N < 4:
28
+ return np.zeros(n_modes), np.zeros((n_modes, N, 3))
29
+
30
+ # Build 3N x 3N Hessian with uniform spring constant (gamma=1)
31
+ H = np.zeros((3 * N, 3 * N), dtype=np.float64)
32
+
33
+ for i in range(N):
34
+ for j in range(i + 1, N):
35
+ diff = ca_coords[j] - ca_coords[i]
36
+ dist = np.linalg.norm(diff)
37
+ if dist > cutoff or dist < 1e-6:
38
+ continue
39
+
40
+ # Outer product of unit displacement vector
41
+ unit = diff / dist
42
+ block = np.outer(unit, unit) # [3, 3]
43
+
44
+ # Off-diagonal: H[i,j] = -gamma * (r_ij ⊗ r_ij) / |r_ij|^2
45
+ # With uniform gamma=1 and unit vectors, this simplifies to:
46
+ ii, jj = 3 * i, 3 * j
47
+ H[ii:ii+3, jj:jj+3] = -block
48
+ H[jj:jj+3, ii:ii+3] = -block
49
+
50
+ # Diagonal: accumulate
51
+ H[ii:ii+3, ii:ii+3] += block
52
+ H[jj:jj+3, jj:jj+3] += block
53
+
54
+ # Eigendecompose — first 6 modes are trivial (3 translation + 3 rotation)
55
+ n_total = min(6 + n_modes, 3 * N)
56
+ eigenvalues, eigvecs = eigh(H, subset_by_index=[0, n_total - 1])
57
+
58
+ # Skip the 6 trivial zero-frequency modes
59
+ start = min(6, len(eigenvalues) - 1)
60
+ n_available = len(eigenvalues) - start
61
+ n_return = min(n_modes, n_available)
62
+
63
+ evals = eigenvalues[start:start + n_return]
64
+ evecs = eigvecs[:, start:start + n_return] # [3N, n_return]
65
+
66
+ # Reshape eigenvectors to [n_modes, N, 3]
67
+ mode_vectors = np.zeros((n_return, N, 3))
68
+ for k in range(n_return):
69
+ mode_vectors[k] = evecs[:, k].reshape(N, 3)
70
+
71
+ # Pad if fewer modes available than requested
72
+ if n_return < n_modes:
73
+ pad_evals = np.zeros(n_modes)
74
+ pad_evals[:n_return] = evals
75
+ pad_modes = np.zeros((n_modes, N, 3))
76
+ pad_modes[:n_return] = mode_vectors
77
+ return pad_evals, pad_modes
78
+
79
+ return evals, mode_vectors
80
+
81
+
82
+ def _kabsch_align(mobile_ca, ref_ca):
83
+ """Kabsch alignment of mobile onto ref (CA atoms only)."""
84
+ t_mobile = mobile_ca.mean(axis=0)
85
+ t_ref = ref_ca.mean(axis=0)
86
+
87
+ m = mobile_ca - t_mobile
88
+ r = ref_ca - t_ref
89
+
90
+ H = m.T @ r
91
+ U, S, Vt = np.linalg.svd(H)
92
+ d = np.linalg.det(Vt.T @ U.T)
93
+ sign = np.array([1.0, 1.0, np.sign(d)])
94
+ R = Vt.T @ np.diag(sign) @ U.T
95
+
96
+ return R, t_mobile, t_ref
97
+
98
+
99
+ def _reconstruct_oxygen(coords):
100
+ """Reconstruct O atom from N, CA, C with ideal C=O geometry."""
101
+ C_pos = coords[:, 2, :]
102
+ CA_pos = coords[:, 1, :]
103
+ C_CA = C_pos - CA_pos
104
+ C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
105
+ C_CA_norm = np.maximum(C_CA_norm, 1e-8)
106
+ O_pos = C_pos + (C_CA / C_CA_norm) * 1.24
107
+ coords[:, 3, :] = O_pos
108
+ return coords
109
+
110
+
111
+ def anm_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1,
112
+ n_frames=5, n_modes=10, cutoff=15.0):
113
+ """
114
+ Interpolate backbone along dominant ANM modes from X0 toward X1.
115
+
116
+ Low-frequency modes capture global domain motions (e.g., CaM hinge bending),
117
+ creating physically informed paths where large-scale motions precede local
118
+ adjustments.
119
+
120
+ Args:
121
+ coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for apo state
122
+ coords_x1: [N1, 4, 3] backbone coords for holo state
123
+ mask_x0: [N0] bool
124
+ mask_x1: [N1] bool
125
+ n_frames: number of intermediate frames (excluding endpoints)
126
+ n_modes: number of ANM modes to use for projection
127
+ cutoff: ANM spring cutoff in Angstroms
128
+
129
+ Returns:
130
+ path_frames: list of (coords_tau, mask_tau, tau) tuples
131
+ Same interface as interpolate_backbone_path
132
+ """
133
+ n_common = min(len(coords_x0), len(coords_x1))
134
+ c0 = coords_x0[:n_common].copy()
135
+ c1 = coords_x1[:n_common].copy()
136
+ m0 = mask_x0[:n_common]
137
+ m1 = mask_x1[:n_common]
138
+
139
+ common_mask = m0 & m1
140
+ if common_mask.sum() < 5:
141
+ return []
142
+
143
+ # Kabsch-align X0 onto X1 using valid CA atoms
144
+ ca0 = c0[common_mask, 1, :]
145
+ ca1 = c1[common_mask, 1, :]
146
+ R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
147
+
148
+ # Apply alignment to all X0 backbone atoms
149
+ flat0 = c0.reshape(-1, 3)
150
+ aligned0 = (flat0 - t_mobile) @ R.T + t_ref
151
+ c0_aligned = aligned0.reshape(n_common, 4, 3)
152
+
153
+ # Compute apo→holo displacement (CA atoms, valid residues only)
154
+ ca0_aligned = c0_aligned[common_mask, 1, :] # [N_valid, 3]
155
+ ca1_valid = c1[common_mask, 1, :]
156
+
157
+ displacement = ca1_valid - ca0_aligned # [N_valid, 3]
158
+
159
+ # Compute ANM modes of the aligned apo structure
160
+ eigenvalues, mode_vectors = compute_anm_modes(
161
+ ca0_aligned, cutoff=cutoff, n_modes=n_modes
162
+ ) # mode_vectors: [n_modes, N_valid, 3]
163
+
164
+ # Project displacement onto each mode
165
+ # d_k = sum_i mode_k[i] . displacement[i]
166
+ projections = np.zeros(n_modes)
167
+ for k in range(n_modes):
168
+ projections[k] = np.sum(mode_vectors[k] * displacement)
169
+
170
+ # Reconstruct mode-projected displacement: d_mode = sum_k d_k * mode_k
171
+ mode_displacement = np.zeros_like(displacement) # [N_valid, 3]
172
+ for k in range(n_modes):
173
+ mode_displacement += projections[k] * mode_vectors[k]
174
+
175
+ # Residual displacement not captured by modes
176
+ residual = displacement - mode_displacement
177
+
178
+ # Generate intermediate frames
179
+ taus = np.linspace(0, 1, n_frames + 2)[1:-1]
180
+ path_frames = []
181
+
182
+ for tau in taus:
183
+ # Apply mode-projected + residual displacement at each tau
184
+ # Mode component applies smoothly; residual is linear
185
+ ca_interp = ca0_aligned + tau * mode_displacement + tau * residual
186
+
187
+ # Build full backbone by interpolating all 4 atom types
188
+ coords_tau = (1.0 - tau) * c0_aligned + tau * c1
189
+ # Override CA positions with ANM-interpolated values
190
+ coords_tau[common_mask, 1, :] = ca_interp
191
+
192
+ # Adjust N, C positions relative to CA shift
193
+ # The N/CA/C triangle is preserved by blending the ANM CA shift
194
+ # with the linear interpolation of N and C
195
+ ca_shift = ca_interp - ((1.0 - tau) * ca0_aligned + tau * ca1_valid)
196
+ coords_tau[common_mask, 0, :] += ca_shift # N atoms
197
+ coords_tau[common_mask, 2, :] += ca_shift # C atoms
198
+
199
+ # Reconstruct O from N, CA, C
200
+ coords_tau = _reconstruct_oxygen(coords_tau)
201
+
202
+ path_frames.append((
203
+ coords_tau.astype(np.float32),
204
+ common_mask.copy(),
205
+ float(tau),
206
+ ))
207
+
208
+ return path_frames
code/utils/path_utils.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transition-path interpolation utilities for conformational induction.
3
+
4
+ Provides:
5
+ - Kabsch-aligned backbone interpolation between two conformational states
6
+ - Gaussian Schrödinger Bridge (DSB) stochastic interpolation
7
+ - Precomputed frame loading (for AlphaFlow / AFsample2)
8
+ - Unified dispatcher: generate_path_frames()
9
+ - Per-residue displacement computation (for allosteric hinge weighting)
10
+ - Monotonically increasing path weight generation
11
+
12
+ Used by the path-aware training, guidance, and refinement modules.
13
+ """
14
+
15
+ import os
16
+ import logging
17
+ import numpy as np
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def _kabsch_align(mobile_ca, ref_ca):
23
+ """
24
+ Kabsch alignment of mobile onto ref (CA atoms only).
25
+
26
+ Args:
27
+ mobile_ca: [N, 3] array
28
+ ref_ca: [N, 3] array
29
+
30
+ Returns:
31
+ R: [3, 3] rotation matrix
32
+ t_mobile: [3] mobile centroid
33
+ t_ref: [3] ref centroid
34
+ Such that: aligned = (mobile - t_mobile) @ R.T + t_ref
35
+ """
36
+ t_mobile = mobile_ca.mean(axis=0)
37
+ t_ref = ref_ca.mean(axis=0)
38
+
39
+ m = mobile_ca - t_mobile
40
+ r = ref_ca - t_ref
41
+
42
+ H = m.T @ r
43
+ U, S, Vt = np.linalg.svd(H)
44
+ d = np.linalg.det(Vt.T @ U.T)
45
+ sign = np.array([1.0, 1.0, np.sign(d)])
46
+ R = Vt.T @ np.diag(sign) @ U.T
47
+
48
+ return R, t_mobile, t_ref
49
+
50
+
51
+ def interpolate_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, n_frames=5):
52
+ """
53
+ Generate intermediate backbone conformations along the X0 -> X1 path.
54
+
55
+ 1. Find common valid residues between X0 and X1
56
+ 2. Kabsch-align X0 onto X1 using CA atoms
57
+ 3. Linearly interpolate backbone coords at n_frames equally-spaced tau values
58
+ 4. Reconstruct O from N/CA/C with ideal geometry
59
+
60
+ Args:
61
+ coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for state 0
62
+ coords_x1: [N1, 4, 3] backbone coords for state 1
63
+ mask_x0: [N0] bool
64
+ mask_x1: [N1] bool
65
+ n_frames: number of intermediate frames (excluding endpoints)
66
+
67
+ Returns:
68
+ path_frames: list of (coords_tau, mask_tau, tau) tuples
69
+ coords_tau: [N_common, 4, 3] interpolated backbone coords
70
+ mask_tau: [N_common] bool
71
+ tau: float in (0, 1) exclusive
72
+ """
73
+ # Use common length
74
+ n_common = min(len(coords_x0), len(coords_x1))
75
+ c0 = coords_x0[:n_common].copy()
76
+ c1 = coords_x1[:n_common].copy()
77
+ m0 = mask_x0[:n_common]
78
+ m1 = mask_x1[:n_common]
79
+
80
+ # Valid in both states
81
+ common_mask = m0 & m1
82
+ if common_mask.sum() < 5:
83
+ return []
84
+
85
+ # Kabsch-align X0 onto X1 using valid CA atoms
86
+ ca0 = c0[common_mask, 1, :] # CA atoms
87
+ ca1 = c1[common_mask, 1, :]
88
+
89
+ R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
90
+
91
+ # Apply alignment to all X0 backbone atoms
92
+ n_res = n_common
93
+ flat0 = c0.reshape(-1, 3)
94
+ aligned0 = (flat0 - t_mobile) @ R.T + t_ref
95
+ c0_aligned = aligned0.reshape(n_res, 4, 3)
96
+
97
+ # Generate intermediate frames
98
+ taus = np.linspace(0, 1, n_frames + 2)[1:-1] # exclude endpoints
99
+ path_frames = []
100
+
101
+ for tau in taus:
102
+ # Linear interpolation: X_tau = (1 - tau) * X0_aligned + tau * X1
103
+ coords_tau = (1.0 - tau) * c0_aligned + tau * c1
104
+
105
+ # Reconstruct O from N, CA, C with ideal C=O bond geometry
106
+ C_pos = coords_tau[:, 2, :] # C atoms
107
+ CA_pos = coords_tau[:, 1, :] # CA atoms
108
+ C_CA = C_pos - CA_pos
109
+ C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
110
+ C_CA_norm = np.maximum(C_CA_norm, 1e-8)
111
+ O_pos = C_pos + (C_CA / C_CA_norm) * 1.24 # ideal C=O bond length
112
+ coords_tau[:, 3, :] = O_pos
113
+
114
+ path_frames.append((
115
+ coords_tau.astype(np.float32),
116
+ common_mask.copy(),
117
+ float(tau),
118
+ ))
119
+
120
+ return path_frames
121
+
122
+
123
+ def compute_residue_displacements(coords_x0, coords_x1, mask_x0, mask_x1):
124
+ """
125
+ Per-residue CA displacement between X0 and X1 after Kabsch alignment.
126
+
127
+ Args:
128
+ coords_x0: [N0, 4, 3] backbone coords for state 0
129
+ coords_x1: [N1, 4, 3] backbone coords for state 1
130
+ mask_x0: [N0] bool
131
+ mask_x1: [N1] bool
132
+
133
+ Returns:
134
+ displacements: [N_common] array of per-residue CA RMSD
135
+ common_mask: [N_common] bool — which residues are valid
136
+ """
137
+ n_common = min(len(coords_x0), len(coords_x1))
138
+ c0 = coords_x0[:n_common]
139
+ c1 = coords_x1[:n_common]
140
+ m0 = mask_x0[:n_common]
141
+ m1 = mask_x1[:n_common]
142
+ common_mask = m0 & m1
143
+
144
+ if common_mask.sum() < 5:
145
+ return np.zeros(n_common), common_mask
146
+
147
+ ca0 = c0[common_mask, 1, :]
148
+ ca1 = c1[common_mask, 1, :]
149
+
150
+ R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
151
+
152
+ # Align all CA of X0
153
+ all_ca0 = c0[:, 1, :]
154
+ aligned_ca0 = (all_ca0 - t_mobile) @ R.T + t_ref
155
+
156
+ # Per-residue displacement
157
+ all_ca1 = c1[:, 1, :]
158
+ displacements = np.linalg.norm(aligned_ca0 - all_ca1, axis=-1)
159
+
160
+ # Zero out invalid residues
161
+ displacements[~common_mask] = 0.0
162
+
163
+ return displacements.astype(np.float32), common_mask
164
+
165
+
166
+ def generate_path_weights(n_frames, mode='linear'):
167
+ """
168
+ Generate monotonically increasing weights for path frames.
169
+
170
+ The weights increase toward tau=1 (the goal state), so that
171
+ intermediate conformations closer to X1 are weighted more heavily.
172
+
173
+ Args:
174
+ n_frames: number of intermediate frames
175
+ mode: weight schedule
176
+ 'linear': w_tau = tau
177
+ 'quadratic': w_tau = tau^2
178
+ 'exponential': w_tau = (exp(tau) - 1) / (e - 1)
179
+ 'uniform': w_tau = 1/n_frames (equal weighting)
180
+
181
+ Returns:
182
+ weights: [n_frames] numpy array, normalized to sum to 1
183
+ """
184
+ if n_frames == 0:
185
+ return np.array([], dtype=np.float32)
186
+
187
+ taus = np.linspace(0, 1, n_frames + 2)[1:-1] # same as interpolation
188
+
189
+ if mode == 'linear':
190
+ weights = taus.copy()
191
+ elif mode == 'quadratic':
192
+ weights = taus ** 2
193
+ elif mode == 'exponential':
194
+ weights = (np.exp(taus) - 1.0) / (np.e - 1.0)
195
+ elif mode == 'uniform':
196
+ weights = np.ones(n_frames, dtype=np.float32)
197
+ else:
198
+ raise ValueError(f"Unknown weight mode: {mode}")
199
+
200
+ # Normalize to sum to 1
201
+ total = weights.sum()
202
+ if total > 0:
203
+ weights = weights / total
204
+
205
+ return weights.astype(np.float32)
206
+
207
+
208
+ # ---------------------------------------------------------------------------
209
+ # Gaussian Schrödinger Bridge (AlignDSB) interpolation
210
+ # ---------------------------------------------------------------------------
211
+
212
+ def dsb_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1,
213
+ n_frames=5, sigma=0.5, n_samples=20, seed=42):
214
+ """
215
+ Gaussian Schrödinger Bridge with t*(1-t) variance schedule.
216
+
217
+ Analytic formula (no neural network):
218
+ X_t = (1-t) * X0_aligned + t * X1 + sqrt(t * (1-t)) * sigma * Z
219
+
220
+ Variance peaks at t=0.5 (maximum uncertainty mid-transition) and vanishes
221
+ at endpoints. sigma controls noise amplitude in Angstroms.
222
+
223
+ For each tau, samples n_samples noisy interpolations and selects the
224
+ median (by RMSD to the mean) for robustness.
225
+
226
+ Args:
227
+ coords_x0: [N0, 4, 3] backbone coords for state 0
228
+ coords_x1: [N1, 4, 3] backbone coords for state 1
229
+ mask_x0: [N0] bool
230
+ mask_x1: [N1] bool
231
+ n_frames: number of intermediate frames
232
+ sigma: noise amplitude (Angstroms)
233
+ n_samples: number of samples per frame for median selection
234
+ seed: random seed
235
+
236
+ Returns:
237
+ path_frames: list of (coords_tau, mask_tau, tau) tuples
238
+ """
239
+ rng = np.random.RandomState(seed)
240
+
241
+ n_common = min(len(coords_x0), len(coords_x1))
242
+ c0 = coords_x0[:n_common].copy()
243
+ c1 = coords_x1[:n_common].copy()
244
+ m0 = mask_x0[:n_common]
245
+ m1 = mask_x1[:n_common]
246
+
247
+ common_mask = m0 & m1
248
+ if common_mask.sum() < 5:
249
+ return []
250
+
251
+ # Kabsch-align X0 onto X1
252
+ ca0 = c0[common_mask, 1, :]
253
+ ca1 = c1[common_mask, 1, :]
254
+ R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
255
+
256
+ flat0 = c0.reshape(-1, 3)
257
+ aligned0 = (flat0 - t_mobile) @ R.T + t_ref
258
+ c0_aligned = aligned0.reshape(n_common, 4, 3)
259
+
260
+ taus = np.linspace(0, 1, n_frames + 2)[1:-1]
261
+ path_frames = []
262
+
263
+ for tau in taus:
264
+ noise_scale = np.sqrt(tau * (1.0 - tau)) * sigma
265
+
266
+ # Generate n_samples noisy interpolations
267
+ samples = []
268
+ for _ in range(n_samples):
269
+ Z = rng.randn(n_common, 4, 3).astype(np.float64)
270
+ X_t = (1.0 - tau) * c0_aligned + tau * c1 + noise_scale * Z
271
+ samples.append(X_t)
272
+
273
+ samples = np.array(samples) # [n_samples, N, 4, 3]
274
+ mean_sample = samples.mean(axis=0) # [N, 4, 3]
275
+
276
+ # Select median sample by RMSD to mean (CA atoms)
277
+ rmsds = []
278
+ for s in samples:
279
+ diff = s[common_mask, 1, :] - mean_sample[common_mask, 1, :]
280
+ rmsd = np.sqrt((diff ** 2).sum() / common_mask.sum())
281
+ rmsds.append(rmsd)
282
+ median_idx = np.argsort(rmsds)[len(rmsds) // 2]
283
+ coords_tau = samples[median_idx]
284
+
285
+ # Reconstruct O from N, CA, C
286
+ C_pos = coords_tau[:, 2, :]
287
+ CA_pos = coords_tau[:, 1, :]
288
+ C_CA = C_pos - CA_pos
289
+ C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
290
+ C_CA_norm = np.maximum(C_CA_norm, 1e-8)
291
+ coords_tau[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24
292
+
293
+ path_frames.append((
294
+ coords_tau.astype(np.float32),
295
+ common_mask.copy(),
296
+ float(tau),
297
+ ))
298
+
299
+ return path_frames
300
+
301
+
302
+ # ---------------------------------------------------------------------------
303
+ # Precomputed frame loading (for AlphaFlow / AFsample2)
304
+ # ---------------------------------------------------------------------------
305
+
306
+ def load_precomputed_frames(target, method, precomputed_dir,
307
+ coords_x0, coords_x1, mask_x0, mask_x1,
308
+ n_frames=5):
309
+ """
310
+ Load pre-generated frames from .npz and Kabsch-align to this complex's
311
+ receptor coordinate frame.
312
+
313
+ Expected file: {precomputed_dir}/{target}/{method}/frames.npz
314
+ with keys: 'frames' [n_frames, N_ref, 4, 3], 'taus' [n_frames],
315
+ 'mask' [N_ref] bool
316
+
317
+ Args:
318
+ target: target name (e.g. 'cam')
319
+ method: method name ('alphaflow' or 'afsample2')
320
+ precomputed_dir: root directory for precomputed frames
321
+ coords_x0, coords_x1: apo/holo backbone coords for alignment
322
+ mask_x0, mask_x1: residue masks
323
+ n_frames: number of frames to return
324
+
325
+ Returns:
326
+ path_frames: list of (coords_tau, mask_tau, tau) tuples
327
+ """
328
+ npz_path = os.path.join(precomputed_dir, target, method, 'frames.npz')
329
+ if not os.path.exists(npz_path):
330
+ logger.warning(f"Precomputed frames not found: {npz_path}, "
331
+ f"falling back to linear interpolation")
332
+ return interpolate_backbone_path(coords_x0, coords_x1,
333
+ mask_x0, mask_x1, n_frames)
334
+
335
+ data = np.load(npz_path)
336
+ pre_frames = data['frames'] # [K, N_ref, 4, 3]
337
+ pre_taus = data['taus'] # [K]
338
+ pre_mask = data['mask'] # [N_ref]
339
+
340
+ n_common = min(len(coords_x0), len(coords_x1), len(pre_mask))
341
+ m0 = mask_x0[:n_common]
342
+ m1 = mask_x1[:n_common]
343
+ pm = pre_mask[:n_common]
344
+ common_mask = m0 & m1 & pm
345
+
346
+ if common_mask.sum() < 5:
347
+ logger.warning(f"Too few common residues for {target}/{method}, "
348
+ f"falling back to linear")
349
+ return interpolate_backbone_path(coords_x0, coords_x1,
350
+ mask_x0, mask_x1, n_frames)
351
+
352
+ # Align precomputed frames to the holo receptor (X1) coordinate frame
353
+ # The precomputed frames were generated from the reference apo sequence
354
+ # and may be in a different coordinate frame
355
+ ref_ca = coords_x1[:n_common][common_mask, 1, :] # holo CA as reference
356
+
357
+ path_frames = []
358
+ K = min(len(pre_frames), n_frames)
359
+
360
+ # Select n_frames evenly spaced from available frames
361
+ if len(pre_frames) > n_frames:
362
+ indices = np.linspace(0, len(pre_frames) - 1, n_frames).astype(int)
363
+ else:
364
+ indices = np.arange(K)
365
+
366
+ for idx in indices:
367
+ frame = pre_frames[idx, :n_common].copy() # [N_common, 4, 3]
368
+ tau = float(pre_taus[idx])
369
+
370
+ # Kabsch-align frame CA to holo CA
371
+ frame_ca = frame[common_mask, 1, :]
372
+ R, t_frame, t_ref = _kabsch_align(frame_ca, ref_ca)
373
+
374
+ flat_frame = frame.reshape(-1, 3)
375
+ aligned = (flat_frame - t_frame) @ R.T + t_ref
376
+ frame_aligned = aligned.reshape(n_common, 4, 3)
377
+
378
+ # Reconstruct O
379
+ C_pos = frame_aligned[:, 2, :]
380
+ CA_pos = frame_aligned[:, 1, :]
381
+ C_CA = C_pos - CA_pos
382
+ C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
383
+ C_CA_norm = np.maximum(C_CA_norm, 1e-8)
384
+ frame_aligned[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24
385
+
386
+ path_frames.append((
387
+ frame_aligned.astype(np.float32),
388
+ common_mask.copy(),
389
+ tau,
390
+ ))
391
+
392
+ return path_frames
393
+
394
+
395
+ # ---------------------------------------------------------------------------
396
+ # Unified dispatcher
397
+ # ---------------------------------------------------------------------------
398
+
399
+ def generate_path_frames(coords_x0, coords_x1, mask_x0, mask_x1,
400
+ method='linear', n_frames=5,
401
+ precomputed_dir=None, target=None, **kwargs):
402
+ """
403
+ Dispatch to method-specific frame generation.
404
+
405
+ Args:
406
+ coords_x0, coords_x1: [N, 4, 3] backbone coords for apo/holo
407
+ mask_x0, mask_x1: [N] bool masks
408
+ method: one of 'linear', 'alphaflow', 'afsample2', 'dsb', 'anm'
409
+ n_frames: number of intermediate frames
410
+ precomputed_dir: directory for precomputed frames (alphaflow/afsample2)
411
+ target: target name (needed for precomputed methods)
412
+ **kwargs: method-specific parameters (sigma, n_modes, etc.)
413
+
414
+ Returns:
415
+ path_frames: list of (coords_tau, mask_tau, tau) tuples
416
+ """
417
+ if method == 'linear':
418
+ return interpolate_backbone_path(
419
+ coords_x0, coords_x1, mask_x0, mask_x1, n_frames)
420
+
421
+ elif method in ('alphaflow', 'afsample2'):
422
+ if precomputed_dir is None:
423
+ raise ValueError(f"precomputed_dir required for method '{method}'")
424
+ if target is None:
425
+ raise ValueError(f"target name required for method '{method}'")
426
+ return load_precomputed_frames(
427
+ target, method, precomputed_dir,
428
+ coords_x0, coords_x1, mask_x0, mask_x1, n_frames)
429
+
430
+ elif method == 'dsb':
431
+ return dsb_backbone_path(
432
+ coords_x0, coords_x1, mask_x0, mask_x1,
433
+ n_frames=n_frames,
434
+ sigma=kwargs.get('sigma', 0.5),
435
+ n_samples=kwargs.get('n_samples', 20),
436
+ seed=kwargs.get('seed', 42))
437
+
438
+ elif method == 'anm':
439
+ from utils.anm import anm_backbone_path
440
+ return anm_backbone_path(
441
+ coords_x0, coords_x1, mask_x0, mask_x1,
442
+ n_frames=n_frames,
443
+ n_modes=kwargs.get('n_modes', 10),
444
+ cutoff=kwargs.get('cutoff', 15.0))
445
+
446
+ else:
447
+ raise ValueError(f"Unknown path method: '{method}'. "
448
+ f"Choose from: linear, alphaflow, afsample2, dsb, anm")
code/utils/pdb_utils.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PDB parsing utilities for Allo-Designer.
3
+ Extracts backbone geometry, computes local frames, and identifies interface residues.
4
+ """
5
+
6
+ import numpy as np
7
+ from Bio import PDB
8
+ from Bio.PDB import PDBParser, MMCIFParser, PDBIO
9
+ from Bio.PDB.Polypeptide import is_aa
10
+ import warnings
11
+ warnings.filterwarnings("ignore", category=PDB.PDBExceptions.PDBConstructionWarning)
12
+
13
+ AA3_TO_IDX = {
14
+ 'ALA': 0, 'ARG': 1, 'ASN': 2, 'ASP': 3, 'CYS': 4,
15
+ 'GLN': 5, 'GLU': 6, 'GLY': 7, 'HIS': 8, 'ILE': 9,
16
+ 'LEU': 10, 'LYS': 11, 'MET': 12, 'PHE': 13, 'PRO': 14,
17
+ 'SER': 15, 'THR': 16, 'TRP': 17, 'TYR': 18, 'VAL': 19,
18
+ 'UNK': 20,
19
+ }
20
+ NUM_AA = 21 # 20 standard + UNK
21
+
22
+
23
+ def load_structure(pdb_path: str, model_id: int = 0):
24
+ """Load a PDB/CIF file and return the first model."""
25
+ if pdb_path.endswith('.cif') or pdb_path.endswith('.mmcif'):
26
+ parser = MMCIFParser(QUIET=True)
27
+ else:
28
+ parser = PDBParser(QUIET=True)
29
+ struct = parser.get_structure("protein", pdb_path)
30
+ return list(struct.get_models())[model_id]
31
+
32
+
33
+ def get_residues(chain, only_standard: bool = True):
34
+ """Return a list of standard amino acid residues from a chain."""
35
+ residues = []
36
+ for res in chain.get_residues():
37
+ if only_standard and not is_aa(res, standard=True):
38
+ continue
39
+ if res.get_id()[0] != ' ': # skip HETATM
40
+ continue
41
+ residues.append(res)
42
+ return residues
43
+
44
+
45
+ def get_backbone_coords(residues):
46
+ """
47
+ Extract backbone atom coordinates (N, CA, C, O) for each residue.
48
+ Returns: coords [N_res, 4, 3], mask [N_res] (True = all backbone atoms present)
49
+ """
50
+ N = len(residues)
51
+ coords = np.zeros((N, 4, 3), dtype=np.float32)
52
+ mask = np.zeros(N, dtype=bool)
53
+
54
+ for i, res in enumerate(residues):
55
+ try:
56
+ coords[i, 0] = res['N'].get_vector().get_array()
57
+ coords[i, 1] = res['CA'].get_vector().get_array()
58
+ coords[i, 2] = res['C'].get_vector().get_array()
59
+ if 'O' in res:
60
+ coords[i, 3] = res['O'].get_vector().get_array()
61
+ else:
62
+ # Estimate O position if missing
63
+ coords[i, 3] = coords[i, 2]
64
+ mask[i] = True
65
+ except KeyError:
66
+ pass
67
+ return coords, mask
68
+
69
+
70
+ def get_aa_indices(residues):
71
+ """Return integer amino acid indices for each residue."""
72
+ return np.array([
73
+ AA3_TO_IDX.get(res.get_resname(), AA3_TO_IDX['UNK'])
74
+ for res in residues
75
+ ], dtype=np.int64)
76
+
77
+
78
+ def compute_backbone_frames(coords, mask):
79
+ """
80
+ Compute SE(3)-equivariant backbone frames from N, CA, C atoms.
81
+ Frame: z-axis = CA->C, y-axis = component of CA->N perpendicular to z, x-axis = y x z.
82
+
83
+ Returns:
84
+ origins: [N, 3] = CA positions
85
+ rotations: [N, 3, 3] = rotation matrices (columns are x, y, z axes)
86
+ """
87
+ N_res = coords.shape[0]
88
+ origins = coords[:, 1, :] # CA positions [N, 3]
89
+ rotations = np.zeros((N_res, 3, 3), dtype=np.float32)
90
+
91
+ for i in range(N_res):
92
+ if not mask[i]:
93
+ rotations[i] = np.eye(3)
94
+ continue
95
+ ca = coords[i, 1]
96
+ n = coords[i, 0]
97
+ c = coords[i, 2]
98
+
99
+ # z-axis: CA -> C
100
+ z = c - ca
101
+ z_norm = np.linalg.norm(z)
102
+ if z_norm < 1e-6:
103
+ rotations[i] = np.eye(3)
104
+ continue
105
+ z = z / z_norm
106
+
107
+ # y-axis: CA -> N, orthogonalized
108
+ y = n - ca
109
+ y = y - np.dot(y, z) * z
110
+ y_norm = np.linalg.norm(y)
111
+ if y_norm < 1e-6:
112
+ rotations[i] = np.eye(3)
113
+ continue
114
+ y = y / y_norm
115
+
116
+ # x-axis: y cross z
117
+ x = np.cross(y, z)
118
+
119
+ rotations[i] = np.stack([x, y, z], axis=-1) # columns are axes
120
+
121
+ return origins, rotations
122
+
123
+
124
+ def compute_torsion_angles(coords, mask):
125
+ """
126
+ Compute backbone torsion angles (phi, psi, omega) for each residue.
127
+ Returns sin/cos of each angle. [N, 6]
128
+ """
129
+ N = len(coords)
130
+ angles = np.zeros((N, 6), dtype=np.float32)
131
+
132
+ def dihedral(p0, p1, p2, p3):
133
+ """Praxelis dihedral angle computation."""
134
+ b1 = p1 - p0
135
+ b2 = p2 - p1
136
+ b3 = p3 - p2
137
+ n1 = np.cross(b1, b2)
138
+ n2 = np.cross(b2, b3)
139
+ n1_norm = np.linalg.norm(n1)
140
+ n2_norm = np.linalg.norm(n2)
141
+ if n1_norm < 1e-6 or n2_norm < 1e-6:
142
+ return 0.0
143
+ n1 = n1 / n1_norm
144
+ n2 = n2 / n2_norm
145
+ m1 = np.cross(n1, b2 / (np.linalg.norm(b2) + 1e-8))
146
+ cos_a = np.clip(np.dot(n1, n2), -1, 1)
147
+ sin_a = np.dot(m1, n2)
148
+ return np.arctan2(sin_a, cos_a)
149
+
150
+ for i in range(N):
151
+ if not mask[i]:
152
+ continue
153
+ ca_i = coords[i, 1]
154
+ n_i = coords[i, 0]
155
+ c_i = coords[i, 2]
156
+
157
+ # Phi: C_{i-1} - N_i - CA_i - C_i
158
+ if i > 0 and mask[i - 1]:
159
+ c_prev = coords[i - 1, 2]
160
+ phi = dihedral(c_prev, n_i, ca_i, c_i)
161
+ angles[i, 0] = np.sin(phi)
162
+ angles[i, 1] = np.cos(phi)
163
+
164
+ # Psi: N_i - CA_i - C_i - N_{i+1}
165
+ if i < N - 1 and mask[i + 1]:
166
+ n_next = coords[i + 1, 0]
167
+ psi = dihedral(n_i, ca_i, c_i, n_next)
168
+ angles[i, 2] = np.sin(psi)
169
+ angles[i, 3] = np.cos(psi)
170
+
171
+ # Omega: CA_{i-1} - C_{i-1} - N_i - CA_i
172
+ if i > 0 and mask[i - 1]:
173
+ ca_prev = coords[i - 1, 1]
174
+ c_prev = coords[i - 1, 2]
175
+ omega = dihedral(ca_prev, c_prev, n_i, ca_i)
176
+ angles[i, 4] = np.sin(omega)
177
+ angles[i, 5] = np.cos(omega)
178
+
179
+ return angles
180
+
181
+
182
+ def get_interface_residues(rec_coords, binder_coords, rec_mask, binder_mask, cutoff: float = 8.0):
183
+ """
184
+ Find interface residues: receptor residues within cutoff of any binder Cα, and vice versa.
185
+ Uses CA-CA distances.
186
+
187
+ Returns:
188
+ rec_interface: bool array [N_rec]
189
+ binder_interface: bool array [N_binder]
190
+ """
191
+ rec_ca = rec_coords[:, 1, :] # [N_rec, 3]
192
+ binder_ca = binder_coords[:, 1, :] # [N_binder, 3]
193
+
194
+ # Pairwise CA-CA distances [N_rec, N_binder]
195
+ diff = rec_ca[:, None, :] - binder_ca[None, :, :] # [N_rec, N_binder, 3]
196
+ dist = np.sqrt((diff ** 2).sum(axis=-1)) # [N_rec, N_binder]
197
+
198
+ # Mask out residues without coordinates
199
+ dist[~rec_mask, :] = np.inf
200
+ dist[:, ~binder_mask] = np.inf
201
+
202
+ rec_interface = (dist < cutoff).any(axis=1)
203
+ binder_interface = (dist < cutoff).any(axis=0)
204
+
205
+ return rec_interface, binder_interface
206
+
207
+
208
+ def align_structures(mobile_ca, ref_ca, mobile_coords=None):
209
+ """
210
+ Kabsch alignment: align mobile to ref using CA positions.
211
+ Returns aligned CA coords and optionally full backbone coords.
212
+ """
213
+ assert mobile_ca.shape == ref_ca.shape, "Must have same number of residues"
214
+
215
+ # Center
216
+ mobile_center = mobile_ca.mean(axis=0)
217
+ ref_center = ref_ca.mean(axis=0)
218
+ m = mobile_ca - mobile_center
219
+ r = ref_ca - ref_center
220
+
221
+ # SVD
222
+ H = m.T @ r
223
+ U, S, Vt = np.linalg.svd(H)
224
+ d = np.sign(np.linalg.det(Vt.T @ U.T))
225
+ D = np.diag([1, 1, d])
226
+ R = Vt.T @ D @ U.T # rotation matrix
227
+
228
+ mobile_ca_aligned = (m @ R.T) + ref_center
229
+
230
+ if mobile_coords is not None:
231
+ # Apply same rotation to full backbone
232
+ N_res, N_atoms, _ = mobile_coords.shape
233
+ flat = mobile_coords.reshape(-1, 3) - mobile_center
234
+ aligned_flat = (flat @ R.T) + ref_center
235
+ mobile_coords_aligned = aligned_flat.reshape(N_res, N_atoms, 3)
236
+ return mobile_ca_aligned, R, mobile_coords_aligned
237
+
238
+ return mobile_ca_aligned, R
239
+
240
+
241
+ def compute_ca_rmsd(coords1, coords2, mask=None):
242
+ """Compute CA-RMSD between two sets of backbone coordinates."""
243
+ ca1 = coords1[:, 1, :]
244
+ ca2 = coords2[:, 1, :]
245
+ if mask is not None:
246
+ ca1 = ca1[mask]
247
+ ca2 = ca2[mask]
248
+ diff = ca1 - ca2
249
+ return np.sqrt((diff ** 2).sum(axis=-1).mean())
250
+
251
+
252
+ def compute_fraction_native_contacts(
253
+ native_rec_ca, native_binder_ca,
254
+ model_rec_ca=None, model_binder_ca=None,
255
+ cutoff=8.0,
256
+ # Legacy 2-arg signature support
257
+ mask=None, delta=1.0,
258
+ ):
259
+ """
260
+ Compute fraction of native inter-chain contacts (fNAT).
261
+
262
+ fNAT = |recovered inter-chain contacts| / |native inter-chain contacts|
263
+
264
+ A native contact is a (receptor_i, binder_j) pair with CA-CA distance
265
+ < cutoff in the native complex. A contact is "recovered" if the same
266
+ pair is < cutoff in the model complex.
267
+
268
+ Args:
269
+ native_rec_ca: [N_rec, 3] receptor CA coords in native complex
270
+ native_binder_ca: [N_bind, 3] binder CA coords in native complex
271
+ model_rec_ca: [N_rec, 3] receptor CA in model (default: same as native)
272
+ model_binder_ca: [N_bind, 3] binder CA in model (default: same as native)
273
+ cutoff: contact distance threshold in Angstroms (default 8.0 for CA-CA)
274
+
275
+ Returns:
276
+ fNAT in [0, 1]. Returns 0.0 if no native contacts exist.
277
+ """
278
+ if model_rec_ca is None:
279
+ model_rec_ca = native_rec_ca
280
+ if model_binder_ca is None:
281
+ model_binder_ca = native_binder_ca
282
+
283
+ # Inter-chain distance matrices [N_rec, N_bind]
284
+ native_dist = np.sqrt(
285
+ ((native_rec_ca[:, None, :] - native_binder_ca[None, :, :]) ** 2).sum(-1)
286
+ )
287
+ model_dist = np.sqrt(
288
+ ((model_rec_ca[:, None, :] - model_binder_ca[None, :, :]) ** 2).sum(-1)
289
+ )
290
+
291
+ native_contacts = native_dist < cutoff
292
+ recovered = native_contacts & (model_dist < cutoff)
293
+
294
+ n_native = native_contacts.sum()
295
+ if n_native == 0:
296
+ return 0.0
297
+ return float(recovered.sum()) / float(n_native)
298
+
299
+
300
+ def rbf_encode(distances, d_min=0.0, d_max=20.0, n_bins=16):
301
+ """
302
+ RBF encoding of distances using Gaussian basis functions.
303
+ Returns: [*distances.shape, n_bins]
304
+ """
305
+ centers = np.linspace(d_min, d_max, n_bins)
306
+ sigma = (d_max - d_min) / (n_bins - 1)
307
+ encoded = np.exp(-((distances[..., None] - centers) ** 2) / (2 * sigma ** 2))
308
+ return encoded.astype(np.float32)
309
+
310
+
311
+ # Candidate sidechain atoms for chi1 (first atom after CB)
312
+ _CHI1_ATOMS = ['CG', 'CG1', 'OG', 'OG1', 'SG']
313
+ # Candidate sidechain atoms for chi2 (second dihedral: CA-CB-XG-XD)
314
+ _CHI2_ATOMS = ['CD', 'CD1', 'SD', 'OD1', 'ND1', 'CE', 'NE', 'OE1']
315
+
316
+
317
+ def _dihedral_4pts(p0, p1, p2, p3):
318
+ """Compute dihedral angle between four 3D points (radians)."""
319
+ b1 = p1 - p0
320
+ b2 = p2 - p1
321
+ b3 = p3 - p2
322
+ n1 = np.cross(b1, b2)
323
+ n2 = np.cross(b2, b3)
324
+ n1_norm = np.linalg.norm(n1)
325
+ n2_norm = np.linalg.norm(n2)
326
+ if n1_norm < 1e-6 or n2_norm < 1e-6:
327
+ return 0.0
328
+ n1 = n1 / n1_norm
329
+ n2 = n2 / n2_norm
330
+ m1 = np.cross(n1, b2 / (np.linalg.norm(b2) + 1e-8))
331
+ return np.arctan2(np.dot(m1, n2), np.dot(n1, n2))
332
+
333
+
334
+ def compute_chi_angles(residues, mask):
335
+ """
336
+ Compute chi1 and chi2 sidechain torsion angles for each residue.
337
+
338
+ Chi1: N - CA - CB - XG (first sidechain dihedral)
339
+ Chi2: CA - CB - XG - XD (second sidechain dihedral)
340
+
341
+ For residues lacking the atoms (Gly, or missing coordinates), returns zeros.
342
+
343
+ Returns:
344
+ chi_feats: [N, 4] (sin_chi1, cos_chi1, sin_chi2, cos_chi2)
345
+ """
346
+ N = len(residues)
347
+ chi_feats = np.zeros((N, 4), dtype=np.float32)
348
+
349
+ for i, res in enumerate(residues):
350
+ if not mask[i]:
351
+ continue
352
+ atoms = {atom.get_name(): atom.get_vector().get_array() for atom in res.get_atoms()
353
+ if atom.get_name() in ('N', 'CA', 'CB') + tuple(_CHI1_ATOMS) + tuple(_CHI2_ATOMS)}
354
+
355
+ n_pos = atoms.get('N')
356
+ ca_pos = atoms.get('CA')
357
+ cb_pos = atoms.get('CB')
358
+
359
+ if n_pos is None or ca_pos is None or cb_pos is None:
360
+ continue
361
+
362
+ # Chi1: N - CA - CB - XG
363
+ xg_pos = None
364
+ for aname in _CHI1_ATOMS:
365
+ if aname in atoms:
366
+ xg_pos = atoms[aname]
367
+ break
368
+
369
+ if xg_pos is not None:
370
+ chi1 = _dihedral_4pts(np.array(n_pos), np.array(ca_pos),
371
+ np.array(cb_pos), np.array(xg_pos))
372
+ chi_feats[i, 0] = np.sin(chi1)
373
+ chi_feats[i, 1] = np.cos(chi1)
374
+
375
+ # Chi2: CA - CB - XG - XD
376
+ xd_pos = None
377
+ for aname in _CHI2_ATOMS:
378
+ if aname in atoms:
379
+ xd_pos = atoms[aname]
380
+ break
381
+
382
+ if xd_pos is not None:
383
+ chi2 = _dihedral_4pts(np.array(ca_pos), np.array(cb_pos),
384
+ np.array(xg_pos), np.array(xd_pos))
385
+ chi_feats[i, 2] = np.sin(chi2)
386
+ chi_feats[i, 3] = np.cos(chi2)
387
+
388
+ return chi_feats
389
+
390
+
391
+ def get_cb_positions(residues, coords, mask):
392
+ """
393
+ Return CB positions for each residue (CA position for Gly or missing CB).
394
+
395
+ Returns:
396
+ cb_pos: [N, 3]
397
+ """
398
+ N = len(residues)
399
+ cb_pos = coords[:, 1, :].copy() # default to CA
400
+
401
+ for i, res in enumerate(residues):
402
+ if not mask[i]:
403
+ continue
404
+ try:
405
+ cb_pos[i] = res['CB'].get_vector().get_array()
406
+ except KeyError:
407
+ pass # Gly or missing CB: keep CA
408
+
409
+ return cb_pos.astype(np.float32)
410
+
411
+
412
+ # Simplified hydrophobicity groups for contact energy
413
+ _HYDROPHOBIC = {'ALA', 'VAL', 'ILE', 'LEU', 'MET', 'PHE', 'TRP', 'PRO', 'TYR'}
414
+ _POS_CHARGED = {'ARG', 'LYS', 'HIS'}
415
+ _NEG_CHARGED = {'ASP', 'GLU'}
416
+
417
+
418
+ def _residue_group(resname):
419
+ if resname in _HYDROPHOBIC:
420
+ return 'H'
421
+ if resname in _POS_CHARGED:
422
+ return '+'
423
+ if resname in _NEG_CHARGED:
424
+ return '-'
425
+ return 'P' # polar
426
+
427
+
428
+ def compute_contact_energy(rec_residues, binder_residues,
429
+ rec_cb, binder_cb,
430
+ rec_mask, binder_mask,
431
+ cutoff: float = 8.0):
432
+ """
433
+ Compute a simple CB-CB contact energy as a physics-based ddG proxy.
434
+
435
+ Uses a 4-group hydrophobicity potential:
436
+ HH: -1.0 (hydrophobic-hydrophobic, favorable)
437
+ +-: -0.5 (opposite charges, favorable)
438
+ H+/-: +0.3 (hydrophobic-charged, unfavorable)
439
+ else: 0.0
440
+
441
+ Returns a scalar in [0, 1] via sigmoid normalization.
442
+ """
443
+ n_rec = len(rec_residues)
444
+ n_binder = len(binder_residues)
445
+
446
+ # CB-CB distance matrix [n_rec, n_binder]
447
+ diff = rec_cb[:, None, :] - binder_cb[None, :, :] # [n_rec, n_binder, 3]
448
+ dist = np.sqrt((diff ** 2).sum(axis=-1)) # [n_rec, n_binder]
449
+
450
+ # Mask invalid residues
451
+ dist[~rec_mask, :] = np.inf
452
+ dist[:, ~binder_mask] = np.inf
453
+
454
+ contact_mask = dist < cutoff
455
+
456
+ energy = 0.0
457
+ for i in range(n_rec):
458
+ for j in range(n_binder):
459
+ if not contact_mask[i, j]:
460
+ continue
461
+ gi = _residue_group(rec_residues[i].get_resname())
462
+ gj = _residue_group(binder_residues[j].get_resname())
463
+ if gi == 'H' and gj == 'H':
464
+ energy -= 1.0
465
+ elif (gi == '+' and gj == '-') or (gi == '-' and gj == '+'):
466
+ energy -= 0.5
467
+ elif (gi == 'H' and gj in ('+', '-')) or (gj == 'H' and gi in ('+', '-')):
468
+ energy += 0.3
469
+
470
+ # Normalize: sigmoid of (energy / 10) shifted so that 0 contacts → score 0.3
471
+ score = 1.0 / (1.0 + np.exp(-(energy - 5.0) / 5.0))
472
+ return float(score)
code/utils/sam.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sharpness-Aware Minimization (SAM) optimizer wrapper.
3
+ Seeks parameters in flatter minima for better OOD generalization.
4
+ Reference: Foret et al., "Sharpness-Aware Minimization for Efficiently Improving Generalization" (ICLR 2021)
5
+ """
6
+
7
+ import torch
8
+
9
+
10
+ class SAM(torch.optim.Optimizer):
11
+ def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
12
+ defaults = dict(rho=rho, **kwargs)
13
+ super().__init__(params, defaults)
14
+ self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
15
+
16
+ @torch.no_grad()
17
+ def first_step(self):
18
+ grad_norm = self._grad_norm()
19
+ for group in self.param_groups:
20
+ scale = group['rho'] / (grad_norm + 1e-12)
21
+ for p in group['params']:
22
+ if p.grad is None:
23
+ continue
24
+ e_w = p.grad * scale
25
+ p.add_(e_w)
26
+ self.state[p]['e_w'] = e_w
27
+
28
+ @torch.no_grad()
29
+ def second_step(self):
30
+ for group in self.param_groups:
31
+ for p in group['params']:
32
+ if p.grad is None:
33
+ continue
34
+ p.sub_(self.state[p]['e_w'])
35
+ self.base_optimizer.step()
36
+
37
+ def _grad_norm(self):
38
+ shared_device = self.param_groups[0]['params'][0].device
39
+ norm = torch.norm(
40
+ torch.stack([
41
+ p.grad.norm(p=2).to(shared_device)
42
+ for group in self.param_groups
43
+ for p in group['params']
44
+ if p.grad is not None
45
+ ]),
46
+ p=2,
47
+ )
48
+ return norm
49
+
50
+ def step(self, closure=None):
51
+ raise NotImplementedError("SAM requires manual first_step() and second_step() calls")
52
+
53
+ def zero_grad(self):
54
+ self.base_optimizer.zero_grad()
data/sample/README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sample dataset (in-repo)
2
+
3
+ This directory ships a single pre-built target — **`cam` (Calmodulin)** — so users
4
+ can run a smoke test of the training and evaluation pipeline without first
5
+ downloading the full multi-target dataset (~10 GB on Zenodo) or rebuilding
6
+ from raw PDB files (~30 min per target).
7
+
8
+ ## Contents
9
+
10
+ ```
11
+ sample/
12
+ └── cam/
13
+ ├── train.pkl # 84 paired holo/apo complex graphs (~24 MB)
14
+ ├── val.pkl # 12 validation graphs (~1.3 MB)
15
+ └── test.pkl # 96 held-out evaluation graphs (~25 MB)
16
+ ```
17
+
18
+ Each pickle is a list of dicts produced by `code/data/build_dataset.py`.
19
+ Splits follow the family-stratified scheme used in the paper
20
+ (equivalent to `data/processed_familysplit/cam/` train+val and
21
+ `data/processed_familysplit_v5/cam/test.pkl` in the source tree).
22
+
23
+ ## Smoke test (1-epoch end-to-end)
24
+
25
+ ```bash
26
+ # Train both phases for 1 epoch
27
+ python code/scripts/train.py \
28
+ --target cam \
29
+ --phase both \
30
+ --data_dir data/sample \
31
+ --checkpoint_dir checkpoints_smoke \
32
+ --epochs 1 \
33
+ --no_wandb
34
+
35
+ # Evaluate
36
+ python code/scripts/evaluate.py \
37
+ --target cam \
38
+ --checkpoint checkpoints_smoke/best_phase2.pt \
39
+ --data_dir data/sample \
40
+ --outdir eval_smoke
41
+ ```
42
+
43
+ Expected runtime: ~1 minute on a single GPU.
44
+
45
+ ## Want more data?
46
+
47
+ - All 12 paper targets, pre-built: see `data/DOWNLOAD.md` for the Zenodo link.
48
+ - Build from raw PDBs locally: `scripts/build_data.sh paper12`.
49
+ - Per-target PDB lists and chain mappings: `data/target_lists/*.txt` (68 targets).
data/sample/cam/test.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38e36d092bbcf4222c762e351fe305e8627f47c78c4acda74170c650ac09e1e8
3
+ size 25608454
data/sample/esm2_embeddings/cam/1IWQ_A.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c52e5a467dc0e73ef7139475bdaafd05e6df8872c345e31fb3da1d35497c00a1
3
+ size 712791
data/sample/esm2_embeddings/cam/1IWQ_B.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7afc0e91e9095d3945adbf4fcf287fd167ac4fed51d10718db65ee80b89890e
3
+ size 93271
data/sample/esm2_embeddings/cam/1K93_A.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef76f07da93f60f77ee00fbd719fd81621e50492d245bf62f1a820140e853d61
3
+ size 2484311
data/sample/esm2_embeddings/cam/1K93_B.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c9d65d07aaa8b1c219f5458d2a0502739f482b698c5579bbb4ebee681b5aecb
3
+ size 2392151
data/sample/esm2_embeddings/cam/1NWD_A.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4138c902d0ca0e40ff0c8b0522f71974ba15d3ec1a11db2bb00f9fe8227339f9
3
+ size 758871
data/sample/esm2_embeddings/cam/1NWD_B.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee429a3b3cbaedf0471f02f7d636e4737f98630f655402317dd3438f8c69c30d
3
+ size 144471
data/sample/esm2_embeddings/cam/1SY9_A.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a7ca8538945c4b1e524ef3440c5aa7f73afd5d273034063518a0251e2a59f01
3
+ size 758871
data/sample/esm2_embeddings/cam/1SY9_B.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6258dc842103adc893aacee23c195e2b7da3038e4704e0c9166e1e5581ac784
3
+ size 98391
data/sample/esm2_embeddings/cam/2BBM_A.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0630d3cbc267244f3a40ee83155373675f654cfd720aeaf03271c6438cb68b1d
3
+ size 758871
data/sample/esm2_embeddings/cam/2BBM_B.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:485ff76a5d7ed5459035608a25322601d8fc3d1b73acd49a13d1dcb1fec22ac3
3
+ size 134231
data/sample/esm2_embeddings/cam/2HQW_A.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8b37dc4704f3c83e4c1d723da66d1873a8dbf27a6e9313450312b9e847f9f44
3
+ size 707671
data/sample/esm2_embeddings/cam/2HQW_B.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:564220203f56b067cda236c35cd153d3b20dcaf0fb2080fadb664eed8f413607
3
+ size 113751
data/sample/esm2_embeddings/cam/2O5G_A.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1187a8edd2d25b219c7ca4cc970a6a771fe8605ab2937163dc3ee595fad97bab
3
+ size 753751
data/sample/esm2_embeddings/cam/2O5G_B.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:604f44b0738f0c9d0f3e1410187a94d95dd7002cead28e23635ddd33bf419b60
3
+ size 98391
data/sample/esm2_embeddings/cam/3D33_A.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4c553c3e7dfe746babf92dfcae09dab8ecbcf1490cf475aead72fc5f2d30a43
3
+ size 472151