Commit ·
ad9572d
0
Parent(s):
AlloGen public release: Q_theta scorer + PXDesign guidance + Colab demo
Browse filesSingle commit, no prior history.
Contents:
- Q_theta scorer (graph transformer, SE(3)-invariant + ESM-2 conditioning,
~898K params, v4-S2 target-swap checkpoint via Git LFS).
- PXDesign guidance scripts (Langevin / SMC / TDS / classifier) under
code/scripts/pxdesign_guidance/.
- CaM inference sample (96 binder-CaM graphs + matching ESM-2 features).
- Colab demo at notebooks/AlloGen_CaM_demo.ipynb (one-click for biology users:
load scorer, score 96 designs, view ROC/best-of-K, guidance recipe).
- README with method figure, inference quickstart, full Python scoring API.
MIT license.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +14 -0
- .gitignore +41 -0
- LICENSE +21 -0
- README.md +133 -0
- checkpoints/Q_theta_phase1.pt +3 -0
- checkpoints/Q_theta_phase2.pt +3 -0
- checkpoints/Q_theta_train_curve.csv +16 -0
- code/__init__.py +0 -0
- code/data/__init__.py +0 -0
- code/data/dataset.py +832 -0
- code/models/__init__.py +0 -0
- code/models/differentiable_features.py +622 -0
- code/models/features.py +250 -0
- code/models/scorer.py +585 -0
- code/requirements.txt +22 -0
- code/scripts/README.md +55 -0
- code/scripts/evaluate.py +332 -0
- code/scripts/pxdesign_guidance/__init__.py +1 -0
- code/scripts/pxdesign_guidance/convert_cif_to_pdb.py +132 -0
- code/scripts/pxdesign_guidance/guided_pxdesign.py +408 -0
- code/scripts/pxdesign_guidance/iterative_refinement.py +338 -0
- code/scripts/pxdesign_guidance/langevin_pxdesign.py +374 -0
- code/scripts/pxdesign_guidance/qtheta_pxdesign.py +477 -0
- code/scripts/pxdesign_guidance/smc_pxdesign.py +262 -0
- code/scripts/pxdesign_guidance/tds_pxdesign.py +323 -0
- code/scripts/rescore.py +178 -0
- code/trainers/__init__.py +0 -0
- code/trainers/trainer.py +674 -0
- code/utils/__init__.py +0 -0
- code/utils/anm.py +208 -0
- code/utils/path_utils.py +448 -0
- code/utils/pdb_utils.py +472 -0
- code/utils/sam.py +54 -0
- data/sample/README.md +49 -0
- data/sample/cam/test.pkl +3 -0
- data/sample/esm2_embeddings/cam/1IWQ_A.pt +3 -0
- data/sample/esm2_embeddings/cam/1IWQ_B.pt +3 -0
- data/sample/esm2_embeddings/cam/1K93_A.pt +3 -0
- data/sample/esm2_embeddings/cam/1K93_B.pt +3 -0
- data/sample/esm2_embeddings/cam/1NWD_A.pt +3 -0
- data/sample/esm2_embeddings/cam/1NWD_B.pt +3 -0
- data/sample/esm2_embeddings/cam/1SY9_A.pt +3 -0
- data/sample/esm2_embeddings/cam/1SY9_B.pt +3 -0
- data/sample/esm2_embeddings/cam/2BBM_A.pt +3 -0
- data/sample/esm2_embeddings/cam/2BBM_B.pt +3 -0
- data/sample/esm2_embeddings/cam/2HQW_A.pt +3 -0
- data/sample/esm2_embeddings/cam/2HQW_B.pt +3 -0
- data/sample/esm2_embeddings/cam/2O5G_A.pt +3 -0
- data/sample/esm2_embeddings/cam/2O5G_B.pt +3 -0
- data/sample/esm2_embeddings/cam/3D33_A.pt +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.svg -text
|
.gitignore
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
.pytest_cache/
|
| 8 |
+
.coverage
|
| 9 |
+
*.egg-info/
|
| 10 |
+
|
| 11 |
+
# Env
|
| 12 |
+
.env
|
| 13 |
+
.venv/
|
| 14 |
+
venv/
|
| 15 |
+
env/
|
| 16 |
+
|
| 17 |
+
# OS
|
| 18 |
+
.DS_Store
|
| 19 |
+
Thumbs.db
|
| 20 |
+
|
| 21 |
+
# IDE
|
| 22 |
+
.idea/
|
| 23 |
+
.vscode/
|
| 24 |
+
*.swp
|
| 25 |
+
|
| 26 |
+
# Logs / runs / caches
|
| 27 |
+
*.log
|
| 28 |
+
logs/
|
| 29 |
+
outputs/
|
| 30 |
+
wandb/
|
| 31 |
+
.ipynb_checkpoints/
|
| 32 |
+
|
| 33 |
+
# Local scoring outputs
|
| 34 |
+
results/
|
| 35 |
+
/tmp_*
|
| 36 |
+
|
| 37 |
+
# Misc
|
| 38 |
+
*.tmp
|
| 39 |
+
*.bak
|
| 40 |
+
.allogen_test
|
| 41 |
+
.agent*_done.txt
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Hanqun Cao
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- protein-design
|
| 5 |
+
- allosteric
|
| 6 |
+
- state-selectivity
|
| 7 |
+
- guided-generation
|
| 8 |
+
- rfdiffusion
|
| 9 |
+
- pxdesign
|
| 10 |
+
- proteina
|
| 11 |
+
library_name: pytorch
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# AlloGen
|
| 15 |
+
|
| 16 |
+
<p align="center">
|
| 17 |
+
<img src="figures/allogen_main.png" alt="AlloGen method overview" width="100%"/>
|
| 18 |
+
</p>
|
| 19 |
+
|
| 20 |
+
State-selectivity scoring + guided generation for allosteric binder design.
|
| 21 |
+
|
| 22 |
+
🧪 **One-click demo for biology users:**
|
| 23 |
+
[](https://colab.research.google.com/#fileId=https%3A//huggingface.co/ChatterjeeLab/AlloGen/raw/main/notebooks/AlloGen_CaM_demo.ipynb) — score CaM binders and run Q_θ-guided PXDesign sampling in 5 minutes. Notebook lives at [`notebooks/AlloGen_CaM_demo.ipynb`](notebooks/AlloGen_CaM_demo.ipynb).
|
| 24 |
+
|
| 25 |
+
AlloGen trains a scorer Q_θ(X, Y) ∈ (0,1) that ranks how well a binder Y discriminates a target's **holo** (active) state X¹ from its **apo** (inactive) state X⁰. The selectivity score is:
|
| 26 |
+
|
| 27 |
+
S(Y) = Q_θ(X¹, Y) − Q_θ(X⁰, Y)
|
| 28 |
+
|
| 29 |
+
Q_θ serves as both a re-ranker (best-of-K) and a gradient signal for guided generation on top of frozen priors (RFdiffusion, PXDesign, Proteina-ComplexA) via Langevin, SMC, TDS, or classifier guidance.
|
| 30 |
+
|
| 31 |
+
This repository accompanies the paper *AlloGen: State-Selective Scoring for Allosteric Binder Design* (NeurIPS 2026).
|
| 32 |
+
|
| 33 |
+
## Installation
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
conda env create -f environment.yml
|
| 37 |
+
conda activate allogen
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
Or pip-only:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
python -m venv .venv && source .venv/bin/activate
|
| 44 |
+
pip install -r requirements.txt
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
Python 3.10 + PyTorch 2.x are required. A CUDA GPU is recommended for guidance, but CPU works for scoring single designs.
|
| 48 |
+
|
| 49 |
+
## Inference quickstart
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
# Score the bundled CaM inference sample against the v4-S2 (target-swap) checkpoint
|
| 53 |
+
python code/scripts/evaluate.py \
|
| 54 |
+
--target cam \
|
| 55 |
+
--checkpoint checkpoints/Q_theta_phase2.pt \
|
| 56 |
+
--data_dir data/sample/ \
|
| 57 |
+
--outdir /tmp/cam_inference \
|
| 58 |
+
--no_wandb
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
See [`inference.md`](inference.md) for the scoring API + guidance command lines.
|
| 62 |
+
|
| 63 |
+
## Repo layout
|
| 64 |
+
|
| 65 |
+
```
|
| 66 |
+
code/
|
| 67 |
+
data/ dataset / graph construction, PDB I/O, target YAMLs
|
| 68 |
+
models/ Q_θ scorer (graph transformer) + differentiable wrapper
|
| 69 |
+
trainers/ two-phase training loop (DockQ regression + selectivity)
|
| 70 |
+
utils/ PDB I/O, backbone frames, SAM optimizer
|
| 71 |
+
scripts/ evaluate, rescore, PXDesign guidance (see scripts/README.md)
|
| 72 |
+
checkpoints/ Q_θ paper weights (v4-S2 target-swap split, via Git LFS)
|
| 73 |
+
data/sample/ tiny CaM inference sample (test split only)
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## Checkpoints
|
| 77 |
+
|
| 78 |
+
Paper weights for the **v4-S2 target-swap** split are bundled via **Git LFS**:
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
git lfs install
|
| 82 |
+
git lfs pull
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
| File | Use |
|
| 86 |
+
|---|---|
|
| 87 |
+
| `checkpoints/Q_theta_phase1.pt` | Phase 1 (DockQ regression) intermediate checkpoint |
|
| 88 |
+
| `checkpoints/Q_theta_phase2.pt` | Phase 2 (selectivity) — main paper result |
|
| 89 |
+
| `checkpoints/Q_theta_train_curve.csv` | Training curve metadata |
|
| 90 |
+
|
| 91 |
+
## Scoring a single design
|
| 92 |
+
|
| 93 |
+
```python
|
| 94 |
+
import sys; sys.path.insert(0, 'code')
|
| 95 |
+
from models.differentiable_features import DifferentiableQTheta
|
| 96 |
+
|
| 97 |
+
scorer = DifferentiableQTheta(
|
| 98 |
+
checkpoint='checkpoints/Q_theta_phase2.pt',
|
| 99 |
+
device='cuda:0',
|
| 100 |
+
)
|
| 101 |
+
scorer.load_receptor(
|
| 102 |
+
holo_path='your_holo.pdb', rec_chain='A',
|
| 103 |
+
apo_path='your_apo.pdb', apo_chain='A',
|
| 104 |
+
)
|
| 105 |
+
q_holo = scorer.score('design.pdb', binder_chain='B', state='holo')
|
| 106 |
+
q_apo = scorer.score('design.pdb', binder_chain='B', state='apo')
|
| 107 |
+
print(f'S = {q_holo - q_apo:.3f}')
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
## Guidance methods
|
| 111 |
+
|
| 112 |
+
The shipped guidance code wraps **PXDesign** as the prior and uses Q_θ as the gradient / classifier signal. All four method variants (Langevin, SMC, TDS, classifier guidance) live in `code/scripts/pxdesign_guidance/`.
|
| 113 |
+
|
| 114 |
+
See [`inference.md`](inference.md) §3 for command lines.
|
| 115 |
+
|
| 116 |
+
To deploy Q_θ with **RFdiffusion**, **Proteina-ComplexA**, or any other backbone prior, see [`code/scripts/README.md`](code/scripts/README.md) — Q_θ exposes `DifferentiableQTheta` for `∇_x S(x)`, and the PXDesign code is a worked template to mirror.
|
| 117 |
+
|
| 118 |
+
## Citation
|
| 119 |
+
|
| 120 |
+
```bibtex
|
| 121 |
+
@inproceedings{cao2026allogen,
|
| 122 |
+
title = {AlloGen: State-Selective Scoring for Allosteric Binder Design},
|
| 123 |
+
author = {Cao, Hanqun and others},
|
| 124 |
+
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
|
| 125 |
+
year = {2026}
|
| 126 |
+
}
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
(BibTeX key will be finalized at camera-ready.)
|
| 130 |
+
|
| 131 |
+
## License
|
| 132 |
+
|
| 133 |
+
MIT — see [`LICENSE`](LICENSE).
|
checkpoints/Q_theta_phase1.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1684955f481c1406b12cc0e1ec3509a2a2e2def8b0a9071ec0c96be00d330e7c
|
| 3 |
+
size 3617774
|
checkpoints/Q_theta_phase2.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:716e4716c0014a46cfd4a2e26c238486c8ad1e3a03809320247cfb734538f8c6
|
| 3 |
+
size 3618158
|
checkpoints/Q_theta_train_curve.csv
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
epoch,loss,test_rho,cam_rho,rho_cam,rho_bcl2,rho_era,rho_mdm2,rho_ran,rho_a2a,rho_pai1,rho_integrin
|
| 2 |
+
1,5.527279376983643,0.44785173068985473,0.48303455690607044,0.48303455690607044,0.5733270876635986,0.6666325520384175,0.655984079559599,0.4957646288057514,0.28072319745279745,0.32040567098387046,0.10694207210873281
|
| 3 |
+
2,5.452569961547852,0.45406694660326774,0.4863130719981931,0.4863130719981931,0.5755749254666933,0.6683595793753047,0.6593234065151792,0.49658672598782,0.29627094377326,0.33010796670396186,0.11999895300572927
|
| 4 |
+
3,5.449374318122864,0.44742845709394635,0.4966343232141348,0.4966343232141348,0.5758960451528496,0.6612355916106455,0.6531535587702088,0.4753625218665822,0.23926254059823043,0.33456826378855364,0.14331481175036578
|
| 5 |
+
4,5.44292688369751,0.4461415166784438,0.49274867569754505,0.49274867569754505,0.5728775201029797,0.6614514700277563,0.6451799702468982,0.46946643826626333,0.242717595336111,0.34013213953325055,0.1445583242167464
|
| 6 |
+
5,5.441892623901367,0.43839115930587513,0.49080585193925014,0.49080585193925014,0.5653633194469204,0.6586450506053149,0.6361543905961949,0.4544470120828291,0.21334963006412605,0.33976427997988223,0.1485997397324834
|
| 7 |
+
6,5.439353764057159,0.4284122153367078,0.46931336411311275,0.46931336411311275,0.5655559912586142,0.6597244426908693,0.634838828707037,0.4423570277236172,0.19089177426790224,0.3362236317787114,0.1283926621537984
|
| 8 |
+
7,5.439265787601471,0.42722475631179213,0.45923496586695794,0.45923496586695794,0.5661982306309269,0.6614514700277563,0.635386597507575,0.4425217082123247,0.20816704795730515,0.3344762989002116,0.1103617313912795
|
| 9 |
+
8,5.439032256603241,0.423813755390132,0.45146367083377825,0.45146367083377825,0.5662624545681583,0.6633943757817542,0.6325423252001796,0.4388258604749022,0.2029844658504843,0.33617764933454036,0.09885924107725882
|
| 10 |
+
9,5.435689151287079,0.4221364421714867,0.45122081786399143,0.45122081786399143,0.5656202151958456,0.660156199525091,0.6294276063720167,0.4378717142280125,0.19952941111260372,0.33502808823026414,0.09823748484406851
|
| 11 |
+
10,5.43562650680542,0.419534099250895,0.4498851265301637,0.4498851265301637,0.5641430646395261,0.662099105279089,0.6279539020257997,0.4329763310874754,0.1822541374232008,0.33530398289529045,0.10165714412661521
|
| 12 |
+
11,5.435124695301056,0.41942099211035866,0.451949376773352,0.451949376773352,0.5633723773927508,0.6612355916106455,0.6249514872613453,0.43204285159336153,0.1822541374232008,0.3360397020020272,0.10352241282618613
|
| 13 |
+
12,5.440709412097931,0.41971570222256677,0.4518279502884586,0.4518279502884586,0.5642072885767574,0.6612355916106455,0.6261524531671271,0.43317581854869713,0.1822541374232008,0.33534996533946143,0.10352241282618613
|
| 14 |
+
13,5.430392742156982,0.4198042517465396,0.4523136562280323,0.4523136562280323,0.5629228098321319,0.6627467405304217,0.6259874349510655,0.4334288217301594,0.1822541374232008,0.33525800045111936,0.10352241282618613
|
| 15 |
+
14,5.436960756778717,0.4199312463250485,0.4515850973186718,0.4515850973186718,0.5639503928278323,0.6627467405304217,0.6264251916075623,0.4329262960644863,0.1822541374232008,0.3360397020020272,0.10352241282618613
|
| 16 |
+
15,5.433559775352478,0.419866834088955,0.4515850973186718,0.4515850973186718,0.5635650492044447,0.6627467405304217,0.6264251916075623,0.432842324243296,0.1822541374232008,0.3359937195578562,0.10352241282618613
|
code/__init__.py
ADDED
|
File without changes
|
code/data/__init__.py
ADDED
|
File without changes
|
code/data/dataset.py
ADDED
|
@@ -0,0 +1,832 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch Dataset for two-state complex scoring.
|
| 3 |
+
|
| 4 |
+
Loads preprocessed graph data and provides batched tensors
|
| 5 |
+
with padding for variable-sized interface graphs.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import pickle
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import Dataset, DataLoader
|
| 14 |
+
|
| 15 |
+
# Global ESM embedding cache: {file_path: tensor}
|
| 16 |
+
_ESM_CACHE = {}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def preload_esm_cache(esm_dir, targets):
|
| 20 |
+
"""Preload all ESM .pt files into global cache before DataLoader workers fork.
|
| 21 |
+
|
| 22 |
+
This ensures forked workers inherit the populated cache via copy-on-write,
|
| 23 |
+
avoiding redundant I/O across workers.
|
| 24 |
+
"""
|
| 25 |
+
import glob as glob_mod
|
| 26 |
+
n = 0
|
| 27 |
+
for target in targets:
|
| 28 |
+
target_dir = os.path.join(esm_dir, target)
|
| 29 |
+
if not os.path.isdir(target_dir):
|
| 30 |
+
continue
|
| 31 |
+
for pt_file in glob_mod.glob(os.path.join(target_dir, '*.pt')):
|
| 32 |
+
if pt_file not in _ESM_CACHE:
|
| 33 |
+
_ESM_CACHE[pt_file] = torch.load(pt_file, map_location='cpu', weights_only=True)
|
| 34 |
+
n += 1
|
| 35 |
+
return n
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_esm_for_sample(sample, esm_dir, target_name, max_nodes=128):
|
| 39 |
+
"""Load and index ESM-2 embeddings for a sample's interface residues.
|
| 40 |
+
|
| 41 |
+
Returns: esm_feats [max_nodes, 1280] or None if unavailable.
|
| 42 |
+
"""
|
| 43 |
+
graph = sample['graph']
|
| 44 |
+
rec_idx = graph.get('rec_iface_idx')
|
| 45 |
+
binder_idx = graph.get('binder_iface_idx')
|
| 46 |
+
if rec_idx is None or binder_idx is None:
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
# Get PDB ID (strip chain suffix like "2G1T_AE" -> "2G1T")
|
| 50 |
+
pdb_id = sample.get('pdb', '')
|
| 51 |
+
base_pdb = pdb_id.split('_')[0] if '_' in pdb_id else pdb_id
|
| 52 |
+
rec_chain = sample.get('rec_chain_id', 'A')
|
| 53 |
+
binder_chain = sample.get('binder_chain_id', 'B')
|
| 54 |
+
|
| 55 |
+
# Load ESM embeddings (cached)
|
| 56 |
+
rec_path = os.path.join(esm_dir, target_name, f'{base_pdb}_{rec_chain}.pt')
|
| 57 |
+
binder_path = os.path.join(esm_dir, target_name, f'{base_pdb}_{binder_chain}.pt')
|
| 58 |
+
|
| 59 |
+
def _load_cached(path):
|
| 60 |
+
if path not in _ESM_CACHE:
|
| 61 |
+
if not os.path.exists(path):
|
| 62 |
+
return None
|
| 63 |
+
_ESM_CACHE[path] = torch.load(path, map_location='cpu', weights_only=True)
|
| 64 |
+
return _ESM_CACHE[path]
|
| 65 |
+
|
| 66 |
+
rec_esm = _load_cached(rec_path)
|
| 67 |
+
binder_esm = _load_cached(binder_path)
|
| 68 |
+
if rec_esm is None or binder_esm is None:
|
| 69 |
+
return None
|
| 70 |
+
|
| 71 |
+
esm_dim = rec_esm.shape[-1] # 1280
|
| 72 |
+
n_rec = len(rec_idx)
|
| 73 |
+
n_binder = len(binder_idx)
|
| 74 |
+
|
| 75 |
+
# Index ESM embeddings by interface residue indices (clamp to valid range)
|
| 76 |
+
rec_idx_safe = np.clip(rec_idx, 0, len(rec_esm) - 1)
|
| 77 |
+
binder_idx_safe = np.clip(binder_idx, 0, len(binder_esm) - 1)
|
| 78 |
+
|
| 79 |
+
esm_feats = np.zeros((max_nodes, esm_dim), dtype=np.float32)
|
| 80 |
+
esm_feats[:n_rec] = rec_esm[rec_idx_safe].numpy()
|
| 81 |
+
esm_feats[n_rec:n_rec + n_binder] = binder_esm[binder_idx_safe].numpy()
|
| 82 |
+
|
| 83 |
+
return esm_feats
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def load_rosetta_labels(rosetta_dir, target):
|
| 87 |
+
"""Load Rosetta dG labels for a target and normalize to [0,1]."""
|
| 88 |
+
path = os.path.join(rosetta_dir, f'{target}_rosetta.json')
|
| 89 |
+
if not os.path.exists(path):
|
| 90 |
+
return None
|
| 91 |
+
with open(path) as f:
|
| 92 |
+
raw = json.load(f)
|
| 93 |
+
if not raw:
|
| 94 |
+
return None
|
| 95 |
+
# Filter outliers: dG values outside [-500, 500] are failed Rosetta runs
|
| 96 |
+
dG_MIN, dG_MAX = -500.0, 500.0
|
| 97 |
+
# Normalize: sigmoid(-dG / tau) maps dG to [0,1]
|
| 98 |
+
# More negative dG = better binding = higher score
|
| 99 |
+
tau = 15.0 # temperature; dG=-30 -> 0.88, dG=-15 -> 0.73, dG=0 -> 0.5
|
| 100 |
+
labels = {}
|
| 101 |
+
for pdb_id, metrics in raw.items():
|
| 102 |
+
dG = metrics.get('dG_separated', 0.0)
|
| 103 |
+
if not np.isfinite(dG) or dG < dG_MIN or dG > dG_MAX:
|
| 104 |
+
continue # skip failed Rosetta runs
|
| 105 |
+
labels[pdb_id] = 1.0 / (1.0 + np.exp(dG / tau))
|
| 106 |
+
labels[pdb_id.upper()] = labels[pdb_id]
|
| 107 |
+
labels[pdb_id.lower()] = labels[pdb_id]
|
| 108 |
+
return labels
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def apply_rosetta_labels(samples, rosetta_labels, label_source='rosetta', alpha=0.5):
|
| 112 |
+
"""Replace or combine sample labels with Rosetta-derived labels."""
|
| 113 |
+
if rosetta_labels is None:
|
| 114 |
+
return
|
| 115 |
+
n_replaced = 0
|
| 116 |
+
for s in samples:
|
| 117 |
+
pdb_id = s.get('pdb', '')
|
| 118 |
+
# Strip chain suffixes: "2G1T_AE" -> "2G1T"
|
| 119 |
+
base_pdb = pdb_id.split('_')[0] if '_' in pdb_id else pdb_id
|
| 120 |
+
rosetta_val = rosetta_labels.get(base_pdb) or rosetta_labels.get(base_pdb.upper())
|
| 121 |
+
if rosetta_val is None:
|
| 122 |
+
continue
|
| 123 |
+
if s['type'] == 'positive':
|
| 124 |
+
new_label = rosetta_val
|
| 125 |
+
elif s['type'].startswith('negative'):
|
| 126 |
+
new_label = 0.0 # apo mismatch stays 0
|
| 127 |
+
continue
|
| 128 |
+
elif s['type'].startswith('decoy'):
|
| 129 |
+
# Scale Rosetta label by DockQ-proxy quality
|
| 130 |
+
new_label = s['label'] * rosetta_val
|
| 131 |
+
else:
|
| 132 |
+
continue
|
| 133 |
+
if label_source == 'rosetta':
|
| 134 |
+
s['label'] = float(new_label)
|
| 135 |
+
elif label_source == 'combined':
|
| 136 |
+
s['label'] = float(alpha * s['label'] + (1 - alpha) * new_label)
|
| 137 |
+
n_replaced += 1
|
| 138 |
+
return n_replaced
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class TwoStateComplexDataset(Dataset):
|
| 142 |
+
"""
|
| 143 |
+
Dataset of protein complex interface graphs with two-state labels.
|
| 144 |
+
|
| 145 |
+
Each sample contains:
|
| 146 |
+
node_feats: [N, node_dim] interface residue features
|
| 147 |
+
edge_feats: [N, N, edge_dim] pairwise SE(3)-invariant features
|
| 148 |
+
node_mask: [N] bool
|
| 149 |
+
label: scalar float in [0, 1] (DockQ proxy / selectivity label)
|
| 150 |
+
type: str (positive / negative_apo / decoy_*)
|
| 151 |
+
pdb: str
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False,
|
| 155 |
+
rosetta_labels: dict = None, label_source: str = 'dockq',
|
| 156 |
+
esm_dir: str = None, target_name: str = None,
|
| 157 |
+
binder_dropout: float = 0.0):
|
| 158 |
+
with open(data_path, 'rb') as f:
|
| 159 |
+
self.samples = pickle.load(f)
|
| 160 |
+
self.max_nodes = max_nodes
|
| 161 |
+
self.augment = augment
|
| 162 |
+
self.esm_dir = esm_dir
|
| 163 |
+
self.target_name = target_name
|
| 164 |
+
self.binder_dropout = binder_dropout
|
| 165 |
+
if label_source != 'dockq' and rosetta_labels:
|
| 166 |
+
apply_rosetta_labels(self.samples, rosetta_labels, label_source)
|
| 167 |
+
|
| 168 |
+
def __len__(self):
|
| 169 |
+
return len(self.samples)
|
| 170 |
+
|
| 171 |
+
def __getitem__(self, idx):
|
| 172 |
+
sample = self.samples[idx]
|
| 173 |
+
graph = sample['graph']
|
| 174 |
+
|
| 175 |
+
node_feats = graph['node_feats'] # [N, node_dim]
|
| 176 |
+
edge_feats = graph['edge_feats'] # [N, N, edge_dim]
|
| 177 |
+
node_mask = graph['node_mask'] # [N]
|
| 178 |
+
|
| 179 |
+
N = len(node_feats)
|
| 180 |
+
assert N <= self.max_nodes, f"Too many nodes: {N} > {self.max_nodes}"
|
| 181 |
+
|
| 182 |
+
# Pad to max_nodes
|
| 183 |
+
node_dim = node_feats.shape[-1]
|
| 184 |
+
edge_dim = edge_feats.shape[-1]
|
| 185 |
+
|
| 186 |
+
node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
|
| 187 |
+
edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
|
| 188 |
+
node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
|
| 189 |
+
|
| 190 |
+
node_feats_pad[:N] = node_feats
|
| 191 |
+
edge_feats_pad[:N, :N] = edge_feats
|
| 192 |
+
node_mask_pad[:N] = node_mask
|
| 193 |
+
|
| 194 |
+
# Optional: random coordinate noise augmentation
|
| 195 |
+
if self.augment:
|
| 196 |
+
noise = np.random.randn(*node_feats_pad.shape) * 0.01
|
| 197 |
+
node_feats_pad = node_feats_pad + noise.astype(np.float32)
|
| 198 |
+
|
| 199 |
+
# Binder-dropout: simulate backbone-only designs by masking binder
|
| 200 |
+
# sequence features (AA one-hot → UNK, chi angles → 0)
|
| 201 |
+
apply_binder_drop = (self.binder_dropout > 0
|
| 202 |
+
and np.random.rand() < self.binder_dropout)
|
| 203 |
+
if apply_binder_drop:
|
| 204 |
+
n_rec = graph.get('n_rec', N // 2)
|
| 205 |
+
# Zero out binder AA one-hot (dims 0-20), set UNK (dim 20 = 1)
|
| 206 |
+
node_feats_pad[n_rec:N, :21] = 0.0
|
| 207 |
+
node_feats_pad[n_rec:N, 20] = 1.0 # UNK
|
| 208 |
+
# Zero out binder chi angles (dims 27-30)
|
| 209 |
+
node_feats_pad[n_rec:N, 27:31] = 0.0
|
| 210 |
+
# Keep backbone torsions (dims 21-26) and chain indicator (dim 31)
|
| 211 |
+
|
| 212 |
+
result = {
|
| 213 |
+
'node_feats': torch.from_numpy(node_feats_pad), # [max_nodes, node_dim]
|
| 214 |
+
'edge_feats': torch.from_numpy(edge_feats_pad), # [max_nodes, max_nodes, edge_dim]
|
| 215 |
+
'node_mask': torch.from_numpy(node_mask_pad), # [max_nodes]
|
| 216 |
+
'label': torch.tensor(sample['label'], dtype=torch.float32),
|
| 217 |
+
'type': sample['type'],
|
| 218 |
+
'pdb': sample['pdb'],
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
# ESM-2 features (lazy load; zero-fill if unavailable)
|
| 222 |
+
if self.esm_dir:
|
| 223 |
+
esm = load_esm_for_sample(sample, self.esm_dir,
|
| 224 |
+
self.target_name or '', self.max_nodes)
|
| 225 |
+
if esm is not None:
|
| 226 |
+
esm_feats = esm
|
| 227 |
+
else:
|
| 228 |
+
esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32)
|
| 229 |
+
# Zero binder ESM if binder-dropout active
|
| 230 |
+
if apply_binder_drop:
|
| 231 |
+
n_rec = graph.get('n_rec', N // 2)
|
| 232 |
+
n_binder = graph.get('n_binder', N - n_rec)
|
| 233 |
+
esm_feats[n_rec:n_rec + n_binder] = 0.0
|
| 234 |
+
result['esm_feats'] = torch.from_numpy(esm_feats)
|
| 235 |
+
|
| 236 |
+
return result
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def collate_fn(batch):
|
| 240 |
+
"""Collate a list of samples into batched tensors."""
|
| 241 |
+
node_feats = torch.stack([s['node_feats'] for s in batch])
|
| 242 |
+
edge_feats = torch.stack([s['edge_feats'] for s in batch])
|
| 243 |
+
node_mask = torch.stack([s['node_mask'] for s in batch])
|
| 244 |
+
labels = torch.stack([s['label'] for s in batch])
|
| 245 |
+
types = [s['type'] for s in batch]
|
| 246 |
+
pdbs = [s['pdb'] for s in batch]
|
| 247 |
+
|
| 248 |
+
result = {
|
| 249 |
+
'node_feats': node_feats, # [B, N, node_dim]
|
| 250 |
+
'edge_feats': edge_feats, # [B, N, N, edge_dim]
|
| 251 |
+
'node_mask': node_mask, # [B, N]
|
| 252 |
+
'label': labels, # [B]
|
| 253 |
+
'type': types,
|
| 254 |
+
'pdb': pdbs,
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
# Stack ESM features if present (handle mixed availability with zero-fill)
|
| 258 |
+
has_esm = any('esm_feats' in s for s in batch)
|
| 259 |
+
if has_esm:
|
| 260 |
+
esm_list = []
|
| 261 |
+
for s in batch:
|
| 262 |
+
if 'esm_feats' in s:
|
| 263 |
+
esm_list.append(s['esm_feats'])
|
| 264 |
+
else:
|
| 265 |
+
# Get shape from a sample that has ESM
|
| 266 |
+
ref = next(x['esm_feats'] for x in batch if 'esm_feats' in x)
|
| 267 |
+
esm_list.append(torch.zeros_like(ref))
|
| 268 |
+
result['esm_feats'] = torch.stack(esm_list)
|
| 269 |
+
|
| 270 |
+
return result
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class TwoStateDatasetPaired(Dataset):
|
| 274 |
+
"""
|
| 275 |
+
Paired dataset: returns (positive, negative) pairs for selectivity training.
|
| 276 |
+
Groups samples by PDB ID and pairs positive (holo) with negative (apo) examples.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False,
|
| 280 |
+
esm_dir: str = None, target_name: str = None,
|
| 281 |
+
binder_dropout: float = 0.0):
|
| 282 |
+
with open(data_path, 'rb') as f:
|
| 283 |
+
samples = pickle.load(f)
|
| 284 |
+
self.max_nodes = max_nodes
|
| 285 |
+
self.augment = augment
|
| 286 |
+
self.esm_dir = esm_dir
|
| 287 |
+
self.target_name = target_name
|
| 288 |
+
self.binder_dropout = binder_dropout
|
| 289 |
+
|
| 290 |
+
# Group by PDB
|
| 291 |
+
from collections import defaultdict
|
| 292 |
+
by_pdb = defaultdict(lambda: {'positive': [], 'negative': [], 'decoy': []})
|
| 293 |
+
for s in samples:
|
| 294 |
+
pdb = s['pdb']
|
| 295 |
+
t = s['type']
|
| 296 |
+
if t == 'positive':
|
| 297 |
+
by_pdb[pdb]['positive'].append(s)
|
| 298 |
+
elif t.startswith('negative'):
|
| 299 |
+
by_pdb[pdb]['negative'].append(s)
|
| 300 |
+
elif t.startswith('decoy'):
|
| 301 |
+
by_pdb[pdb]['decoy'].append(s)
|
| 302 |
+
|
| 303 |
+
# Build pairs: (positive, negative) per PDB
|
| 304 |
+
self.pairs = []
|
| 305 |
+
for pdb, groups in by_pdb.items():
|
| 306 |
+
if len(groups['positive']) > 0 and len(groups['negative']) > 0:
|
| 307 |
+
for pos in groups['positive']:
|
| 308 |
+
for neg in groups['negative']:
|
| 309 |
+
self.pairs.append((pos, neg))
|
| 310 |
+
# Also add (positive, decoy_large_rmsd) pairs
|
| 311 |
+
if len(groups['positive']) > 0 and len(groups['decoy']) > 0:
|
| 312 |
+
large_decoys = [s for s in groups['decoy'] if 'rmsd' in s['type'] and
|
| 313 |
+
float(s['type'].replace('decoy_rmsd', '')) > 4.0]
|
| 314 |
+
for pos in groups['positive']:
|
| 315 |
+
for neg in large_decoys[:3]: # limit to 3 hard decoys per positive
|
| 316 |
+
self.pairs.append((pos, neg))
|
| 317 |
+
|
| 318 |
+
def __len__(self):
|
| 319 |
+
return len(self.pairs)
|
| 320 |
+
|
| 321 |
+
def _prepare(self, sample, apply_binder_drop=False):
|
| 322 |
+
graph = sample['graph']
|
| 323 |
+
node_feats = graph['node_feats']
|
| 324 |
+
edge_feats = graph['edge_feats']
|
| 325 |
+
node_mask = graph['node_mask']
|
| 326 |
+
N = len(node_feats)
|
| 327 |
+
node_dim = node_feats.shape[-1]
|
| 328 |
+
edge_dim = edge_feats.shape[-1]
|
| 329 |
+
|
| 330 |
+
node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
|
| 331 |
+
edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
|
| 332 |
+
node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
|
| 333 |
+
|
| 334 |
+
n = min(N, self.max_nodes)
|
| 335 |
+
node_feats_pad[:n] = node_feats[:n]
|
| 336 |
+
edge_feats_pad[:n, :n] = edge_feats[:n, :n]
|
| 337 |
+
node_mask_pad[:n] = node_mask[:n]
|
| 338 |
+
|
| 339 |
+
# Binder-dropout: simulate backbone-only designs
|
| 340 |
+
if apply_binder_drop:
|
| 341 |
+
n_rec = graph.get('n_rec', n // 2)
|
| 342 |
+
node_feats_pad[n_rec:n, :21] = 0.0
|
| 343 |
+
node_feats_pad[n_rec:n, 20] = 1.0 # UNK
|
| 344 |
+
node_feats_pad[n_rec:n, 27:31] = 0.0
|
| 345 |
+
|
| 346 |
+
result = {
|
| 347 |
+
'node_feats': torch.from_numpy(node_feats_pad),
|
| 348 |
+
'edge_feats': torch.from_numpy(edge_feats_pad),
|
| 349 |
+
'node_mask': torch.from_numpy(node_mask_pad),
|
| 350 |
+
'label': torch.tensor(sample['label'], dtype=torch.float32),
|
| 351 |
+
'contact_energy': torch.tensor(
|
| 352 |
+
sample.get('contact_energy', 0.5), dtype=torch.float32
|
| 353 |
+
),
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
# ESM-2 features (zero-fill if unavailable)
|
| 357 |
+
if self.esm_dir:
|
| 358 |
+
esm = load_esm_for_sample(sample, self.esm_dir,
|
| 359 |
+
self.target_name or '', self.max_nodes)
|
| 360 |
+
if esm is not None:
|
| 361 |
+
esm_feats = esm
|
| 362 |
+
else:
|
| 363 |
+
esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32)
|
| 364 |
+
if apply_binder_drop:
|
| 365 |
+
n_rec = graph.get('n_rec', n // 2)
|
| 366 |
+
n_binder = graph.get('n_binder', n - n_rec)
|
| 367 |
+
esm_feats[n_rec:n_rec + n_binder] = 0.0
|
| 368 |
+
result['esm_feats'] = torch.from_numpy(esm_feats)
|
| 369 |
+
|
| 370 |
+
return result
|
| 371 |
+
|
| 372 |
+
def __getitem__(self, idx):
|
| 373 |
+
pos_sample, neg_sample = self.pairs[idx]
|
| 374 |
+
# Same dropout decision for both pos and neg in a pair
|
| 375 |
+
drop = (self.binder_dropout > 0
|
| 376 |
+
and np.random.rand() < self.binder_dropout)
|
| 377 |
+
return {
|
| 378 |
+
'pos': self._prepare(pos_sample, apply_binder_drop=drop),
|
| 379 |
+
'neg': self._prepare(neg_sample, apply_binder_drop=drop),
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def collate_paired_fn(batch):
|
| 384 |
+
"""Collate paired (positive, negative) samples."""
|
| 385 |
+
pos_batch = {
|
| 386 |
+
'node_feats': torch.stack([s['pos']['node_feats'] for s in batch]),
|
| 387 |
+
'edge_feats': torch.stack([s['pos']['edge_feats'] for s in batch]),
|
| 388 |
+
'node_mask': torch.stack([s['pos']['node_mask'] for s in batch]),
|
| 389 |
+
'label': torch.stack([s['pos']['label'] for s in batch]),
|
| 390 |
+
'contact_energy': torch.stack([s['pos']['contact_energy'] for s in batch]),
|
| 391 |
+
}
|
| 392 |
+
neg_batch = {
|
| 393 |
+
'node_feats': torch.stack([s['neg']['node_feats'] for s in batch]),
|
| 394 |
+
'edge_feats': torch.stack([s['neg']['edge_feats'] for s in batch]),
|
| 395 |
+
'node_mask': torch.stack([s['neg']['node_mask'] for s in batch]),
|
| 396 |
+
'label': torch.stack([s['neg']['label'] for s in batch]),
|
| 397 |
+
'contact_energy': torch.stack([s['neg']['contact_energy'] for s in batch]),
|
| 398 |
+
}
|
| 399 |
+
# ESM features (handle mixed availability)
|
| 400 |
+
has_pos_esm = any('esm_feats' in s['pos'] for s in batch)
|
| 401 |
+
if has_pos_esm:
|
| 402 |
+
def _stack_esm(batch_list, key):
|
| 403 |
+
esm_list = []
|
| 404 |
+
ref = next((x[key]['esm_feats'] for x in batch_list if 'esm_feats' in x[key]), None)
|
| 405 |
+
for s in batch_list:
|
| 406 |
+
if 'esm_feats' in s[key]:
|
| 407 |
+
esm_list.append(s[key]['esm_feats'])
|
| 408 |
+
else:
|
| 409 |
+
esm_list.append(torch.zeros_like(ref))
|
| 410 |
+
return torch.stack(esm_list)
|
| 411 |
+
pos_batch['esm_feats'] = _stack_esm(batch, 'pos')
|
| 412 |
+
neg_batch['esm_feats'] = _stack_esm(batch, 'neg')
|
| 413 |
+
return {'pos': pos_batch, 'neg': neg_batch}
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class PathAwareDatasetPaired(Dataset):
|
| 417 |
+
"""
|
| 418 |
+
Paired dataset with transition-path frames for path-aware Phase 2 training.
|
| 419 |
+
|
| 420 |
+
Extends TwoStateDatasetPaired: each sample returns (positive, negative, path_frames)
|
| 421 |
+
where path_frames is a list of prepared graph dicts for intermediate conformations
|
| 422 |
+
stored in the positive sample's 'path_graphs' field.
|
| 423 |
+
"""
|
| 424 |
+
|
| 425 |
+
def __init__(self, data_path: str, max_nodes: int = 128, augment: bool = False):
|
| 426 |
+
with open(data_path, 'rb') as f:
|
| 427 |
+
samples = pickle.load(f)
|
| 428 |
+
self.max_nodes = max_nodes
|
| 429 |
+
self.augment = augment
|
| 430 |
+
|
| 431 |
+
from collections import defaultdict
|
| 432 |
+
by_pdb = defaultdict(lambda: {'positive': [], 'negative': [], 'decoy': []})
|
| 433 |
+
for s in samples:
|
| 434 |
+
pdb = s['pdb']
|
| 435 |
+
t = s['type']
|
| 436 |
+
if t == 'positive':
|
| 437 |
+
by_pdb[pdb]['positive'].append(s)
|
| 438 |
+
elif t.startswith('negative'):
|
| 439 |
+
by_pdb[pdb]['negative'].append(s)
|
| 440 |
+
elif t.startswith('decoy'):
|
| 441 |
+
by_pdb[pdb]['decoy'].append(s)
|
| 442 |
+
|
| 443 |
+
self.pairs = []
|
| 444 |
+
for pdb, groups in by_pdb.items():
|
| 445 |
+
if len(groups['positive']) > 0 and len(groups['negative']) > 0:
|
| 446 |
+
for pos in groups['positive']:
|
| 447 |
+
for neg in groups['negative']:
|
| 448 |
+
self.pairs.append((pos, neg))
|
| 449 |
+
if len(groups['positive']) > 0 and len(groups['decoy']) > 0:
|
| 450 |
+
large_decoys = [s for s in groups['decoy'] if 'rmsd' in s['type'] and
|
| 451 |
+
float(s['type'].replace('decoy_rmsd', '')) > 4.0]
|
| 452 |
+
for pos in groups['positive']:
|
| 453 |
+
for neg in large_decoys[:3]:
|
| 454 |
+
self.pairs.append((pos, neg))
|
| 455 |
+
|
| 456 |
+
def _prepare(self, sample):
|
| 457 |
+
graph = sample['graph']
|
| 458 |
+
node_feats = graph['node_feats']
|
| 459 |
+
edge_feats = graph['edge_feats']
|
| 460 |
+
node_mask = graph['node_mask']
|
| 461 |
+
N = len(node_feats)
|
| 462 |
+
node_dim = node_feats.shape[-1]
|
| 463 |
+
edge_dim = edge_feats.shape[-1]
|
| 464 |
+
|
| 465 |
+
node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
|
| 466 |
+
edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
|
| 467 |
+
node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
|
| 468 |
+
|
| 469 |
+
n = min(N, self.max_nodes)
|
| 470 |
+
node_feats_pad[:n] = node_feats[:n]
|
| 471 |
+
edge_feats_pad[:n, :n] = edge_feats[:n, :n]
|
| 472 |
+
node_mask_pad[:n] = node_mask[:n]
|
| 473 |
+
|
| 474 |
+
return {
|
| 475 |
+
'node_feats': torch.from_numpy(node_feats_pad),
|
| 476 |
+
'edge_feats': torch.from_numpy(edge_feats_pad),
|
| 477 |
+
'node_mask': torch.from_numpy(node_mask_pad),
|
| 478 |
+
'label': torch.tensor(sample.get('label', 0.0), dtype=torch.float32),
|
| 479 |
+
'contact_energy': torch.tensor(
|
| 480 |
+
sample.get('contact_energy', 0.5), dtype=torch.float32
|
| 481 |
+
),
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
def _prepare_graph_only(self, path_entry):
|
| 485 |
+
"""Prepare a path frame graph (no label/contact_energy needed)."""
|
| 486 |
+
graph = path_entry['graph']
|
| 487 |
+
node_feats = graph['node_feats']
|
| 488 |
+
edge_feats = graph['edge_feats']
|
| 489 |
+
node_mask = graph['node_mask']
|
| 490 |
+
N = len(node_feats)
|
| 491 |
+
node_dim = node_feats.shape[-1]
|
| 492 |
+
edge_dim = edge_feats.shape[-1]
|
| 493 |
+
|
| 494 |
+
node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
|
| 495 |
+
edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
|
| 496 |
+
node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
|
| 497 |
+
|
| 498 |
+
n = min(N, self.max_nodes)
|
| 499 |
+
node_feats_pad[:n] = node_feats[:n]
|
| 500 |
+
edge_feats_pad[:n, :n] = edge_feats[:n, :n]
|
| 501 |
+
node_mask_pad[:n] = node_mask[:n]
|
| 502 |
+
|
| 503 |
+
return {
|
| 504 |
+
'node_feats': torch.from_numpy(node_feats_pad),
|
| 505 |
+
'edge_feats': torch.from_numpy(edge_feats_pad),
|
| 506 |
+
'node_mask': torch.from_numpy(node_mask_pad),
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
def __len__(self):
|
| 510 |
+
return len(self.pairs)
|
| 511 |
+
|
| 512 |
+
def __getitem__(self, idx):
|
| 513 |
+
pos_sample, neg_sample = self.pairs[idx]
|
| 514 |
+
result = {
|
| 515 |
+
'pos': self._prepare(pos_sample),
|
| 516 |
+
'neg': self._prepare(neg_sample),
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
# Prepare path frames if available
|
| 520 |
+
path_graphs = pos_sample.get('path_graphs', [])
|
| 521 |
+
prepared_paths = []
|
| 522 |
+
path_taus = []
|
| 523 |
+
for pg in path_graphs:
|
| 524 |
+
prepared_paths.append(self._prepare_graph_only(pg))
|
| 525 |
+
path_taus.append(pg['tau'])
|
| 526 |
+
|
| 527 |
+
result['path'] = prepared_paths
|
| 528 |
+
result['path_taus'] = path_taus
|
| 529 |
+
|
| 530 |
+
return result
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def collate_path_paired_fn(batch):
|
| 534 |
+
"""Collate paired samples with variable-length path frames."""
|
| 535 |
+
pos_batch = {
|
| 536 |
+
'node_feats': torch.stack([s['pos']['node_feats'] for s in batch]),
|
| 537 |
+
'edge_feats': torch.stack([s['pos']['edge_feats'] for s in batch]),
|
| 538 |
+
'node_mask': torch.stack([s['pos']['node_mask'] for s in batch]),
|
| 539 |
+
'label': torch.stack([s['pos']['label'] for s in batch]),
|
| 540 |
+
'contact_energy': torch.stack([s['pos']['contact_energy'] for s in batch]),
|
| 541 |
+
}
|
| 542 |
+
neg_batch = {
|
| 543 |
+
'node_feats': torch.stack([s['neg']['node_feats'] for s in batch]),
|
| 544 |
+
'edge_feats': torch.stack([s['neg']['edge_feats'] for s in batch]),
|
| 545 |
+
'node_mask': torch.stack([s['neg']['node_mask'] for s in batch]),
|
| 546 |
+
'label': torch.stack([s['neg']['label'] for s in batch]),
|
| 547 |
+
'contact_energy': torch.stack([s['neg']['contact_energy'] for s in batch]),
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
# Collate path frames: find max K across batch, pad shorter ones
|
| 551 |
+
max_k = max((len(s['path']) for s in batch), default=0)
|
| 552 |
+
path_batches = []
|
| 553 |
+
path_taus = []
|
| 554 |
+
|
| 555 |
+
if max_k > 0:
|
| 556 |
+
# Build a zero-filled placeholder for padding (graph-only keys)
|
| 557 |
+
ref = batch[0]['path'][0] if batch[0]['path'] else batch[0]['pos']
|
| 558 |
+
zero_placeholder = {
|
| 559 |
+
'node_feats': torch.zeros_like(ref['node_feats']),
|
| 560 |
+
'edge_feats': torch.zeros_like(ref['edge_feats']),
|
| 561 |
+
'node_mask': torch.zeros_like(ref['node_mask']),
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
for k_idx in range(max_k):
|
| 565 |
+
frames_at_k = []
|
| 566 |
+
taus_at_k = []
|
| 567 |
+
for s in batch:
|
| 568 |
+
if k_idx < len(s['path']):
|
| 569 |
+
frames_at_k.append(s['path'][k_idx])
|
| 570 |
+
taus_at_k.append(s['path_taus'][k_idx])
|
| 571 |
+
else:
|
| 572 |
+
frames_at_k.append(zero_placeholder)
|
| 573 |
+
taus_at_k.append(1.0)
|
| 574 |
+
|
| 575 |
+
path_batches.append({
|
| 576 |
+
'node_feats': torch.stack([f['node_feats'] for f in frames_at_k]),
|
| 577 |
+
'edge_feats': torch.stack([f['edge_feats'] for f in frames_at_k]),
|
| 578 |
+
'node_mask': torch.stack([f['node_mask'] for f in frames_at_k]),
|
| 579 |
+
})
|
| 580 |
+
path_taus.append(taus_at_k[0])
|
| 581 |
+
|
| 582 |
+
result = {'pos': pos_batch, 'neg': neg_batch}
|
| 583 |
+
if path_batches:
|
| 584 |
+
result['path'] = path_batches
|
| 585 |
+
result['path_taus'] = path_taus
|
| 586 |
+
return result
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
class MultiTargetDataset(Dataset):
|
| 590 |
+
"""
|
| 591 |
+
Pooled dataset combining samples from multiple targets.
|
| 592 |
+
Supports balanced sampling across targets.
|
| 593 |
+
"""
|
| 594 |
+
|
| 595 |
+
def __init__(self, data_paths: list, max_nodes: int = 128, augment: bool = False,
|
| 596 |
+
balance: bool = True, rosetta_dir: str = None, label_source: str = 'dockq',
|
| 597 |
+
esm_dir: str = None, binder_dropout: float = 0.0):
|
| 598 |
+
"""
|
| 599 |
+
Args:
|
| 600 |
+
data_paths: list of (target_name, pkl_path) tuples
|
| 601 |
+
max_nodes: max interface graph size
|
| 602 |
+
augment: apply noise augmentation
|
| 603 |
+
balance: if True, oversample smaller targets to balance
|
| 604 |
+
rosetta_dir: directory containing Rosetta label JSONs
|
| 605 |
+
label_source: 'dockq', 'rosetta', or 'combined'
|
| 606 |
+
"""
|
| 607 |
+
self.max_nodes = max_nodes
|
| 608 |
+
self.augment = augment
|
| 609 |
+
self.esm_dir = esm_dir
|
| 610 |
+
self.binder_dropout = binder_dropout
|
| 611 |
+
|
| 612 |
+
# Load all samples with target labels
|
| 613 |
+
self.samples = []
|
| 614 |
+
self.target_indices = {} # target_name -> list of indices
|
| 615 |
+
|
| 616 |
+
for target_name, path in data_paths:
|
| 617 |
+
if not os.path.exists(path):
|
| 618 |
+
continue
|
| 619 |
+
with open(path, 'rb') as f:
|
| 620 |
+
target_samples = pickle.load(f)
|
| 621 |
+
|
| 622 |
+
# Apply Rosetta labels if requested
|
| 623 |
+
if label_source != 'dockq' and rosetta_dir:
|
| 624 |
+
rl = load_rosetta_labels(rosetta_dir, target_name)
|
| 625 |
+
if rl:
|
| 626 |
+
apply_rosetta_labels(target_samples, rl, label_source)
|
| 627 |
+
|
| 628 |
+
start_idx = len(self.samples)
|
| 629 |
+
for s in target_samples:
|
| 630 |
+
s['_target'] = target_name
|
| 631 |
+
self.samples.append(s)
|
| 632 |
+
end_idx = len(self.samples)
|
| 633 |
+
self.target_indices[target_name] = list(range(start_idx, end_idx))
|
| 634 |
+
|
| 635 |
+
# Build balanced sampling weights
|
| 636 |
+
if balance and len(self.target_indices) > 1:
|
| 637 |
+
non_empty = {k: v for k, v in self.target_indices.items() if len(v) > 0}
|
| 638 |
+
max_count = max(len(idxs) for idxs in non_empty.values()) if non_empty else 1
|
| 639 |
+
self.weights = np.zeros(len(self.samples))
|
| 640 |
+
for target_name, idxs in self.target_indices.items():
|
| 641 |
+
if len(idxs) == 0:
|
| 642 |
+
continue
|
| 643 |
+
weight = max_count / len(idxs)
|
| 644 |
+
for i in idxs:
|
| 645 |
+
self.weights[i] = weight
|
| 646 |
+
self.weights /= self.weights.sum()
|
| 647 |
+
else:
|
| 648 |
+
self.weights = None
|
| 649 |
+
|
| 650 |
+
def __len__(self):
|
| 651 |
+
return len(self.samples)
|
| 652 |
+
|
| 653 |
+
def __getitem__(self, idx):
|
| 654 |
+
sample = self.samples[idx]
|
| 655 |
+
graph = sample['graph']
|
| 656 |
+
node_feats = graph['node_feats']
|
| 657 |
+
edge_feats = graph['edge_feats']
|
| 658 |
+
node_mask = graph['node_mask']
|
| 659 |
+
N = len(node_feats)
|
| 660 |
+
node_dim = node_feats.shape[-1]
|
| 661 |
+
edge_dim = edge_feats.shape[-1]
|
| 662 |
+
|
| 663 |
+
node_feats_pad = np.zeros((self.max_nodes, node_dim), dtype=np.float32)
|
| 664 |
+
edge_feats_pad = np.zeros((self.max_nodes, self.max_nodes, edge_dim), dtype=np.float32)
|
| 665 |
+
node_mask_pad = np.zeros(self.max_nodes, dtype=bool)
|
| 666 |
+
|
| 667 |
+
n = min(N, self.max_nodes)
|
| 668 |
+
node_feats_pad[:n] = node_feats[:n]
|
| 669 |
+
edge_feats_pad[:n, :n] = edge_feats[:n, :n]
|
| 670 |
+
node_mask_pad[:n] = node_mask[:n]
|
| 671 |
+
|
| 672 |
+
if self.augment:
|
| 673 |
+
noise = np.random.randn(*node_feats_pad.shape) * 0.01
|
| 674 |
+
node_feats_pad = node_feats_pad + noise.astype(np.float32)
|
| 675 |
+
|
| 676 |
+
# Binder-dropout: simulate backbone-only designs
|
| 677 |
+
apply_binder_drop = (self.binder_dropout > 0
|
| 678 |
+
and np.random.rand() < self.binder_dropout)
|
| 679 |
+
if apply_binder_drop:
|
| 680 |
+
n_rec = graph.get('n_rec', N // 2)
|
| 681 |
+
node_feats_pad[n_rec:N, :21] = 0.0
|
| 682 |
+
node_feats_pad[n_rec:N, 20] = 1.0 # UNK
|
| 683 |
+
node_feats_pad[n_rec:N, 27:31] = 0.0
|
| 684 |
+
|
| 685 |
+
result = {
|
| 686 |
+
'node_feats': torch.from_numpy(node_feats_pad),
|
| 687 |
+
'edge_feats': torch.from_numpy(edge_feats_pad),
|
| 688 |
+
'node_mask': torch.from_numpy(node_mask_pad),
|
| 689 |
+
'label': torch.tensor(sample['label'], dtype=torch.float32),
|
| 690 |
+
'type': sample['type'],
|
| 691 |
+
'pdb': sample['pdb'],
|
| 692 |
+
'target': sample.get('_target', 'unknown'),
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
# ESM-2 features (zero-fill if unavailable)
|
| 696 |
+
if self.esm_dir:
|
| 697 |
+
target_name = sample.get('_target', 'unknown')
|
| 698 |
+
esm = load_esm_for_sample(sample, self.esm_dir, target_name, self.max_nodes)
|
| 699 |
+
if esm is not None:
|
| 700 |
+
esm_feats = esm
|
| 701 |
+
else:
|
| 702 |
+
esm_feats = np.zeros((self.max_nodes, 1280), dtype=np.float32)
|
| 703 |
+
if apply_binder_drop:
|
| 704 |
+
n_rec = graph.get('n_rec', N // 2)
|
| 705 |
+
n_binder = graph.get('n_binder', N - n_rec)
|
| 706 |
+
esm_feats[n_rec:n_rec + n_binder] = 0.0
|
| 707 |
+
result['esm_feats'] = torch.from_numpy(esm_feats)
|
| 708 |
+
|
| 709 |
+
return result
|
| 710 |
+
|
| 711 |
+
@staticmethod
|
| 712 |
+
def get_pooled_dataloaders(data_dir, targets, batch_size=16, max_nodes=128,
|
| 713 |
+
num_workers=4, paired=False,
|
| 714 |
+
rosetta_dir=None, label_source='dockq',
|
| 715 |
+
esm_dir=None, binder_dropout=0.0):
|
| 716 |
+
"""Build pooled dataloaders from multiple targets.
|
| 717 |
+
|
| 718 |
+
Args:
|
| 719 |
+
data_dir: root data directory
|
| 720 |
+
targets: list of target names
|
| 721 |
+
batch_size: batch size
|
| 722 |
+
max_nodes: max interface nodes
|
| 723 |
+
num_workers: dataloader workers
|
| 724 |
+
paired: if True, build paired dataloaders for Phase 2
|
| 725 |
+
rosetta_dir: directory with Rosetta label JSONs
|
| 726 |
+
label_source: 'dockq', 'rosetta', or 'combined'
|
| 727 |
+
"""
|
| 728 |
+
from torch.utils.data import WeightedRandomSampler
|
| 729 |
+
|
| 730 |
+
# Preload ESM embeddings into global cache before creating datasets/workers
|
| 731 |
+
if esm_dir:
|
| 732 |
+
n_loaded = preload_esm_cache(esm_dir, targets)
|
| 733 |
+
|
| 734 |
+
loaders = {}
|
| 735 |
+
for split in ['train', 'val', 'test']:
|
| 736 |
+
data_paths = []
|
| 737 |
+
for target in targets:
|
| 738 |
+
path = os.path.join(data_dir, target, f"{split}.pkl")
|
| 739 |
+
if os.path.exists(path):
|
| 740 |
+
data_paths.append((target, path))
|
| 741 |
+
|
| 742 |
+
if not data_paths:
|
| 743 |
+
continue
|
| 744 |
+
|
| 745 |
+
augment = (split == 'train')
|
| 746 |
+
bd = binder_dropout if split == 'train' else 0.0
|
| 747 |
+
|
| 748 |
+
if paired:
|
| 749 |
+
# For paired mode, concatenate paired datasets
|
| 750 |
+
all_pairs = []
|
| 751 |
+
for target, path in data_paths:
|
| 752 |
+
ds = TwoStateDatasetPaired(path, max_nodes=max_nodes, augment=augment,
|
| 753 |
+
esm_dir=esm_dir, target_name=target,
|
| 754 |
+
binder_dropout=bd)
|
| 755 |
+
all_pairs.append(ds)
|
| 756 |
+
|
| 757 |
+
if not all_pairs:
|
| 758 |
+
continue
|
| 759 |
+
|
| 760 |
+
# Use ConcatDataset
|
| 761 |
+
from torch.utils.data import ConcatDataset
|
| 762 |
+
concat_ds = ConcatDataset(all_pairs)
|
| 763 |
+
p_batch = min(batch_size, max(1, len(concat_ds) // 2))
|
| 764 |
+
loaders[split] = DataLoader(
|
| 765 |
+
concat_ds, batch_size=p_batch,
|
| 766 |
+
shuffle=(split == 'train'),
|
| 767 |
+
num_workers=num_workers,
|
| 768 |
+
collate_fn=collate_paired_fn,
|
| 769 |
+
pin_memory=True,
|
| 770 |
+
)
|
| 771 |
+
else:
|
| 772 |
+
dataset = MultiTargetDataset(data_paths, max_nodes=max_nodes,
|
| 773 |
+
augment=augment, balance=(split == 'train'),
|
| 774 |
+
rosetta_dir=rosetta_dir, label_source=label_source,
|
| 775 |
+
esm_dir=esm_dir, binder_dropout=bd)
|
| 776 |
+
|
| 777 |
+
sampler = None
|
| 778 |
+
shuffle = (split == 'train')
|
| 779 |
+
if split == 'train' and dataset.weights is not None:
|
| 780 |
+
sampler = WeightedRandomSampler(
|
| 781 |
+
weights=dataset.weights,
|
| 782 |
+
num_samples=len(dataset),
|
| 783 |
+
replacement=True
|
| 784 |
+
)
|
| 785 |
+
shuffle = False
|
| 786 |
+
|
| 787 |
+
loaders[split] = DataLoader(
|
| 788 |
+
dataset, batch_size=batch_size,
|
| 789 |
+
shuffle=shuffle, sampler=sampler,
|
| 790 |
+
num_workers=num_workers,
|
| 791 |
+
collate_fn=collate_fn,
|
| 792 |
+
pin_memory=True,
|
| 793 |
+
drop_last=(split == 'train' and len(dataset) > batch_size),
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
return loaders
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
def get_dataloaders(data_dir: str, target: str, batch_size: int = 16,
|
| 800 |
+
max_nodes: int = 128, num_workers: int = 4,
|
| 801 |
+
paired: bool = False, esm_dir: str = None,
|
| 802 |
+
binder_dropout: float = 0.0):
|
| 803 |
+
"""Build train/val/test dataloaders for a given target."""
|
| 804 |
+
loaders = {}
|
| 805 |
+
for split in ['train', 'val', 'test']:
|
| 806 |
+
path = os.path.join(data_dir, target, f"{split}.pkl")
|
| 807 |
+
if not os.path.exists(path):
|
| 808 |
+
continue
|
| 809 |
+
|
| 810 |
+
augment = (split == 'train')
|
| 811 |
+
bd = binder_dropout if split == 'train' else 0.0
|
| 812 |
+
if paired and split == 'train':
|
| 813 |
+
dataset = TwoStateDatasetPaired(path, max_nodes=max_nodes, augment=augment,
|
| 814 |
+
esm_dir=esm_dir, target_name=target,
|
| 815 |
+
binder_dropout=bd)
|
| 816 |
+
collate = collate_paired_fn
|
| 817 |
+
else:
|
| 818 |
+
dataset = TwoStateComplexDataset(path, max_nodes=max_nodes, augment=augment,
|
| 819 |
+
esm_dir=esm_dir, target_name=target,
|
| 820 |
+
binder_dropout=bd)
|
| 821 |
+
collate = collate_fn
|
| 822 |
+
|
| 823 |
+
loaders[split] = DataLoader(
|
| 824 |
+
dataset,
|
| 825 |
+
batch_size=batch_size,
|
| 826 |
+
shuffle=(split == 'train'),
|
| 827 |
+
num_workers=num_workers,
|
| 828 |
+
collate_fn=collate,
|
| 829 |
+
pin_memory=True,
|
| 830 |
+
drop_last=(split == 'train' and len(dataset) > batch_size),
|
| 831 |
+
)
|
| 832 |
+
return loaders
|
code/models/__init__.py
ADDED
|
File without changes
|
code/models/differentiable_features.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Differentiable feature extraction for Q_theta guidance.
|
| 3 |
+
|
| 4 |
+
This module re-implements the key feature extraction functions from features.py
|
| 5 |
+
and pdb_utils.py using PyTorch operations, enabling gradient computation through
|
| 6 |
+
Q_theta with respect to backbone coordinates.
|
| 7 |
+
|
| 8 |
+
The differentiable path:
|
| 9 |
+
coords (N,4,3) → backbone frames → torsions, distances, directions, rotations
|
| 10 |
+
→ node_feats, edge_feats → Q_theta → score → backward() → ∇coords
|
| 11 |
+
|
| 12 |
+
Non-differentiable features (AA one-hot, chain_id, seq_sep, same_chain) are
|
| 13 |
+
treated as constants.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ── Differentiable backbone frame computation ────────────────────────────────
|
| 23 |
+
|
| 24 |
+
def compute_backbone_frames_torch(coords, mask):
|
| 25 |
+
"""
|
| 26 |
+
Compute SE(3)-equivariant backbone frames from N, CA, C atoms.
|
| 27 |
+
Differentiable w.r.t. coords.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
coords: [N, 4, 3] backbone coords (N, CA, C, O) — requires_grad=True for binder
|
| 31 |
+
mask: [N] bool tensor
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
origins: [N, 3] = CA positions
|
| 35 |
+
rotations: [N, 3, 3] = rotation matrices (columns are x, y, z axes)
|
| 36 |
+
"""
|
| 37 |
+
N_res = coords.shape[0]
|
| 38 |
+
device = coords.device
|
| 39 |
+
|
| 40 |
+
origins = coords[:, 1, :] # CA positions [N, 3]
|
| 41 |
+
rotations = torch.eye(3, device=device, dtype=coords.dtype).unsqueeze(0).expand(N_res, -1, -1).clone()
|
| 42 |
+
|
| 43 |
+
ca = coords[:, 1, :] # [N, 3]
|
| 44 |
+
n_atom = coords[:, 0, :] # [N, 3]
|
| 45 |
+
c_atom = coords[:, 2, :] # [N, 3]
|
| 46 |
+
|
| 47 |
+
# z-axis: CA -> C
|
| 48 |
+
z = c_atom - ca # [N, 3]
|
| 49 |
+
z_norm = torch.norm(z, dim=-1, keepdim=True).clamp(min=1e-6) # [N, 1]
|
| 50 |
+
z = z / z_norm # [N, 3]
|
| 51 |
+
|
| 52 |
+
# y-axis: CA -> N, orthogonalized against z
|
| 53 |
+
y = n_atom - ca # [N, 3]
|
| 54 |
+
y_proj = (y * z).sum(dim=-1, keepdim=True) # [N, 1]
|
| 55 |
+
y = y - y_proj * z # [N, 3]
|
| 56 |
+
y_norm = torch.norm(y, dim=-1, keepdim=True).clamp(min=1e-6) # [N, 1]
|
| 57 |
+
y = y / y_norm # [N, 3]
|
| 58 |
+
|
| 59 |
+
# x-axis: y cross z
|
| 60 |
+
x = torch.cross(y, z, dim=-1) # [N, 3]
|
| 61 |
+
|
| 62 |
+
# Stack columns: [N, 3, 3] where columns are x, y, z
|
| 63 |
+
rot = torch.stack([x, y, z], dim=-1) # [N, 3, 3]
|
| 64 |
+
|
| 65 |
+
# Apply mask: identity for masked residues
|
| 66 |
+
mask_f = mask.float().unsqueeze(-1).unsqueeze(-1) # [N, 1, 1]
|
| 67 |
+
eye = torch.eye(3, device=device, dtype=coords.dtype).unsqueeze(0) # [1, 3, 3]
|
| 68 |
+
rotations = rot * mask_f + eye * (1 - mask_f)
|
| 69 |
+
|
| 70 |
+
return origins, rotations
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ── Differentiable torsion angle computation ─────────────────────────────────
|
| 74 |
+
|
| 75 |
+
def _dihedral_torch(p0, p1, p2, p3):
|
| 76 |
+
"""
|
| 77 |
+
Compute dihedral angle for batches of 4 points. Returns sin, cos.
|
| 78 |
+
Differentiable w.r.t. all inputs.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
p0, p1, p2, p3: [N, 3] tensors
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
sin_angle: [N]
|
| 85 |
+
cos_angle: [N]
|
| 86 |
+
"""
|
| 87 |
+
b1 = p1 - p0 # [N, 3]
|
| 88 |
+
b2 = p2 - p1
|
| 89 |
+
b3 = p3 - p2
|
| 90 |
+
|
| 91 |
+
n1 = torch.cross(b1, b2, dim=-1) # [N, 3]
|
| 92 |
+
n2 = torch.cross(b2, b3, dim=-1)
|
| 93 |
+
|
| 94 |
+
n1_norm = torch.norm(n1, dim=-1, keepdim=True).clamp(min=1e-8)
|
| 95 |
+
n2_norm = torch.norm(n2, dim=-1, keepdim=True).clamp(min=1e-8)
|
| 96 |
+
n1 = n1 / n1_norm
|
| 97 |
+
n2 = n2 / n2_norm
|
| 98 |
+
|
| 99 |
+
b2_norm = torch.norm(b2, dim=-1, keepdim=True).clamp(min=1e-8)
|
| 100 |
+
m1 = torch.cross(n1, b2 / b2_norm, dim=-1) # [N, 3]
|
| 101 |
+
|
| 102 |
+
cos_angle = (n1 * n2).sum(dim=-1) # [N]
|
| 103 |
+
sin_angle = (m1 * n2).sum(dim=-1) # [N]
|
| 104 |
+
|
| 105 |
+
return sin_angle, cos_angle
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def compute_torsion_angles_torch(coords, mask):
|
| 109 |
+
"""
|
| 110 |
+
Compute backbone torsion angles (phi, psi, omega) as sin/cos pairs.
|
| 111 |
+
Differentiable w.r.t. coords.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
coords: [N, 4, 3] backbone coords (N, CA, C, O)
|
| 115 |
+
mask: [N] bool tensor
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
torsions: [N, 6] (sin_phi, cos_phi, sin_psi, cos_psi, sin_omega, cos_omega)
|
| 119 |
+
"""
|
| 120 |
+
N = coords.shape[0]
|
| 121 |
+
device = coords.device
|
| 122 |
+
torsions = torch.zeros(N, 6, device=device, dtype=coords.dtype)
|
| 123 |
+
|
| 124 |
+
if N < 2:
|
| 125 |
+
return torsions
|
| 126 |
+
|
| 127 |
+
n_atoms = coords[:, 0, :] # N atoms [N, 3]
|
| 128 |
+
ca_atoms = coords[:, 1, :] # CA atoms
|
| 129 |
+
c_atoms = coords[:, 2, :] # C atoms
|
| 130 |
+
|
| 131 |
+
# Phi: C_{i-1} - N_i - CA_i - C_i (for i >= 1)
|
| 132 |
+
if N > 1:
|
| 133 |
+
phi_mask = mask[1:] & mask[:-1] # [N-1]
|
| 134 |
+
sin_phi, cos_phi = _dihedral_torch(
|
| 135 |
+
c_atoms[:-1], # C_{i-1}
|
| 136 |
+
n_atoms[1:], # N_i
|
| 137 |
+
ca_atoms[1:], # CA_i
|
| 138 |
+
c_atoms[1:] # C_i
|
| 139 |
+
)
|
| 140 |
+
torsions[1:, 0] = sin_phi * phi_mask.float()
|
| 141 |
+
torsions[1:, 1] = cos_phi * phi_mask.float()
|
| 142 |
+
|
| 143 |
+
# Psi: N_i - CA_i - C_i - N_{i+1} (for i < N-1)
|
| 144 |
+
if N > 1:
|
| 145 |
+
psi_mask = mask[:-1] & mask[1:] # [N-1]
|
| 146 |
+
sin_psi, cos_psi = _dihedral_torch(
|
| 147 |
+
n_atoms[:-1], # N_i
|
| 148 |
+
ca_atoms[:-1], # CA_i
|
| 149 |
+
c_atoms[:-1], # C_i
|
| 150 |
+
n_atoms[1:] # N_{i+1}
|
| 151 |
+
)
|
| 152 |
+
torsions[:-1, 2] = sin_psi * psi_mask.float()
|
| 153 |
+
torsions[:-1, 3] = cos_psi * psi_mask.float()
|
| 154 |
+
|
| 155 |
+
# Omega: CA_{i-1} - C_{i-1} - N_i - CA_i (for i >= 1)
|
| 156 |
+
if N > 1:
|
| 157 |
+
omega_mask = mask[1:] & mask[:-1] # [N-1]
|
| 158 |
+
sin_omega, cos_omega = _dihedral_torch(
|
| 159 |
+
ca_atoms[:-1], # CA_{i-1}
|
| 160 |
+
c_atoms[:-1], # C_{i-1}
|
| 161 |
+
n_atoms[1:], # N_i
|
| 162 |
+
ca_atoms[1:] # CA_i
|
| 163 |
+
)
|
| 164 |
+
torsions[1:, 4] = sin_omega * omega_mask.float()
|
| 165 |
+
torsions[1:, 5] = cos_omega * omega_mask.float()
|
| 166 |
+
|
| 167 |
+
return torsions
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ── Differentiable RBF distance encoding ─────────────────────────────────────
|
| 171 |
+
|
| 172 |
+
def rbf_encode_torch(distances, d_min=0.0, d_max=20.0, n_bins=16):
|
| 173 |
+
"""
|
| 174 |
+
RBF encoding of distances using Gaussian basis functions.
|
| 175 |
+
Differentiable w.r.t. distances.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
distances: [...] tensor
|
| 179 |
+
Returns:
|
| 180 |
+
encoded: [..., n_bins] tensor
|
| 181 |
+
"""
|
| 182 |
+
centers = torch.linspace(d_min, d_max, n_bins, device=distances.device, dtype=distances.dtype)
|
| 183 |
+
sigma = (d_max - d_min) / (n_bins - 1)
|
| 184 |
+
return torch.exp(-((distances.unsqueeze(-1) - centers) ** 2) / (2 * sigma ** 2))
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ── Differentiable edge feature computation ──────────────────────────────────
|
| 188 |
+
|
| 189 |
+
def compute_edge_features_torch(origins, rotations, seq_idx, chain_ids, mask,
|
| 190 |
+
n_bins_rbf=16, n_bins_sep=8, max_sep=32):
|
| 191 |
+
"""
|
| 192 |
+
Compute SE(3)-invariant edge features between all residue pairs.
|
| 193 |
+
Differentiable w.r.t. origins and rotations (which derive from coords).
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
origins: [N, 3] CA positions
|
| 197 |
+
rotations: [N, 3, 3] backbone frame rotations
|
| 198 |
+
seq_idx: [N] int tensor — sequence indices (non-differentiable)
|
| 199 |
+
chain_ids: [N] int tensor — chain labels (non-differentiable)
|
| 200 |
+
mask: [N] bool tensor
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
edge_feats: [N, N, 37]
|
| 204 |
+
"""
|
| 205 |
+
N = origins.shape[0]
|
| 206 |
+
device = origins.device
|
| 207 |
+
dtype = origins.dtype
|
| 208 |
+
|
| 209 |
+
# --- Distance features (differentiable) ---
|
| 210 |
+
diff = origins.unsqueeze(1) - origins.unsqueeze(0) # [N, N, 3]
|
| 211 |
+
dist = torch.norm(diff, dim=-1).clamp(min=1e-8) # [N, N]
|
| 212 |
+
dist_rbf = rbf_encode_torch(dist, d_min=0., d_max=20., n_bins=n_bins_rbf) # [N, N, 16]
|
| 213 |
+
|
| 214 |
+
# --- Direction in local frame (differentiable) ---
|
| 215 |
+
unit_diff = diff / dist.unsqueeze(-1) # [N, N, 3]
|
| 216 |
+
# local_dir[i,j] = R_i^T @ (ca_j - ca_i) / dist
|
| 217 |
+
# rotations: [N, 3, 3], unit_diff: [N, N, 3]
|
| 218 |
+
local_dir = torch.einsum('ikl,ijl->ijk', rotations, unit_diff) # [N, N, 3]
|
| 219 |
+
|
| 220 |
+
# --- Relative rotation (differentiable) ---
|
| 221 |
+
# rel_rot[i,j] = R_i^T @ R_j -> [N, N, 3, 3] -> flatten to [N, N, 9]
|
| 222 |
+
rel_rot = torch.einsum('ikl,jlm->ijkm', rotations, rotations) # [N, N, 3, 3]
|
| 223 |
+
rel_rot_flat = rel_rot.reshape(N, N, 9) # [N, N, 9]
|
| 224 |
+
|
| 225 |
+
# --- Sequence separation (non-differentiable, constant) ---
|
| 226 |
+
sep = seq_idx.unsqueeze(1) - seq_idx.unsqueeze(0) # [N, N]
|
| 227 |
+
bins = torch.linspace(-max_sep, max_sep, n_bins_sep + 1, device=device)
|
| 228 |
+
sep_clipped = sep.float().clamp(-max_sep, max_sep)
|
| 229 |
+
# Bin encoding via soft assignment (but really we just use hard binning)
|
| 230 |
+
sep_enc = torch.zeros(N, N, n_bins_sep, device=device, dtype=dtype)
|
| 231 |
+
bin_idx = torch.bucketize(sep_clipped, bins) - 1
|
| 232 |
+
bin_idx = bin_idx.clamp(0, n_bins_sep - 1)
|
| 233 |
+
# Scatter one-hot
|
| 234 |
+
sep_enc.scatter_(2, bin_idx.unsqueeze(-1).long(), 1.0)
|
| 235 |
+
|
| 236 |
+
# Cross-chain pairs get sep=0
|
| 237 |
+
same_chain = (chain_ids.unsqueeze(1) == chain_ids.unsqueeze(0)) # [N, N]
|
| 238 |
+
cross_chain = ~same_chain
|
| 239 |
+
sep_enc[cross_chain] = 0.0
|
| 240 |
+
|
| 241 |
+
# --- Same chain indicator (non-differentiable, constant) ---
|
| 242 |
+
same_chain_feat = same_chain.float().unsqueeze(-1) # [N, N, 1]
|
| 243 |
+
|
| 244 |
+
# --- Concatenate ---
|
| 245 |
+
edge_feats = torch.cat([
|
| 246 |
+
dist_rbf, # [N, N, 16]
|
| 247 |
+
local_dir, # [N, N, 3]
|
| 248 |
+
rel_rot_flat, # [N, N, 9]
|
| 249 |
+
sep_enc, # [N, N, 8]
|
| 250 |
+
same_chain_feat # [N, N, 1]
|
| 251 |
+
], dim=-1) # [N, N, 37]
|
| 252 |
+
|
| 253 |
+
# Zero out edges involving masked residues
|
| 254 |
+
mask_2d = mask.unsqueeze(1) & mask.unsqueeze(0) # [N, N]
|
| 255 |
+
edge_feats = edge_feats * mask_2d.unsqueeze(-1).float()
|
| 256 |
+
|
| 257 |
+
return edge_feats
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# ── Full differentiable interface graph builder ──────────────────────────────
|
| 261 |
+
|
| 262 |
+
def build_differentiable_interface_graph(
|
| 263 |
+
rec_coords, rec_mask, rec_aa_idx, rec_chi,
|
| 264 |
+
binder_coords, binder_mask, binder_aa_idx, binder_chi,
|
| 265 |
+
cutoff=8.0, max_nodes=128
|
| 266 |
+
):
|
| 267 |
+
"""
|
| 268 |
+
Build interface graph with differentiable features w.r.t. binder_coords.
|
| 269 |
+
Receptor coords are treated as constants (detached).
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
rec_coords: [N_rec, 4, 3] — receptor backbone coords (constant, no grad)
|
| 273 |
+
rec_mask: [N_rec] bool
|
| 274 |
+
rec_aa_idx: [N_rec] int — amino acid indices (constant)
|
| 275 |
+
rec_chi: [N_rec, 4] — chi1/chi2 sin/cos (constant)
|
| 276 |
+
binder_coords: [N_binder, 4, 3] — binder backbone coords (requires_grad)
|
| 277 |
+
binder_mask: [N_binder] bool
|
| 278 |
+
binder_aa_idx: [N_binder] int — amino acid indices (constant, UNK for designed)
|
| 279 |
+
binder_chi: [N_binder, 4] — chi1/chi2 sin/cos (zeros for backbone-only)
|
| 280 |
+
cutoff: interface distance cutoff (Å)
|
| 281 |
+
max_nodes: maximum nodes per chain in the graph
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
node_feats: [1, N_total, 32] tensor
|
| 285 |
+
edge_feats: [1, N_total, N_total, 37] tensor
|
| 286 |
+
node_mask: [1, N_total] bool tensor
|
| 287 |
+
n_rec: int
|
| 288 |
+
n_binder: int
|
| 289 |
+
or None if no interface
|
| 290 |
+
"""
|
| 291 |
+
device = binder_coords.device
|
| 292 |
+
dtype = binder_coords.dtype
|
| 293 |
+
NUM_AA = 21
|
| 294 |
+
|
| 295 |
+
# ── Find interface residues (differentiable distances but hard threshold) ──
|
| 296 |
+
rec_ca = rec_coords[:, 1, :] # [N_rec, 3]
|
| 297 |
+
binder_ca = binder_coords[:, 1, :] # [N_binder, 3]
|
| 298 |
+
|
| 299 |
+
# Pairwise CA distances
|
| 300 |
+
dist_mat = torch.cdist(rec_ca.unsqueeze(0), binder_ca.unsqueeze(0)).squeeze(0) # [N_rec, N_binder]
|
| 301 |
+
# Mask invalid residues
|
| 302 |
+
dist_mat = dist_mat.clone()
|
| 303 |
+
dist_mat[~rec_mask, :] = float('inf')
|
| 304 |
+
dist_mat[:, ~binder_mask] = float('inf')
|
| 305 |
+
|
| 306 |
+
rec_iface = (dist_mat < cutoff).any(dim=1) # [N_rec]
|
| 307 |
+
binder_iface = (dist_mat < cutoff).any(dim=0) # [N_binder]
|
| 308 |
+
|
| 309 |
+
rec_iface_idx = torch.where(rec_iface)[0]
|
| 310 |
+
binder_iface_idx = torch.where(binder_iface)[0]
|
| 311 |
+
|
| 312 |
+
# Truncate if too many
|
| 313 |
+
if len(rec_iface_idx) > max_nodes // 2:
|
| 314 |
+
rec_iface_idx = rec_iface_idx[:max_nodes // 2]
|
| 315 |
+
if len(binder_iface_idx) > max_nodes // 2:
|
| 316 |
+
binder_iface_idx = binder_iface_idx[:max_nodes // 2]
|
| 317 |
+
|
| 318 |
+
n_rec = len(rec_iface_idx)
|
| 319 |
+
n_binder = len(binder_iface_idx)
|
| 320 |
+
n_total = n_rec + n_binder
|
| 321 |
+
|
| 322 |
+
if n_total == 0:
|
| 323 |
+
return None
|
| 324 |
+
|
| 325 |
+
# ── Extract interface subsets ──
|
| 326 |
+
rec_iface_coords = rec_coords[rec_iface_idx] # [n_rec, 4, 3]
|
| 327 |
+
binder_iface_coords = binder_coords[binder_iface_idx] # [n_binder, 4, 3]
|
| 328 |
+
rec_iface_mask = rec_mask[rec_iface_idx]
|
| 329 |
+
binder_iface_mask = binder_mask[binder_iface_idx]
|
| 330 |
+
|
| 331 |
+
# ── Compute backbone frames (differentiable) ──
|
| 332 |
+
rec_origins, rec_rotations = compute_backbone_frames_torch(rec_iface_coords, rec_iface_mask)
|
| 333 |
+
binder_origins, binder_rotations = compute_backbone_frames_torch(binder_iface_coords, binder_iface_mask)
|
| 334 |
+
|
| 335 |
+
# ── Compute torsion angles (differentiable) ──
|
| 336 |
+
rec_torsion = compute_torsion_angles_torch(rec_iface_coords, rec_iface_mask) # [n_rec, 6]
|
| 337 |
+
binder_torsion = compute_torsion_angles_torch(binder_iface_coords, binder_iface_mask) # [n_binder, 6]
|
| 338 |
+
|
| 339 |
+
# ── Node features ──
|
| 340 |
+
# AA one-hot (non-differentiable constant)
|
| 341 |
+
rec_aa_onehot = F.one_hot(rec_aa_idx[rec_iface_idx].long(), NUM_AA).float() # [n_rec, 21]
|
| 342 |
+
binder_aa_onehot = F.one_hot(binder_aa_idx[binder_iface_idx].long(), NUM_AA).float() # [n_binder, 21]
|
| 343 |
+
|
| 344 |
+
# Chi angles (constant for receptor, zeros for backbone-only binder)
|
| 345 |
+
rec_chi_iface = rec_chi[rec_iface_idx] # [n_rec, 4]
|
| 346 |
+
binder_chi_iface = binder_chi[binder_iface_idx] # [n_binder, 4]
|
| 347 |
+
|
| 348 |
+
# Chain indicator
|
| 349 |
+
rec_chain_feat = torch.zeros(n_rec, 1, device=device, dtype=dtype)
|
| 350 |
+
binder_chain_feat = torch.ones(n_binder, 1, device=device, dtype=dtype)
|
| 351 |
+
|
| 352 |
+
# Concatenate node features: [AA(21) + torsions(6) + chi(4) + chain(1)] = 32
|
| 353 |
+
rec_node = torch.cat([rec_aa_onehot, rec_torsion, rec_chi_iface, rec_chain_feat], dim=-1)
|
| 354 |
+
binder_node = torch.cat([binder_aa_onehot, binder_torsion, binder_chi_iface, binder_chain_feat], dim=-1)
|
| 355 |
+
node_feats = torch.cat([rec_node, binder_node], dim=0) # [N_total, 32]
|
| 356 |
+
node_mask_flat = torch.cat([rec_iface_mask, binder_iface_mask], dim=0) # [N_total]
|
| 357 |
+
|
| 358 |
+
# ── Edge features (differentiable) ──
|
| 359 |
+
all_origins = torch.cat([rec_origins, binder_origins], dim=0) # [N_total, 3]
|
| 360 |
+
all_rotations = torch.cat([rec_rotations, binder_rotations], dim=0) # [N_total, 3, 3]
|
| 361 |
+
|
| 362 |
+
# Sequence indices
|
| 363 |
+
rec_seq_idx = rec_iface_idx
|
| 364 |
+
binder_seq_idx = binder_iface_idx + rec_coords.shape[0]
|
| 365 |
+
all_seq_idx = torch.cat([rec_seq_idx, binder_seq_idx], dim=0)
|
| 366 |
+
|
| 367 |
+
# Chain IDs
|
| 368 |
+
all_chain_ids = torch.cat([
|
| 369 |
+
torch.zeros(n_rec, device=device, dtype=torch.long),
|
| 370 |
+
torch.ones(n_binder, device=device, dtype=torch.long)
|
| 371 |
+
], dim=0)
|
| 372 |
+
|
| 373 |
+
edge_feats = compute_edge_features_torch(
|
| 374 |
+
all_origins, all_rotations, all_seq_idx, all_chain_ids, node_mask_flat
|
| 375 |
+
) # [N_total, N_total, 37]
|
| 376 |
+
|
| 377 |
+
# Add batch dimension
|
| 378 |
+
return {
|
| 379 |
+
'node_feats': node_feats.unsqueeze(0), # [1, N, 32]
|
| 380 |
+
'edge_feats': edge_feats.unsqueeze(0), # [1, N, N, 37]
|
| 381 |
+
'node_mask': node_mask_flat.unsqueeze(0), # [1, N]
|
| 382 |
+
'n_rec': n_rec,
|
| 383 |
+
'n_binder': n_binder,
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# ── Differentiable Q_theta scoring function ──────────────────────────────────
|
| 388 |
+
|
| 389 |
+
class DifferentiableQTheta:
|
| 390 |
+
"""
|
| 391 |
+
Wraps the Q_theta scorer for differentiable scoring w.r.t. binder backbone
|
| 392 |
+
coordinates. Receptor structures are pre-loaded and cached.
|
| 393 |
+
|
| 394 |
+
Usage:
|
| 395 |
+
dq = DifferentiableQTheta(checkpoint_path, device)
|
| 396 |
+
dq.load_receptor(holo_pdb, chain='A', label='holo')
|
| 397 |
+
dq.load_receptor(apo_pdb, chain='A', label='apo')
|
| 398 |
+
|
| 399 |
+
binder_coords = torch.tensor(...) # [N_binder, 4, 3], requires_grad=True
|
| 400 |
+
score_holo = dq.score(binder_coords, binder_mask, binder_aa_idx, 'holo')
|
| 401 |
+
score_apo = dq.score(binder_coords, binder_mask, binder_aa_idx, 'apo')
|
| 402 |
+
selectivity = score_holo - score_apo
|
| 403 |
+
selectivity.backward()
|
| 404 |
+
# binder_coords.grad now contains ∂S/∂coords
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
def __init__(self, checkpoint_path, device='cuda:0', esm_dir=None):
|
| 408 |
+
import sys, os
|
| 409 |
+
_code_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 410 |
+
if _code_dir not in sys.path:
|
| 411 |
+
sys.path.insert(0, _code_dir)
|
| 412 |
+
from models.scorer import build_model
|
| 413 |
+
|
| 414 |
+
self.device = torch.device(device)
|
| 415 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
| 416 |
+
self.config = ckpt['config']
|
| 417 |
+
self.model = build_model(self.config)
|
| 418 |
+
self.model.load_state_dict(ckpt['model_state'])
|
| 419 |
+
self.model = self.model.to(self.device)
|
| 420 |
+
self.model.eval()
|
| 421 |
+
|
| 422 |
+
# ESM feature support
|
| 423 |
+
self.use_esm = self.config.get('esm_dim', 0) > 0
|
| 424 |
+
self.esm_dim = self.config.get('esm_dim', 0)
|
| 425 |
+
self.esm_dir = esm_dir or os.path.join(os.environ.get('ALLOGEN_ROOT', '.'), 'data/esm2_embeddings')
|
| 426 |
+
|
| 427 |
+
# Cache receptor data
|
| 428 |
+
self.receptors = {} # label -> {coords, mask, aa_idx, chi, esm_emb?}
|
| 429 |
+
|
| 430 |
+
def load_receptor(self, pdb_path, chain='A', label='holo',
|
| 431 |
+
esm_target=None, esm_key=None):
|
| 432 |
+
"""Pre-load and cache receptor structure, optionally with ESM embeddings.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
pdb_path: path to receptor PDB
|
| 436 |
+
chain: chain ID
|
| 437 |
+
label: cache key
|
| 438 |
+
esm_target: target name for ESM dir (e.g., 'abl' for data/esm2_embeddings/abl/)
|
| 439 |
+
esm_key: ESM embedding file key (e.g., '6XR7_A'). If None, auto-derived.
|
| 440 |
+
"""
|
| 441 |
+
import sys, os
|
| 442 |
+
_code_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 443 |
+
if _code_dir not in sys.path:
|
| 444 |
+
sys.path.insert(0, _code_dir)
|
| 445 |
+
from utils.pdb_utils import (
|
| 446 |
+
load_structure, get_residues, get_backbone_coords,
|
| 447 |
+
get_aa_indices, compute_chi_angles
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
model = load_structure(pdb_path)
|
| 451 |
+
chain_obj = model[chain]
|
| 452 |
+
residues = get_residues(chain_obj)
|
| 453 |
+
coords, mask = get_backbone_coords(residues)
|
| 454 |
+
aa_idx = get_aa_indices(residues)
|
| 455 |
+
chi = compute_chi_angles(residues, mask)
|
| 456 |
+
|
| 457 |
+
rec_data = {
|
| 458 |
+
'coords': torch.from_numpy(coords).float().to(self.device),
|
| 459 |
+
'mask': torch.from_numpy(mask).bool().to(self.device),
|
| 460 |
+
'aa_idx': torch.from_numpy(aa_idx).long().to(self.device),
|
| 461 |
+
'chi': torch.from_numpy(chi).float().to(self.device),
|
| 462 |
+
'residues': residues,
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
# Load ESM embeddings if model uses ESM
|
| 466 |
+
if self.use_esm and esm_target:
|
| 467 |
+
pdb_id = os.path.basename(pdb_path).replace('.pdb', '')
|
| 468 |
+
if esm_key is None:
|
| 469 |
+
esm_key = f'{pdb_id}_{chain}'
|
| 470 |
+
esm_path = os.path.join(self.esm_dir, esm_target, f'{esm_key}.pt')
|
| 471 |
+
if os.path.exists(esm_path):
|
| 472 |
+
esm_emb = torch.load(esm_path, map_location=self.device, weights_only=True)
|
| 473 |
+
# Truncate/pad to match residue count
|
| 474 |
+
n_res = len(residues)
|
| 475 |
+
if esm_emb.shape[0] > n_res:
|
| 476 |
+
esm_emb = esm_emb[:n_res]
|
| 477 |
+
elif esm_emb.shape[0] < n_res:
|
| 478 |
+
pad = torch.zeros(n_res - esm_emb.shape[0], esm_emb.shape[1],
|
| 479 |
+
device=self.device)
|
| 480 |
+
esm_emb = torch.cat([esm_emb, pad], dim=0)
|
| 481 |
+
rec_data['esm_emb'] = esm_emb.float()
|
| 482 |
+
else:
|
| 483 |
+
rec_data['esm_emb'] = torch.zeros(len(residues), self.esm_dim,
|
| 484 |
+
device=self.device)
|
| 485 |
+
|
| 486 |
+
self.receptors[label] = rec_data
|
| 487 |
+
|
| 488 |
+
def load_receptor_from_coords(self, coords, mask, aa_idx=None, chi=None,
|
| 489 |
+
label='path'):
|
| 490 |
+
"""
|
| 491 |
+
Load a receptor from raw backbone coords (not from PDB file).
|
| 492 |
+
|
| 493 |
+
Used for interpolated path frames that don't have PDB files.
|
| 494 |
+
If aa_idx is None, uses all-ALA (index 0). If chi is None, uses zeros.
|
| 495 |
+
|
| 496 |
+
Args:
|
| 497 |
+
coords: [N, 4, 3] numpy or torch backbone coords (N, CA, C, O)
|
| 498 |
+
mask: [N] numpy or torch bool
|
| 499 |
+
aa_idx: [N] numpy or torch int (default: all-ALA = 0)
|
| 500 |
+
chi: [N, 4] numpy or torch float (default: zeros)
|
| 501 |
+
label: str key for caching
|
| 502 |
+
"""
|
| 503 |
+
import numpy as np
|
| 504 |
+
|
| 505 |
+
# Convert numpy to torch if needed
|
| 506 |
+
if isinstance(coords, np.ndarray):
|
| 507 |
+
coords = torch.from_numpy(coords).float()
|
| 508 |
+
if isinstance(mask, np.ndarray):
|
| 509 |
+
mask = torch.from_numpy(mask).bool()
|
| 510 |
+
|
| 511 |
+
N = coords.shape[0]
|
| 512 |
+
|
| 513 |
+
if aa_idx is None:
|
| 514 |
+
aa_idx = torch.zeros(N, dtype=torch.long) # all-ALA
|
| 515 |
+
elif isinstance(aa_idx, np.ndarray):
|
| 516 |
+
aa_idx = torch.from_numpy(aa_idx).long()
|
| 517 |
+
|
| 518 |
+
if chi is None:
|
| 519 |
+
chi = torch.zeros(N, 4, dtype=coords.dtype)
|
| 520 |
+
elif isinstance(chi, np.ndarray):
|
| 521 |
+
chi = torch.from_numpy(chi).float()
|
| 522 |
+
|
| 523 |
+
self.receptors[label] = {
|
| 524 |
+
'coords': coords.to(self.device),
|
| 525 |
+
'mask': mask.to(self.device),
|
| 526 |
+
'aa_idx': aa_idx.to(self.device),
|
| 527 |
+
'chi': chi.to(self.device),
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
def score(self, binder_coords, binder_mask, binder_aa_idx=None,
|
| 531 |
+
binder_chi=None, receptor_label='holo', cutoff=8.0):
|
| 532 |
+
"""
|
| 533 |
+
Score binder against a cached receptor. Differentiable w.r.t. binder_coords.
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
binder_coords: [N_binder, 4, 3] tensor (can have requires_grad=True)
|
| 537 |
+
binder_mask: [N_binder] bool tensor
|
| 538 |
+
binder_aa_idx: [N_binder] int tensor (default: all UNK)
|
| 539 |
+
binder_chi: [N_binder, 4] tensor (default: zeros)
|
| 540 |
+
receptor_label: key into cached receptors
|
| 541 |
+
cutoff: interface distance cutoff
|
| 542 |
+
|
| 543 |
+
Returns:
|
| 544 |
+
score: scalar tensor in (0, 1), differentiable w.r.t. binder_coords
|
| 545 |
+
"""
|
| 546 |
+
rec = self.receptors[receptor_label]
|
| 547 |
+
N_binder = binder_coords.shape[0]
|
| 548 |
+
|
| 549 |
+
if binder_aa_idx is None:
|
| 550 |
+
binder_aa_idx = torch.full((N_binder,), 20, device=self.device, dtype=torch.long) # UNK
|
| 551 |
+
if binder_chi is None:
|
| 552 |
+
binder_chi = torch.zeros(N_binder, 4, device=self.device, dtype=binder_coords.dtype)
|
| 553 |
+
|
| 554 |
+
graph = build_differentiable_interface_graph(
|
| 555 |
+
rec_coords=rec['coords'],
|
| 556 |
+
rec_mask=rec['mask'],
|
| 557 |
+
rec_aa_idx=rec['aa_idx'],
|
| 558 |
+
rec_chi=rec['chi'],
|
| 559 |
+
binder_coords=binder_coords,
|
| 560 |
+
binder_mask=binder_mask,
|
| 561 |
+
binder_aa_idx=binder_aa_idx,
|
| 562 |
+
binder_chi=binder_chi,
|
| 563 |
+
cutoff=cutoff,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
if graph is None:
|
| 567 |
+
# No interface — return zero score with gradient
|
| 568 |
+
return torch.zeros(1, device=self.device, dtype=binder_coords.dtype, requires_grad=True).squeeze()
|
| 569 |
+
|
| 570 |
+
# Build ESM features if model uses ESM
|
| 571 |
+
esm_feats = None
|
| 572 |
+
if self.use_esm:
|
| 573 |
+
n_rec = graph['n_rec']
|
| 574 |
+
n_binder = graph['n_binder']
|
| 575 |
+
n_total = n_rec + n_binder
|
| 576 |
+
# Receptor ESM: use cached if available, else zeros
|
| 577 |
+
if 'esm_emb' in rec:
|
| 578 |
+
rec_esm = rec['esm_emb']
|
| 579 |
+
# Need to select interface residues (same indices as structural features)
|
| 580 |
+
# The graph was built with rec_iface_idx — we need those indices
|
| 581 |
+
# For simplicity, use zeros for now and rely on the projection layer
|
| 582 |
+
# to handle the zero binder ESM gracefully
|
| 583 |
+
rec_esm_full = rec_esm # [N_rec_total, 1280]
|
| 584 |
+
else:
|
| 585 |
+
rec_esm_full = torch.zeros(rec['coords'].shape[0], self.esm_dim,
|
| 586 |
+
device=self.device)
|
| 587 |
+
# Binder ESM: zeros (designed backbone, no sequence)
|
| 588 |
+
binder_esm = torch.zeros(binder_coords.shape[0], self.esm_dim,
|
| 589 |
+
device=self.device)
|
| 590 |
+
# We need interface indices to select — rebuild them
|
| 591 |
+
rec_ca = rec['coords'][:, 1, :]
|
| 592 |
+
binder_ca = binder_coords[:, 1, :]
|
| 593 |
+
dist_mat = torch.cdist(rec_ca.unsqueeze(0), binder_ca.unsqueeze(0)).squeeze(0)
|
| 594 |
+
dist_mat_c = dist_mat.clone()
|
| 595 |
+
dist_mat_c[~rec['mask'], :] = float('inf')
|
| 596 |
+
dist_mat_c[:, ~binder_mask] = float('inf')
|
| 597 |
+
rec_iface = (dist_mat_c < cutoff).any(dim=1)
|
| 598 |
+
binder_iface = (dist_mat_c < cutoff).any(dim=0)
|
| 599 |
+
rec_iface_idx = torch.where(rec_iface)[0][:n_rec]
|
| 600 |
+
binder_iface_idx = torch.where(binder_iface)[0][:n_binder]
|
| 601 |
+
|
| 602 |
+
rec_esm_iface = rec_esm_full[rec_iface_idx] # [n_rec, 1280]
|
| 603 |
+
binder_esm_iface = binder_esm[binder_iface_idx] # [n_binder, 1280]
|
| 604 |
+
esm_combined = torch.cat([rec_esm_iface, binder_esm_iface], dim=0) # [n_total, 1280]
|
| 605 |
+
esm_feats = esm_combined.unsqueeze(0) # [1, n_total, 1280]
|
| 606 |
+
|
| 607 |
+
score = self.model(graph['node_feats'], graph['edge_feats'], graph['node_mask'],
|
| 608 |
+
esm_feats=esm_feats)
|
| 609 |
+
return score.squeeze() # scalar
|
| 610 |
+
|
| 611 |
+
def selectivity_margin(self, binder_coords, binder_mask,
|
| 612 |
+
binder_aa_idx=None, binder_chi=None,
|
| 613 |
+
holo_label='holo', apo_label='apo', cutoff=8.0):
|
| 614 |
+
"""
|
| 615 |
+
Compute selectivity margin S = Q(holo, Y) - Q(apo, Y).
|
| 616 |
+
Differentiable w.r.t. binder_coords.
|
| 617 |
+
"""
|
| 618 |
+
q_holo = self.score(binder_coords, binder_mask, binder_aa_idx, binder_chi,
|
| 619 |
+
holo_label, cutoff)
|
| 620 |
+
q_apo = self.score(binder_coords, binder_mask, binder_aa_idx, binder_chi,
|
| 621 |
+
apo_label, cutoff)
|
| 622 |
+
return q_holo - q_apo, q_holo, q_apo
|
code/models/features.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SE(3)-invariant feature extraction for interface graphs.
|
| 3 |
+
Node and edge features used by the Q_theta scorer.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
# Ensure utils is importable (for both direct and package imports)
|
| 11 |
+
_CODE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 12 |
+
if _CODE_DIR not in sys.path:
|
| 13 |
+
sys.path.insert(0, _CODE_DIR)
|
| 14 |
+
|
| 15 |
+
from utils.pdb_utils import (
|
| 16 |
+
rbf_encode, compute_backbone_frames, compute_torsion_angles,
|
| 17 |
+
get_aa_indices, compute_chi_angles, get_cb_positions, NUM_AA
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Feature dimensions
|
| 21 |
+
# one-hot AA (21) + backbone torsions (6) + chi1 sin/cos (2) + chi2 sin/cos (2) + chain indicator (1) = 32
|
| 22 |
+
NODE_DIM = NUM_AA + 6 + 4 + 1 # = 32
|
| 23 |
+
EDGE_DIM = 16 + 3 + 9 + 8 + 1 # RBF dist (16) + direction (3) + rel rotation (9) + seq sep (8) + same chain (1) = 37
|
| 24 |
+
|
| 25 |
+
MAX_SEQ_SEP = 32 # bins for sequence separation
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def seq_sep_encode(sep, n_bins=8, max_sep=MAX_SEQ_SEP):
|
| 29 |
+
"""Bin-encode sequence separation."""
|
| 30 |
+
bins = np.linspace(-max_sep, max_sep, n_bins + 1)
|
| 31 |
+
sep_clipped = np.clip(sep, -max_sep, max_sep)
|
| 32 |
+
encoded = np.zeros(n_bins, dtype=np.float32)
|
| 33 |
+
bin_idx = np.digitize(sep_clipped, bins) - 1
|
| 34 |
+
bin_idx = np.clip(bin_idx, 0, n_bins - 1)
|
| 35 |
+
encoded[bin_idx] = 1.0
|
| 36 |
+
return encoded
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def extract_node_features(residues, coords, mask, torsion_angles, chi_angles, chain_id):
|
| 40 |
+
"""
|
| 41 |
+
Compute per-residue node features.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
residues: list of Bio.PDB residues
|
| 45 |
+
coords: [N, 4, 3] backbone coords
|
| 46 |
+
mask: [N] bool
|
| 47 |
+
torsion_angles: [N, 6] sin/cos of phi, psi, omega
|
| 48 |
+
chi_angles: [N, 4] sin/cos of chi1, chi2
|
| 49 |
+
chain_id: 0 = receptor, 1 = binder
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
node_feats: [N, NODE_DIM] (NODE_DIM = 32)
|
| 53 |
+
"""
|
| 54 |
+
N = len(residues)
|
| 55 |
+
aa_idx = get_aa_indices(residues)
|
| 56 |
+
|
| 57 |
+
# One-hot amino acid
|
| 58 |
+
aa_onehot = np.zeros((N, NUM_AA), dtype=np.float32)
|
| 59 |
+
for i in range(N):
|
| 60 |
+
if mask[i]:
|
| 61 |
+
aa_onehot[i, aa_idx[i]] = 1.0
|
| 62 |
+
|
| 63 |
+
# Chain indicator
|
| 64 |
+
chain_feat = np.full((N, 1), chain_id, dtype=np.float32)
|
| 65 |
+
|
| 66 |
+
# Concatenate
|
| 67 |
+
node_feats = np.concatenate([
|
| 68 |
+
aa_onehot, # [N, 21]
|
| 69 |
+
torsion_angles, # [N, 6]
|
| 70 |
+
chi_angles, # [N, 4]
|
| 71 |
+
chain_feat, # [N, 1]
|
| 72 |
+
], axis=-1)
|
| 73 |
+
|
| 74 |
+
return node_feats # [N, 32]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def extract_edge_features(coords_i, frames_i, coords_j, frames_j,
|
| 78 |
+
seq_idx_i, seq_idx_j, chain_i, chain_j, mask_i, mask_j):
|
| 79 |
+
"""
|
| 80 |
+
Compute SE(3)-invariant edge features between residue sets i and j.
|
| 81 |
+
Vectorized over all pairs.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
coords_i: [N_i, 4, 3] backbone coords of set i (full interface)
|
| 85 |
+
frames_i: (origins_i [N_i, 3], rotations_i [N_i, 3, 3])
|
| 86 |
+
coords_j: [N_j, 4, 3]
|
| 87 |
+
frames_j: (origins_j [N_j, 3], rotations_j [N_j, 3, 3])
|
| 88 |
+
seq_idx_i: [N_i] integer sequence indices (for sequence separation)
|
| 89 |
+
seq_idx_j: [N_j] integer sequence indices
|
| 90 |
+
chain_i: int (0 or 1)
|
| 91 |
+
chain_j: int (0 or 1)
|
| 92 |
+
mask_i: [N_i] bool
|
| 93 |
+
mask_j: [N_j] bool
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
edge_feats: [N_i, N_j, EDGE_DIM]
|
| 97 |
+
"""
|
| 98 |
+
N_i, N_j = len(coords_i), len(coords_j)
|
| 99 |
+
origins_i, rotations_i = frames_i
|
| 100 |
+
origins_j, rotations_j = frames_j
|
| 101 |
+
|
| 102 |
+
ca_i = origins_i # [N_i, 3]
|
| 103 |
+
ca_j = origins_j # [N_j, 3]
|
| 104 |
+
|
| 105 |
+
# --- Distance features ---
|
| 106 |
+
diff = ca_j[None, :, :] - ca_i[:, None, :] # [N_i, N_j, 3]
|
| 107 |
+
dist = np.sqrt((diff ** 2).sum(axis=-1)) # [N_i, N_j]
|
| 108 |
+
dist_rbf = rbf_encode(dist, d_min=0., d_max=20., n_bins=16) # [N_i, N_j, 16]
|
| 109 |
+
|
| 110 |
+
# --- Direction in local frame of i ---
|
| 111 |
+
# unit vector from i to j in global frame
|
| 112 |
+
unit_diff = diff / (dist[..., None] + 1e-8) # [N_i, N_j, 3]
|
| 113 |
+
# rotate by R_i^T to get local direction
|
| 114 |
+
# rotations_i: [N_i, 3, 3], unit_diff: [N_i, N_j, 3]
|
| 115 |
+
# local_dir[i,j] = R_i^T @ (ca_j - ca_i) / dist
|
| 116 |
+
local_dir = np.einsum('ikl,ijl->ijk', rotations_i, unit_diff) # [N_i, N_j, 3]
|
| 117 |
+
|
| 118 |
+
# --- Relative rotation: R_i^T R_j ---
|
| 119 |
+
# rotations_i: [N_i, 3, 3], rotations_j: [N_j, 3, 3]
|
| 120 |
+
# rel_rot[i,j] = R_i^T @ R_j -> [N_i, N_j, 3, 3] -> flatten to [N_i, N_j, 9]
|
| 121 |
+
rel_rot = np.einsum('ikl,jlm->ijkm', rotations_i, rotations_j) # [N_i, N_j, 3, 3]
|
| 122 |
+
rel_rot_flat = rel_rot.reshape(N_i, N_j, 9) # [N_i, N_j, 9]
|
| 123 |
+
|
| 124 |
+
# --- Sequence separation ---
|
| 125 |
+
sep = seq_idx_j[None, :] - seq_idx_i[:, None] # [N_i, N_j]
|
| 126 |
+
# Encode each pair (loop over all; use vectorized bin assignment)
|
| 127 |
+
sep_flat = sep.reshape(-1)
|
| 128 |
+
sep_enc = np.array([seq_sep_encode(s) for s in sep_flat]) # [N_i*N_j, 8]
|
| 129 |
+
sep_enc = sep_enc.reshape(N_i, N_j, 8)
|
| 130 |
+
|
| 131 |
+
# Cross-chain pairs get sep=0 by convention if different chains
|
| 132 |
+
if chain_i != chain_j:
|
| 133 |
+
sep_enc[:] = 0.0
|
| 134 |
+
|
| 135 |
+
# --- Same chain indicator ---
|
| 136 |
+
same_chain = float(chain_i == chain_j)
|
| 137 |
+
same_chain_feat = np.full((N_i, N_j, 1), same_chain, dtype=np.float32)
|
| 138 |
+
|
| 139 |
+
# --- Concatenate ---
|
| 140 |
+
edge_feats = np.concatenate([
|
| 141 |
+
dist_rbf, # [N_i, N_j, 16]
|
| 142 |
+
local_dir, # [N_i, N_j, 3]
|
| 143 |
+
rel_rot_flat, # [N_i, N_j, 9]
|
| 144 |
+
sep_enc, # [N_i, N_j, 8]
|
| 145 |
+
same_chain_feat # [N_i, N_j, 1]
|
| 146 |
+
], axis=-1) # [N_i, N_j, 37]
|
| 147 |
+
|
| 148 |
+
# Zero out edges involving masked residues
|
| 149 |
+
edge_feats[~mask_i, :, :] = 0.0
|
| 150 |
+
edge_feats[:, ~mask_j, :] = 0.0
|
| 151 |
+
|
| 152 |
+
return edge_feats.astype(np.float32)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def build_interface_graph(rec_residues, rec_coords, rec_mask,
|
| 156 |
+
binder_residues, binder_coords, binder_mask,
|
| 157 |
+
rec_interface_mask, binder_interface_mask,
|
| 158 |
+
max_nodes: int = 128):
|
| 159 |
+
"""
|
| 160 |
+
Build a joint interface graph combining receptor and binder interface residues.
|
| 161 |
+
|
| 162 |
+
Returns a dict with:
|
| 163 |
+
node_feats: [N_total, NODE_DIM]
|
| 164 |
+
edge_feats: [N_total, N_total, EDGE_DIM]
|
| 165 |
+
node_mask: [N_total] bool
|
| 166 |
+
n_rec: int (number of receptor interface nodes)
|
| 167 |
+
n_binder: int (number of binder interface nodes)
|
| 168 |
+
"""
|
| 169 |
+
# Select interface residues
|
| 170 |
+
rec_iface_idx = np.where(rec_interface_mask)[0]
|
| 171 |
+
binder_iface_idx = np.where(binder_interface_mask)[0]
|
| 172 |
+
|
| 173 |
+
# Truncate if too many
|
| 174 |
+
if len(rec_iface_idx) > max_nodes // 2:
|
| 175 |
+
rec_iface_idx = rec_iface_idx[:max_nodes // 2]
|
| 176 |
+
if len(binder_iface_idx) > max_nodes // 2:
|
| 177 |
+
binder_iface_idx = binder_iface_idx[:max_nodes // 2]
|
| 178 |
+
|
| 179 |
+
n_rec = len(rec_iface_idx)
|
| 180 |
+
n_binder = len(binder_iface_idx)
|
| 181 |
+
n_total = n_rec + n_binder
|
| 182 |
+
|
| 183 |
+
if n_total == 0:
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
# Extract coords for interface residues
|
| 187 |
+
rec_iface_coords = rec_coords[rec_iface_idx] # [n_rec, 4, 3]
|
| 188 |
+
binder_iface_coords = binder_coords[binder_iface_idx] # [n_binder, 4, 3]
|
| 189 |
+
rec_iface_mask = rec_mask[rec_iface_idx]
|
| 190 |
+
binder_iface_mask = binder_mask[binder_iface_idx]
|
| 191 |
+
|
| 192 |
+
# Compute backbone frames
|
| 193 |
+
rec_origins, rec_rotations = compute_backbone_frames(rec_iface_coords, rec_iface_mask)
|
| 194 |
+
binder_origins, binder_rotations = compute_backbone_frames(binder_iface_coords, binder_iface_mask)
|
| 195 |
+
|
| 196 |
+
# Compute torsion angles
|
| 197 |
+
# We need full-chain coords for proper phi/psi computation, but use local approximation here
|
| 198 |
+
rec_torsion = compute_torsion_angles(rec_iface_coords, rec_iface_mask)
|
| 199 |
+
binder_torsion = compute_torsion_angles(binder_iface_coords, binder_iface_mask)
|
| 200 |
+
|
| 201 |
+
# Extract residues
|
| 202 |
+
rec_iface_residues = [rec_residues[i] for i in rec_iface_idx]
|
| 203 |
+
binder_iface_residues = [binder_residues[i] for i in binder_iface_idx]
|
| 204 |
+
|
| 205 |
+
# Compute sidechain chi1/chi2 angles
|
| 206 |
+
rec_chi = compute_chi_angles(rec_iface_residues, rec_iface_mask)
|
| 207 |
+
binder_chi = compute_chi_angles(binder_iface_residues, binder_iface_mask)
|
| 208 |
+
|
| 209 |
+
# Node features
|
| 210 |
+
rec_node_feats = extract_node_features(
|
| 211 |
+
rec_iface_residues, rec_iface_coords, rec_iface_mask, rec_torsion, rec_chi, chain_id=0
|
| 212 |
+
) # [n_rec, NODE_DIM]
|
| 213 |
+
binder_node_feats = extract_node_features(
|
| 214 |
+
binder_iface_residues, binder_iface_coords, binder_iface_mask, binder_torsion, binder_chi, chain_id=1
|
| 215 |
+
) # [n_binder, NODE_DIM]
|
| 216 |
+
|
| 217 |
+
node_feats = np.concatenate([rec_node_feats, binder_node_feats], axis=0) # [N, NODE_DIM]
|
| 218 |
+
node_mask = np.concatenate([rec_iface_mask, binder_iface_mask], axis=0)
|
| 219 |
+
|
| 220 |
+
# Edge features (4 blocks: RR, RB, BR, BB)
|
| 221 |
+
all_coords = np.concatenate([rec_iface_coords, binder_iface_coords], axis=0)
|
| 222 |
+
all_mask = node_mask
|
| 223 |
+
all_origins = np.concatenate([rec_origins, binder_origins], axis=0)
|
| 224 |
+
all_rotations = np.concatenate([rec_rotations, binder_rotations], axis=0)
|
| 225 |
+
all_seq_idx = np.concatenate([rec_iface_idx, binder_iface_idx + len(rec_residues)], axis=0)
|
| 226 |
+
all_chain = np.array([0] * n_rec + [1] * n_binder, dtype=np.int32)
|
| 227 |
+
|
| 228 |
+
# Compute full NxN edge features
|
| 229 |
+
frames_all = (all_origins, all_rotations)
|
| 230 |
+
edge_feats = extract_edge_features(
|
| 231 |
+
all_coords, frames_all,
|
| 232 |
+
all_coords, frames_all,
|
| 233 |
+
all_seq_idx, all_seq_idx,
|
| 234 |
+
-1, -1, # chain handled via all_chain array below
|
| 235 |
+
all_mask, all_mask
|
| 236 |
+
) # [N, N, EDGE_DIM]
|
| 237 |
+
|
| 238 |
+
# Patch same_chain feature (last dim) using actual chain IDs
|
| 239 |
+
same_chain_feat = (all_chain[:, None] == all_chain[None, :]).astype(np.float32)
|
| 240 |
+
edge_feats[:, :, -1] = same_chain_feat
|
| 241 |
+
|
| 242 |
+
return {
|
| 243 |
+
'node_feats': node_feats.astype(np.float32), # [N, NODE_DIM]
|
| 244 |
+
'edge_feats': edge_feats.astype(np.float32), # [N, N, EDGE_DIM]
|
| 245 |
+
'node_mask': node_mask, # [N]
|
| 246 |
+
'n_rec': n_rec,
|
| 247 |
+
'n_binder': n_binder,
|
| 248 |
+
'rec_iface_idx': rec_iface_idx, # [n_rec] original residue indices
|
| 249 |
+
'binder_iface_idx': binder_iface_idx, # [n_binder] original residue indices
|
| 250 |
+
}
|
code/models/scorer.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Q_theta: State-selectivity scorer for Allo-Designer.
|
| 3 |
+
|
| 4 |
+
Architecture: Dense Edge-Biased Graph Transformer
|
| 5 |
+
- Input: padded interface graph (node feats + pairwise edge feats)
|
| 6 |
+
- SE(3)-invariant features (all features from distances/angles in backbone frames)
|
| 7 |
+
- Output: Q_theta(X, Y) in (0,1) = probability-like compatibility/selectivity score
|
| 8 |
+
|
| 9 |
+
No torch_geometric dependency: uses dense attention with edge biases.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RBFLayer(nn.Module):
|
| 19 |
+
"""Learnable RBF embedding for edge distances."""
|
| 20 |
+
def __init__(self, n_bins: int = 16, d_min: float = 0., d_max: float = 20.):
|
| 21 |
+
super().__init__()
|
| 22 |
+
centers = torch.linspace(d_min, d_max, n_bins)
|
| 23 |
+
self.register_buffer('centers', centers)
|
| 24 |
+
self.log_sigma = nn.Parameter(torch.zeros(1))
|
| 25 |
+
|
| 26 |
+
def forward(self, dist):
|
| 27 |
+
# dist: [...] -> [..., n_bins]
|
| 28 |
+
sigma = torch.exp(self.log_sigma)
|
| 29 |
+
return torch.exp(-((dist.unsqueeze(-1) - self.centers) ** 2) / (2 * sigma ** 2))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class EdgeBiasedMHA(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Multi-Head Self-Attention with additive edge biases.
|
| 35 |
+
Implements the core equation:
|
| 36 |
+
A_ij = (Q_i K_j^T / sqrt(d)) + b_ij
|
| 37 |
+
where b_ij is computed from edge features.
|
| 38 |
+
"""
|
| 39 |
+
def __init__(self, d_model: int, n_heads: int, d_edge: int, dropout: float = 0.1):
|
| 40 |
+
super().__init__()
|
| 41 |
+
assert d_model % n_heads == 0
|
| 42 |
+
self.n_heads = n_heads
|
| 43 |
+
self.d_head = d_model // n_heads
|
| 44 |
+
self.scale = math.sqrt(self.d_head)
|
| 45 |
+
|
| 46 |
+
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
|
| 47 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 48 |
+
self.edge_proj = nn.Linear(d_edge, n_heads) # edge features -> per-head bias
|
| 49 |
+
self.dropout = nn.Dropout(dropout)
|
| 50 |
+
|
| 51 |
+
def forward(self, x, edge_feats, mask=None):
|
| 52 |
+
"""
|
| 53 |
+
x: [B, N, d_model]
|
| 54 |
+
edge_feats: [B, N, N, d_edge]
|
| 55 |
+
mask: [B, N] bool (True = valid residue)
|
| 56 |
+
"""
|
| 57 |
+
B, N, D = x.shape
|
| 58 |
+
H = self.n_heads
|
| 59 |
+
|
| 60 |
+
# QKV projection
|
| 61 |
+
qkv = self.qkv_proj(x).reshape(B, N, 3, H, self.d_head).permute(2, 0, 3, 1, 4)
|
| 62 |
+
q, k, v = qkv.unbind(0) # each [B, H, N, d_head]
|
| 63 |
+
|
| 64 |
+
# Scaled dot-product attention logits
|
| 65 |
+
attn_logits = (q @ k.transpose(-2, -1)) / self.scale # [B, H, N, N]
|
| 66 |
+
|
| 67 |
+
# Edge bias: [B, N, N, H] -> [B, H, N, N]
|
| 68 |
+
edge_bias = self.edge_proj(edge_feats).permute(0, 3, 1, 2) # [B, H, N, N]
|
| 69 |
+
attn_logits = attn_logits + edge_bias
|
| 70 |
+
|
| 71 |
+
# Padding mask: mask out padded positions
|
| 72 |
+
if mask is not None:
|
| 73 |
+
# mask: [B, N] True=valid; padding=False
|
| 74 |
+
padding = ~mask # [B, N] True=padding
|
| 75 |
+
attn_logits = attn_logits.masked_fill(
|
| 76 |
+
padding[:, None, None, :], # [B, 1, 1, N]
|
| 77 |
+
float('-inf')
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
attn_weights = self.dropout(F.softmax(attn_logits, dim=-1))
|
| 81 |
+
|
| 82 |
+
# Handle all-padding rows (NaN -> 0)
|
| 83 |
+
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
|
| 84 |
+
|
| 85 |
+
out = (attn_weights @ v) # [B, H, N, d_head]
|
| 86 |
+
out = out.transpose(1, 2).reshape(B, N, D) # [B, N, D]
|
| 87 |
+
return self.out_proj(out)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class InterfaceTransformerLayer(nn.Module):
|
| 91 |
+
"""Single layer of edge-biased transformer with pre-norm."""
|
| 92 |
+
def __init__(self, d_model: int, n_heads: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.attn = EdgeBiasedMHA(d_model, n_heads, d_edge, dropout)
|
| 95 |
+
self.ff = nn.Sequential(
|
| 96 |
+
nn.Linear(d_model, d_model * ff_mult),
|
| 97 |
+
nn.GELU(),
|
| 98 |
+
nn.Dropout(dropout),
|
| 99 |
+
nn.Linear(d_model * ff_mult, d_model),
|
| 100 |
+
)
|
| 101 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 102 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 103 |
+
self.drop = nn.Dropout(dropout)
|
| 104 |
+
|
| 105 |
+
def forward(self, x, edge_feats, mask=None):
|
| 106 |
+
x = x + self.drop(self.attn(self.norm1(x), edge_feats, mask))
|
| 107 |
+
x = x + self.drop(self.ff(self.norm2(x)))
|
| 108 |
+
return x
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class GATLayer(nn.Module):
|
| 112 |
+
"""Multi-head GAT layer with pre-norm. No edge features in attention."""
|
| 113 |
+
def __init__(self, d_model: int, n_heads: int, ff_mult: int = 4, dropout: float = 0.1):
|
| 114 |
+
super().__init__()
|
| 115 |
+
assert d_model % n_heads == 0
|
| 116 |
+
self.n_heads = n_heads
|
| 117 |
+
self.d_head = d_model // n_heads
|
| 118 |
+
|
| 119 |
+
self.W = nn.Linear(d_model, d_model, bias=False)
|
| 120 |
+
self.a_l = nn.Parameter(torch.randn(n_heads, self.d_head))
|
| 121 |
+
self.a_r = nn.Parameter(torch.randn(n_heads, self.d_head))
|
| 122 |
+
nn.init.xavier_uniform_(self.a_l.unsqueeze(0))
|
| 123 |
+
nn.init.xavier_uniform_(self.a_r.unsqueeze(0))
|
| 124 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 125 |
+
self.leaky_relu = nn.LeakyReLU(0.2)
|
| 126 |
+
self.attn_drop = nn.Dropout(dropout)
|
| 127 |
+
|
| 128 |
+
self.ff = nn.Sequential(
|
| 129 |
+
nn.Linear(d_model, d_model * ff_mult), nn.GELU(),
|
| 130 |
+
nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model),
|
| 131 |
+
)
|
| 132 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 133 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 134 |
+
self.drop = nn.Dropout(dropout)
|
| 135 |
+
|
| 136 |
+
def forward(self, x, edge_feats, mask=None):
|
| 137 |
+
B, N, D = x.shape
|
| 138 |
+
H = self.n_heads
|
| 139 |
+
|
| 140 |
+
h = self.norm1(x)
|
| 141 |
+
Wh = self.W(h).view(B, N, H, self.d_head) # [B, N, H, d_head]
|
| 142 |
+
e_l = (Wh * self.a_l).sum(-1) # [B, N, H]
|
| 143 |
+
e_r = (Wh * self.a_r).sum(-1) # [B, N, H]
|
| 144 |
+
attn = self.leaky_relu(e_l.unsqueeze(2) + e_r.unsqueeze(1)) # [B, N, N, H]
|
| 145 |
+
attn = attn.permute(0, 3, 1, 2) # [B, H, N, N]
|
| 146 |
+
|
| 147 |
+
if mask is not None:
|
| 148 |
+
attn = attn.masked_fill(~mask[:, None, None, :], float('-inf'))
|
| 149 |
+
|
| 150 |
+
attn = self.attn_drop(F.softmax(attn, dim=-1))
|
| 151 |
+
attn = torch.nan_to_num(attn, nan=0.0)
|
| 152 |
+
|
| 153 |
+
out = torch.einsum('bhnm,bmhd->bnhd', attn, Wh)
|
| 154 |
+
out = out.reshape(B, N, D)
|
| 155 |
+
x = x + self.drop(self.out_proj(out))
|
| 156 |
+
x = x + self.drop(self.ff(self.norm2(x)))
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class GCNLayer(nn.Module):
|
| 161 |
+
"""GCN layer with edge-weighted message passing and pre-norm."""
|
| 162 |
+
def __init__(self, d_model: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.msg_proj = nn.Linear(d_model, d_model, bias=False)
|
| 165 |
+
self.edge_weight = nn.Linear(d_edge, 1)
|
| 166 |
+
|
| 167 |
+
self.ff = nn.Sequential(
|
| 168 |
+
nn.Linear(d_model, d_model * ff_mult), nn.GELU(),
|
| 169 |
+
nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model),
|
| 170 |
+
)
|
| 171 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 172 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 173 |
+
self.drop = nn.Dropout(dropout)
|
| 174 |
+
|
| 175 |
+
def forward(self, x, edge_feats, mask=None):
|
| 176 |
+
B, N, D = x.shape
|
| 177 |
+
h = self.norm1(x)
|
| 178 |
+
msg = self.msg_proj(h) # [B, N, D]
|
| 179 |
+
|
| 180 |
+
w = self.edge_weight(edge_feats).squeeze(-1) # [B, N, N]
|
| 181 |
+
if mask is not None:
|
| 182 |
+
w = w.masked_fill(~mask[:, None, :], float('-inf'))
|
| 183 |
+
w = F.softmax(w, dim=-1)
|
| 184 |
+
w = torch.nan_to_num(w, nan=0.0)
|
| 185 |
+
|
| 186 |
+
agg = torch.bmm(w, msg) # [B, N, D]
|
| 187 |
+
x = x + self.drop(agg)
|
| 188 |
+
x = x + self.drop(self.ff(self.norm2(x)))
|
| 189 |
+
return x
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class CrossChainTransformerLayer(nn.Module):
|
| 193 |
+
"""Cross-chain attention: each node attends only to nodes from the other chain."""
|
| 194 |
+
def __init__(self, d_model: int, n_heads: int, d_edge: int, ff_mult: int = 4, dropout: float = 0.1):
|
| 195 |
+
super().__init__()
|
| 196 |
+
assert d_model % n_heads == 0
|
| 197 |
+
self.n_heads = n_heads
|
| 198 |
+
self.d_head = d_model // n_heads
|
| 199 |
+
self.scale = math.sqrt(self.d_head)
|
| 200 |
+
|
| 201 |
+
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
|
| 202 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 203 |
+
self.edge_proj = nn.Linear(d_edge, n_heads)
|
| 204 |
+
self.attn_drop = nn.Dropout(dropout)
|
| 205 |
+
|
| 206 |
+
self.ff = nn.Sequential(
|
| 207 |
+
nn.Linear(d_model, d_model * ff_mult), nn.GELU(),
|
| 208 |
+
nn.Dropout(dropout), nn.Linear(d_model * ff_mult, d_model),
|
| 209 |
+
)
|
| 210 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 211 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 212 |
+
self.drop = nn.Dropout(dropout)
|
| 213 |
+
|
| 214 |
+
def forward(self, x, edge_feats, mask=None, chain_mask=None):
|
| 215 |
+
"""
|
| 216 |
+
x: [B, N, d_model]
|
| 217 |
+
edge_feats: [B, N, N, d_edge]
|
| 218 |
+
mask: [B, N] bool (True = valid)
|
| 219 |
+
chain_mask: [B, N] float (0=receptor, 1=binder)
|
| 220 |
+
"""
|
| 221 |
+
B, N, D = x.shape
|
| 222 |
+
H = self.n_heads
|
| 223 |
+
|
| 224 |
+
h = self.norm1(x)
|
| 225 |
+
qkv = self.qkv_proj(h).reshape(B, N, 3, H, self.d_head).permute(2, 0, 3, 1, 4)
|
| 226 |
+
q, k, v = qkv.unbind(0) # each [B, H, N, d_head]
|
| 227 |
+
|
| 228 |
+
attn_logits = (q @ k.transpose(-2, -1)) / self.scale # [B, H, N, N]
|
| 229 |
+
edge_bias = self.edge_proj(edge_feats).permute(0, 3, 1, 2) # [B, H, N, N]
|
| 230 |
+
attn_logits = attn_logits + edge_bias
|
| 231 |
+
|
| 232 |
+
# Mask padding
|
| 233 |
+
if mask is not None:
|
| 234 |
+
attn_logits = attn_logits.masked_fill(~mask[:, None, None, :], float('-inf'))
|
| 235 |
+
|
| 236 |
+
# Cross-chain mask: block same-chain attention
|
| 237 |
+
if chain_mask is not None:
|
| 238 |
+
same_chain = (chain_mask.unsqueeze(1) == chain_mask.unsqueeze(2)) # [B, N, N]
|
| 239 |
+
attn_logits = attn_logits.masked_fill(same_chain[:, None, :, :], float('-inf'))
|
| 240 |
+
|
| 241 |
+
attn_weights = self.attn_drop(F.softmax(attn_logits, dim=-1))
|
| 242 |
+
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
|
| 243 |
+
|
| 244 |
+
out = (attn_weights @ v).transpose(1, 2).reshape(B, N, D)
|
| 245 |
+
x = x + self.drop(self.out_proj(out))
|
| 246 |
+
x = x + self.drop(self.ff(self.norm2(x)))
|
| 247 |
+
return x
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class EdgeUpdateLayer(nn.Module):
|
| 251 |
+
"""Updates edge features using node representations each layer.
|
| 252 |
+
Memory-efficient: projects nodes to low-dim before outer product."""
|
| 253 |
+
def __init__(self, d_model: int, d_edge: int, dropout: float = 0.1):
|
| 254 |
+
super().__init__()
|
| 255 |
+
d_proj = min(32, d_model // 4) # Low-dim projection to save memory
|
| 256 |
+
self.proj_i = nn.Linear(d_model, d_proj, bias=False)
|
| 257 |
+
self.proj_j = nn.Linear(d_model, d_proj, bias=False)
|
| 258 |
+
self.edge_mlp = nn.Sequential(
|
| 259 |
+
nn.Linear(2 * d_proj + d_edge, d_edge),
|
| 260 |
+
nn.GELU(),
|
| 261 |
+
nn.Dropout(dropout),
|
| 262 |
+
nn.Linear(d_edge, d_edge),
|
| 263 |
+
)
|
| 264 |
+
self.norm = nn.LayerNorm(d_edge)
|
| 265 |
+
|
| 266 |
+
def forward(self, h, e, mask=None):
|
| 267 |
+
B, N, D = h.shape
|
| 268 |
+
hi = self.proj_i(h).unsqueeze(2).expand(-1, -1, N, -1) # [B, N, N, d_proj]
|
| 269 |
+
hj = self.proj_j(h).unsqueeze(1).expand(-1, N, -1, -1) # [B, N, N, d_proj]
|
| 270 |
+
inp = torch.cat([hi, hj, self.norm(e)], dim=-1)
|
| 271 |
+
e = e + self.edge_mlp(inp)
|
| 272 |
+
return e
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class InterfaceGNN(nn.Module):
|
| 276 |
+
"""
|
| 277 |
+
Q_theta scorer: SE(3)-invariant dense graph transformer for interface scoring.
|
| 278 |
+
|
| 279 |
+
Input:
|
| 280 |
+
node_feats: [B, N, node_dim] per-residue features
|
| 281 |
+
edge_feats: [B, N, N, edge_dim] pairwise edge features
|
| 282 |
+
mask: [B, N] bool (True = valid residue, False = padding)
|
| 283 |
+
|
| 284 |
+
Output:
|
| 285 |
+
scores: [B] in (0, 1) = Q_theta(X, Y)
|
| 286 |
+
"""
|
| 287 |
+
def __init__(
|
| 288 |
+
self,
|
| 289 |
+
node_dim: int = 28,
|
| 290 |
+
edge_dim: int = 37,
|
| 291 |
+
hidden_dim: int = 128,
|
| 292 |
+
n_layers: int = 4,
|
| 293 |
+
n_heads: int = 8,
|
| 294 |
+
ff_mult: int = 4,
|
| 295 |
+
dropout: float = 0.1,
|
| 296 |
+
backbone: str = 'transformer',
|
| 297 |
+
pooling: str = 'meanmax', # 'meanmax' or 'attention'
|
| 298 |
+
edge_update: bool = False,
|
| 299 |
+
esm_dim: int = 0, # 0 = no ESM; >0 = ESM embedding dim to project
|
| 300 |
+
esm_proj_dim: int = 128, # projection dim for ESM features
|
| 301 |
+
esm_dropout: float = 0.0, # dropout on ESM projection
|
| 302 |
+
):
|
| 303 |
+
super().__init__()
|
| 304 |
+
actual_node_dim = node_dim + (esm_proj_dim if esm_dim > 0 else 0)
|
| 305 |
+
self.esm_dim = esm_dim
|
| 306 |
+
if esm_dim > 0:
|
| 307 |
+
layers = [
|
| 308 |
+
nn.Linear(esm_dim, esm_proj_dim),
|
| 309 |
+
nn.LayerNorm(esm_proj_dim),
|
| 310 |
+
nn.GELU(),
|
| 311 |
+
]
|
| 312 |
+
if esm_dropout > 0:
|
| 313 |
+
layers.append(nn.Dropout(esm_dropout))
|
| 314 |
+
self.esm_proj = nn.Sequential(*layers)
|
| 315 |
+
self.node_embed = nn.Sequential(
|
| 316 |
+
nn.Linear(actual_node_dim, hidden_dim),
|
| 317 |
+
nn.LayerNorm(hidden_dim),
|
| 318 |
+
nn.GELU(),
|
| 319 |
+
)
|
| 320 |
+
self.edge_embed = nn.Sequential(
|
| 321 |
+
nn.Linear(edge_dim, hidden_dim),
|
| 322 |
+
nn.GELU(),
|
| 323 |
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
| 324 |
+
)
|
| 325 |
+
d_edge_hidden = hidden_dim // 2
|
| 326 |
+
|
| 327 |
+
if backbone == 'transformer':
|
| 328 |
+
self.layers = nn.ModuleList([
|
| 329 |
+
InterfaceTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout)
|
| 330 |
+
for _ in range(n_layers)
|
| 331 |
+
])
|
| 332 |
+
elif backbone == 'gat':
|
| 333 |
+
self.layers = nn.ModuleList([
|
| 334 |
+
GATLayer(hidden_dim, n_heads, ff_mult, dropout)
|
| 335 |
+
for _ in range(n_layers)
|
| 336 |
+
])
|
| 337 |
+
elif backbone == 'gcn':
|
| 338 |
+
self.layers = nn.ModuleList([
|
| 339 |
+
GCNLayer(hidden_dim, d_edge_hidden, ff_mult, dropout)
|
| 340 |
+
for _ in range(n_layers)
|
| 341 |
+
])
|
| 342 |
+
elif backbone == 'crosschain':
|
| 343 |
+
# Interleave self-attention and cross-chain attention
|
| 344 |
+
layers = []
|
| 345 |
+
for i in range(n_layers):
|
| 346 |
+
if i % 2 == 0:
|
| 347 |
+
layers.append(InterfaceTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout))
|
| 348 |
+
else:
|
| 349 |
+
layers.append(CrossChainTransformerLayer(hidden_dim, n_heads, d_edge_hidden, ff_mult, dropout))
|
| 350 |
+
self.layers = nn.ModuleList(layers)
|
| 351 |
+
else:
|
| 352 |
+
raise ValueError(f"Unknown backbone: {backbone}")
|
| 353 |
+
|
| 354 |
+
self.norm_out = nn.LayerNorm(hidden_dim)
|
| 355 |
+
|
| 356 |
+
# Edge update layers (optional)
|
| 357 |
+
self.edge_update = edge_update
|
| 358 |
+
if edge_update:
|
| 359 |
+
self.edge_update_layers = nn.ModuleList([
|
| 360 |
+
EdgeUpdateLayer(hidden_dim, d_edge_hidden, dropout)
|
| 361 |
+
for _ in range(n_layers)
|
| 362 |
+
])
|
| 363 |
+
|
| 364 |
+
# Pooling
|
| 365 |
+
self.pooling = pooling
|
| 366 |
+
if pooling == 'attention':
|
| 367 |
+
self.attn_pool = nn.Sequential(
|
| 368 |
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
| 369 |
+
nn.Tanh(),
|
| 370 |
+
nn.Linear(hidden_dim // 2, 1),
|
| 371 |
+
)
|
| 372 |
+
pool_dim = hidden_dim
|
| 373 |
+
else:
|
| 374 |
+
pool_dim = 2 * hidden_dim
|
| 375 |
+
|
| 376 |
+
# Scoring head
|
| 377 |
+
self.head = nn.Sequential(
|
| 378 |
+
nn.Linear(pool_dim, hidden_dim),
|
| 379 |
+
nn.GELU(),
|
| 380 |
+
nn.Dropout(dropout),
|
| 381 |
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
| 382 |
+
nn.GELU(),
|
| 383 |
+
nn.Linear(hidden_dim // 2, 1),
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
def forward(self, node_feats, edge_feats, mask, esm_feats=None):
|
| 387 |
+
"""
|
| 388 |
+
node_feats: [B, N, node_dim]
|
| 389 |
+
edge_feats: [B, N, N, edge_dim]
|
| 390 |
+
mask: [B, N] bool
|
| 391 |
+
esm_feats: [B, N, esm_dim] optional ESM-2 embeddings
|
| 392 |
+
Returns: scores [B] in (0, 1)
|
| 393 |
+
"""
|
| 394 |
+
B, N, _ = node_feats.shape
|
| 395 |
+
|
| 396 |
+
# Extract chain mask for cross-chain attention (last dim = chain indicator)
|
| 397 |
+
chain_mask = node_feats[:, :, -1] # [B, N] float: 0=receptor, 1=binder
|
| 398 |
+
|
| 399 |
+
# Optionally concatenate projected ESM features
|
| 400 |
+
if self.esm_dim > 0 and esm_feats is not None:
|
| 401 |
+
esm_proj = self.esm_proj(esm_feats) # [B, N, 128]
|
| 402 |
+
node_feats = torch.cat([node_feats, esm_proj], dim=-1)
|
| 403 |
+
|
| 404 |
+
# Embed nodes and edges
|
| 405 |
+
h = self.node_embed(node_feats) # [B, N, hidden_dim]
|
| 406 |
+
e = self.edge_embed(edge_feats) # [B, N, N, hidden_dim//2]
|
| 407 |
+
|
| 408 |
+
# Graph transformer layers (with optional edge updates)
|
| 409 |
+
for i, layer in enumerate(self.layers):
|
| 410 |
+
if isinstance(layer, CrossChainTransformerLayer):
|
| 411 |
+
h = layer(h, e, mask, chain_mask=chain_mask)
|
| 412 |
+
else:
|
| 413 |
+
h = layer(h, e, mask)
|
| 414 |
+
if self.edge_update:
|
| 415 |
+
e = self.edge_update_layers[i](h, e, mask)
|
| 416 |
+
|
| 417 |
+
h = self.norm_out(h) # [B, N, hidden_dim]
|
| 418 |
+
|
| 419 |
+
# Pooling
|
| 420 |
+
mask_f = mask.float().unsqueeze(-1) # [B, N, 1]
|
| 421 |
+
|
| 422 |
+
if self.pooling == 'attention':
|
| 423 |
+
# Learned attention pooling
|
| 424 |
+
attn_logits = self.attn_pool(h).squeeze(-1) # [B, N]
|
| 425 |
+
attn_logits = attn_logits.masked_fill(~mask, float('-inf'))
|
| 426 |
+
attn_weights = F.softmax(attn_logits, dim=-1).unsqueeze(-1) # [B, N, 1]
|
| 427 |
+
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
|
| 428 |
+
h_pool = (h * attn_weights).sum(dim=1) # [B, hidden_dim]
|
| 429 |
+
else:
|
| 430 |
+
# Mean + max pooling
|
| 431 |
+
h_masked = h * mask_f
|
| 432 |
+
h_mean = h_masked.sum(dim=1) / (mask_f.sum(dim=1) + 1e-8)
|
| 433 |
+
h_max_input = h_masked + (1 - mask_f) * (-1e9)
|
| 434 |
+
h_max = h_max_input.max(dim=1).values
|
| 435 |
+
h_pool = torch.cat([h_mean, h_max], dim=-1) # [B, 2*hidden_dim]
|
| 436 |
+
|
| 437 |
+
# Score
|
| 438 |
+
logits = self.head(h_pool).squeeze(-1) # [B]
|
| 439 |
+
scores = torch.sigmoid(logits) # [B] in (0, 1)
|
| 440 |
+
return scores
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class AlloDesignerScorer(nn.Module):
|
| 444 |
+
"""
|
| 445 |
+
Full Q_theta model wrapper with loss computation.
|
| 446 |
+
|
| 447 |
+
Implements the two-stage training objective:
|
| 448 |
+
Phase 1: DockQ regression (MSE loss)
|
| 449 |
+
Phase 2: Selectivity margin ranking (contrastive loss)
|
| 450 |
+
|
| 451 |
+
The selectivity margin from the paper (Eq. 3):
|
| 452 |
+
S_theta(Y; X+, N) = logit(Q(X+, Y)) - log sum_X- exp(logit(Q(X-, Y)))
|
| 453 |
+
"""
|
| 454 |
+
def __init__(self, node_dim=28, edge_dim=37, hidden_dim=128,
|
| 455 |
+
n_layers=4, n_heads=8, dropout=0.1, backbone='transformer',
|
| 456 |
+
pooling='meanmax', edge_update=False, esm_dim=0,
|
| 457 |
+
esm_proj_dim=128, esm_dropout=0.0):
|
| 458 |
+
super().__init__()
|
| 459 |
+
self.gnn = InterfaceGNN(node_dim, edge_dim, hidden_dim, n_layers, n_heads,
|
| 460 |
+
dropout=dropout, backbone=backbone,
|
| 461 |
+
pooling=pooling, edge_update=edge_update,
|
| 462 |
+
esm_dim=esm_dim, esm_proj_dim=esm_proj_dim,
|
| 463 |
+
esm_dropout=esm_dropout)
|
| 464 |
+
|
| 465 |
+
def forward(self, node_feats, edge_feats, mask, esm_feats=None):
|
| 466 |
+
return self.gnn(node_feats, edge_feats, mask, esm_feats=esm_feats)
|
| 467 |
+
|
| 468 |
+
def compute_dockq_loss(self, scores, dockq_labels):
|
| 469 |
+
"""Phase 1: MSE regression loss against DockQ labels."""
|
| 470 |
+
return F.mse_loss(scores, dockq_labels.float())
|
| 471 |
+
|
| 472 |
+
def compute_selectivity_loss(self, pos_scores, neg_scores_list, margin: float = 0.2):
|
| 473 |
+
"""
|
| 474 |
+
Phase 2: Selectivity margin loss.
|
| 475 |
+
|
| 476 |
+
For each binder Y:
|
| 477 |
+
pos_score = Q(X+, Y)
|
| 478 |
+
neg_scores = [Q(X-, Y) for X- in N]
|
| 479 |
+
|
| 480 |
+
Loss = -mean(S_theta) where
|
| 481 |
+
S_theta = logit(pos_score) - log sum exp(logit(neg_scores))
|
| 482 |
+
|
| 483 |
+
Also computes a soft margin loss:
|
| 484 |
+
L_margin = mean(max(0, margin - (pos_score - neg_score)))
|
| 485 |
+
"""
|
| 486 |
+
# logit = log(p / (1-p))
|
| 487 |
+
eps = 1e-6
|
| 488 |
+
pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps))
|
| 489 |
+
|
| 490 |
+
# neg_scores_list: list of [B] tensors
|
| 491 |
+
neg_logits = torch.stack([
|
| 492 |
+
torch.log(s.clamp(eps, 1 - eps) / (1 - s).clamp(eps))
|
| 493 |
+
for s in neg_scores_list
|
| 494 |
+
], dim=-1) # [B, n_neg]
|
| 495 |
+
|
| 496 |
+
# InfoNCE-style selectivity margin
|
| 497 |
+
log_denom = torch.logsumexp(neg_logits, dim=-1) # [B]
|
| 498 |
+
selectivity = pos_logit - log_denom # [B]
|
| 499 |
+
selectivity_loss = -selectivity.mean()
|
| 500 |
+
|
| 501 |
+
# Soft margin loss (averaged over all negatives)
|
| 502 |
+
margin_losses = []
|
| 503 |
+
for neg_scores in neg_scores_list:
|
| 504 |
+
margin_losses.append(F.relu(margin - (pos_scores - neg_scores)))
|
| 505 |
+
margin_loss = torch.stack(margin_losses, dim=-1).mean()
|
| 506 |
+
|
| 507 |
+
return selectivity_loss + margin_loss
|
| 508 |
+
|
| 509 |
+
def compute_path_selectivity_loss(self, pos_scores, neg_scores_list,
|
| 510 |
+
path_scores_list, path_taus,
|
| 511 |
+
margin=0.2, path_lambda=0.5):
|
| 512 |
+
"""
|
| 513 |
+
Extended selectivity loss with path monotonicity regularization.
|
| 514 |
+
|
| 515 |
+
Args:
|
| 516 |
+
pos_scores: [B] Q(X1, Y) -- goal state scores
|
| 517 |
+
neg_scores_list: list of [B] -- Q(X0, Y), Q(X_cryptic, Y), etc.
|
| 518 |
+
path_scores_list: list of [B] -- Q(X_tau, Y) for each path frame
|
| 519 |
+
path_taus: list of float -- tau values for each path frame (sorted)
|
| 520 |
+
margin: margin for ranking loss
|
| 521 |
+
path_lambda: weight for path monotonicity loss
|
| 522 |
+
|
| 523 |
+
Returns:
|
| 524 |
+
total_loss: selectivity loss + path_lambda * monotonicity loss
|
| 525 |
+
loss_dict: breakdown of loss components
|
| 526 |
+
"""
|
| 527 |
+
# Standard selectivity loss (unchanged)
|
| 528 |
+
select_loss = self.compute_selectivity_loss(pos_scores, neg_scores_list, margin)
|
| 529 |
+
|
| 530 |
+
# Path monotonicity loss: ensure Q increases with tau
|
| 531 |
+
loss_monotone = torch.tensor(0.0, device=pos_scores.device)
|
| 532 |
+
if path_scores_list and path_lambda > 0:
|
| 533 |
+
small_margin = 0.05
|
| 534 |
+
# Consecutive path frames should be monotonically increasing
|
| 535 |
+
for i in range(len(path_scores_list) - 1):
|
| 536 |
+
loss_monotone = loss_monotone + F.relu(
|
| 537 |
+
path_scores_list[i] - path_scores_list[i + 1] + small_margin
|
| 538 |
+
).mean()
|
| 539 |
+
# Last path frame should be less than positive (holo) score
|
| 540 |
+
loss_monotone = loss_monotone + F.relu(
|
| 541 |
+
path_scores_list[-1] - pos_scores + margin
|
| 542 |
+
).mean()
|
| 543 |
+
# First path frame should be greater than negative (apo) score
|
| 544 |
+
if neg_scores_list:
|
| 545 |
+
loss_monotone = loss_monotone + F.relu(
|
| 546 |
+
neg_scores_list[0] - path_scores_list[0] + small_margin
|
| 547 |
+
).mean()
|
| 548 |
+
|
| 549 |
+
total = select_loss + path_lambda * loss_monotone
|
| 550 |
+
return total, {
|
| 551 |
+
'loss_selectivity': select_loss.item(),
|
| 552 |
+
'loss_path_monotone': loss_monotone.item(),
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
def compute_combined_loss(self, pos_scores, neg_scores_list, dockq_labels,
|
| 556 |
+
lambda_rank: float = 1.0):
|
| 557 |
+
"""Combined Phase 1 + Phase 2 loss."""
|
| 558 |
+
# Regression loss on all scores (pos + neg get appropriate labels)
|
| 559 |
+
dockq_loss = self.compute_dockq_loss(pos_scores, dockq_labels)
|
| 560 |
+
|
| 561 |
+
# Selectivity loss
|
| 562 |
+
select_loss = self.compute_selectivity_loss(pos_scores, neg_scores_list)
|
| 563 |
+
|
| 564 |
+
return dockq_loss + lambda_rank * select_loss, {
|
| 565 |
+
'loss_dockq': dockq_loss.item(),
|
| 566 |
+
'loss_selectivity': select_loss.item(),
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def build_model(config: dict) -> AlloDesignerScorer:
|
| 571 |
+
"""Build the Q_theta scorer from a config dict."""
|
| 572 |
+
return AlloDesignerScorer(
|
| 573 |
+
node_dim=config.get('node_dim', 32),
|
| 574 |
+
edge_dim=config.get('edge_dim', 37),
|
| 575 |
+
hidden_dim=config.get('hidden_dim', 128),
|
| 576 |
+
n_layers=config.get('n_layers', 4),
|
| 577 |
+
n_heads=config.get('n_heads', 8),
|
| 578 |
+
dropout=config.get('dropout', 0.1),
|
| 579 |
+
backbone=config.get('backbone', 'transformer'),
|
| 580 |
+
pooling=config.get('pooling', 'meanmax'),
|
| 581 |
+
edge_update=config.get('edge_update', False),
|
| 582 |
+
esm_dim=config.get('esm_dim', 0),
|
| 583 |
+
esm_proj_dim=config.get('esm_proj_dim', 128),
|
| 584 |
+
esm_dropout=config.get('esm_dropout', 0.0),
|
| 585 |
+
)
|
code/requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
|
| 5 |
+
# Protein structure
|
| 6 |
+
biopython>=1.80
|
| 7 |
+
|
| 8 |
+
# ML utilities
|
| 9 |
+
scipy>=1.10.0
|
| 10 |
+
scikit-learn>=1.3.0
|
| 11 |
+
|
| 12 |
+
# Experiment tracking
|
| 13 |
+
wandb>=0.12.0
|
| 14 |
+
|
| 15 |
+
# Config
|
| 16 |
+
pyyaml>=6.0
|
| 17 |
+
|
| 18 |
+
# Visualization
|
| 19 |
+
matplotlib>=3.7.0
|
| 20 |
+
|
| 21 |
+
# Optional accelerations
|
| 22 |
+
einops>=0.6.0
|
code/scripts/README.md
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# `code/scripts/` — entry points
|
| 2 |
+
|
| 3 |
+
This public release ships only the inference and sampling code for Q_θ.
|
| 4 |
+
|
| 5 |
+
| File / dir | Purpose |
|
| 6 |
+
|---|---|
|
| 7 |
+
| `evaluate.py` | Score binders in a pre-built `*.pkl` test set with a Q_θ checkpoint; reports Spearman ρ, AUC, selectivity gap. |
|
| 8 |
+
| `rescore.py` | Re-score raw PDB designs (binder + holo + apo) with Q_θ. |
|
| 9 |
+
| `pxdesign_guidance/` | PXDesign-prior guidance with Q_θ (Langevin / SMC / TDS / classifier). |
|
| 10 |
+
|
| 11 |
+
Training, baseline scoring (ProteinMPNN / ESM-IF / Rosetta / DFIRE / energy panel), guidance for RFdiffusion / Proteina-ComplexA, and paper-figure aggregation are **not** shipped; the inference path above is the only supported surface for the public release.
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## Deploying Q_θ with other base models
|
| 16 |
+
|
| 17 |
+
Q_θ provides two interfaces:
|
| 18 |
+
|
| 19 |
+
1. **Re-ranker (best-of-K).** Given K candidate binders from any prior, score each with `S(Y) = Q_θ(X¹, Y) − Q_θ(X⁰, Y)` and pick the top. No gradient signal needed; the prior is unmodified.
|
| 20 |
+
2. **Gradient signal for guidance.** Compute `∇_Y S(Y)` via `DifferentiableQTheta` (in `code/models/differentiable_features.py`) and inject into the prior's sampler (Langevin step, SMC weight, TDS twist, classifier guidance score).
|
| 21 |
+
|
| 22 |
+
The `pxdesign_guidance/` subdir is a worked example of interface (2) wrapping PXDesign. To plug Q_θ into another prior, mirror that pattern:
|
| 23 |
+
|
| 24 |
+
### RFdiffusion
|
| 25 |
+
|
| 26 |
+
1. Clone RFdiffusion: <https://github.com/RosettaCommons/RFdiffusion>.
|
| 27 |
+
2. Follow its install + checkpoint download.
|
| 28 |
+
3. In RFdiffusion's diffusion loop, after each denoising step, materialize the predicted backbone, build the holo/apo graph inputs expected by `DifferentiableQTheta`, and either:
|
| 29 |
+
- Apply a Langevin nudge: `x ← x + η · ∇_x S(x)`.
|
| 30 |
+
- Add a classifier-guidance term to the denoiser's `xt-1` mean: `μ' = μ + s · σ² · ∇_x log p(y|x)`, where `log p(y|x) ≈ S(x)` (Q_θ is treated as the log-likelihood of "is good binder").
|
| 31 |
+
4. Reference template: `pxdesign_guidance/guided_pxdesign.py`.
|
| 32 |
+
|
| 33 |
+
### Proteina-ComplexA
|
| 34 |
+
|
| 35 |
+
1. Clone Proteina: <https://github.com/proteinabio/proteina-complexa> (or the released artifact).
|
| 36 |
+
2. Use its ComplexA mode that emits binder coords conditioned on a receptor.
|
| 37 |
+
3. Same plug pattern as RFdiffusion — wrap the sampler with `DifferentiableQTheta` for guidance, or run unguided and re-rank with `evaluate.py` / `rescore.py`.
|
| 38 |
+
|
| 39 |
+
### Any backbone prior
|
| 40 |
+
|
| 41 |
+
The only contract Q_θ enforces:
|
| 42 |
+
|
| 43 |
+
- Receptor input is a PDB with holo and apo coordinates.
|
| 44 |
+
- Binder input is a PDB (or coords) with chain id distinct from receptor's.
|
| 45 |
+
- For guidance, expose differentiable Cα + backbone coordinates so `∇_x S(x)` flows.
|
| 46 |
+
|
| 47 |
+
See `code/models/differentiable_features.py:DifferentiableQTheta` for the exact interface (`load_receptor(holo_path, apo_path, …)`, `score(design_path, binder_chain, state)`, `.differentiable_score(coords, …)`).
|
| 48 |
+
|
| 49 |
+
---
|
| 50 |
+
|
| 51 |
+
## Why other guidance scripts aren't shipped
|
| 52 |
+
|
| 53 |
+
The RFdiffusion / Proteina guidance variants in our internal tree depend on those projects' un-released CIF formats and patched samplers; we don't want to ship modified third-party code. The PXDesign variants we do ship use only PXDesign's public API and are self-contained.
|
| 54 |
+
|
| 55 |
+
For citation / reproduction context, see the paper §4 (guidance methods).
|
code/scripts/evaluate.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation script for the trained Q_theta scorer.
|
| 3 |
+
|
| 4 |
+
Computes:
|
| 5 |
+
1. Selectivity metrics (gap, ranking accuracy, AUC)
|
| 6 |
+
2. DockQ correlation (Spearman/Pearson)
|
| 7 |
+
3. Score distributions (violin plots)
|
| 8 |
+
4. Best-of-K analysis (as function of K)
|
| 9 |
+
5. Per-target breakdown
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python code/scripts/evaluate.py \
|
| 13 |
+
--target cam \
|
| 14 |
+
--checkpoint checkpoints/Q_theta_phase2.pt \
|
| 15 |
+
--data_dir data/processed \
|
| 16 |
+
--gpu 7
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import argparse
|
| 22 |
+
import logging
|
| 23 |
+
import json
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
import matplotlib
|
| 27 |
+
matplotlib.use('Agg')
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
from scipy.stats import spearmanr, pearsonr
|
| 30 |
+
from sklearn.metrics import roc_auc_score, roc_curve
|
| 31 |
+
|
| 32 |
+
_CODE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
| 33 |
+
if _CODE_DIR not in sys.path:
|
| 34 |
+
sys.path.insert(0, _CODE_DIR)
|
| 35 |
+
|
| 36 |
+
from models.scorer import build_model
|
| 37 |
+
from data.dataset import TwoStateComplexDataset, collate_fn
|
| 38 |
+
from torch.utils.data import DataLoader
|
| 39 |
+
|
| 40 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def compute_best_of_k(pos_scores, K_values=None, threshold=0.7):
|
| 45 |
+
"""
|
| 46 |
+
Simulate best-of-K selection: what fraction of draws contain at least one good binder?
|
| 47 |
+
Assumes pos_scores are from a distribution of candidate binders for goal state X+.
|
| 48 |
+
"""
|
| 49 |
+
if K_values is None:
|
| 50 |
+
K_values = [1, 2, 5, 10, 20, 50, 100]
|
| 51 |
+
results = {}
|
| 52 |
+
n = len(pos_scores)
|
| 53 |
+
n_trials = 1000
|
| 54 |
+
|
| 55 |
+
for K in K_values:
|
| 56 |
+
successes = 0
|
| 57 |
+
for _ in range(n_trials):
|
| 58 |
+
idxs = np.random.choice(n, size=min(K, n), replace=False)
|
| 59 |
+
best_score = pos_scores[idxs].max()
|
| 60 |
+
if best_score >= threshold:
|
| 61 |
+
successes += 1
|
| 62 |
+
results[K] = successes / n_trials
|
| 63 |
+
|
| 64 |
+
return results
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def compute_selectivity_margin(pos_scores, neg_scores):
|
| 68 |
+
"""Compute per-sample selectivity margin S_theta."""
|
| 69 |
+
eps = 1e-6
|
| 70 |
+
pos_logit = np.log(pos_scores.clip(eps, 1-eps) / (1-pos_scores).clip(eps))
|
| 71 |
+
neg_logit = np.log(neg_scores.clip(eps, 1-eps) / (1-neg_scores).clip(eps))
|
| 72 |
+
selectivity = pos_logit - np.log(np.exp(neg_logit) + 1e-8)
|
| 73 |
+
return selectivity
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def plot_score_distributions(pos_scores, neg_scores, decoy_scores=None,
|
| 77 |
+
title='Score Distributions', outpath=None):
|
| 78 |
+
"""Violin plot of score distributions for different complex types."""
|
| 79 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 80 |
+
|
| 81 |
+
data = [pos_scores, neg_scores]
|
| 82 |
+
labels = ['Positive\n(X+, Y)', 'Negative\n(X0, Y)']
|
| 83 |
+
colors = ['#2196F3', '#F44336']
|
| 84 |
+
|
| 85 |
+
if decoy_scores is not None and len(decoy_scores) > 0:
|
| 86 |
+
data.append(decoy_scores)
|
| 87 |
+
labels.append('Decoys\n(X+, Y~)')
|
| 88 |
+
colors.append('#FF9800')
|
| 89 |
+
|
| 90 |
+
parts = ax.violinplot(data, positions=range(len(data)), showmedians=True)
|
| 91 |
+
for i, (pc, c) in enumerate(zip(parts['bodies'], colors)):
|
| 92 |
+
pc.set_facecolor(c)
|
| 93 |
+
pc.set_alpha(0.7)
|
| 94 |
+
|
| 95 |
+
ax.set_xticks(range(len(data)))
|
| 96 |
+
ax.set_xticklabels(labels)
|
| 97 |
+
ax.set_ylabel('Q_theta Score', fontsize=12)
|
| 98 |
+
ax.set_title(title, fontsize=14)
|
| 99 |
+
ax.set_ylim(0, 1)
|
| 100 |
+
ax.axhline(0.5, color='gray', linestyle='--', alpha=0.5, label='Decision boundary')
|
| 101 |
+
ax.legend()
|
| 102 |
+
|
| 103 |
+
# Add mean + std annotations
|
| 104 |
+
for i, (d, c) in enumerate(zip(data, colors)):
|
| 105 |
+
ax.text(i, 0.02, f'μ={d.mean():.2f}\nσ={d.std():.2f}',
|
| 106 |
+
ha='center', fontsize=9, color=c)
|
| 107 |
+
|
| 108 |
+
plt.tight_layout()
|
| 109 |
+
if outpath:
|
| 110 |
+
plt.savefig(outpath, dpi=150, bbox_inches='tight')
|
| 111 |
+
logger.info(f"Saved plot to {outpath}")
|
| 112 |
+
plt.close()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def plot_roc_curve(labels, scores, title='ROC Curve', outpath=None):
|
| 116 |
+
"""Plot ROC curve for positive vs negative classification."""
|
| 117 |
+
fpr, tpr, _ = roc_curve(labels, scores)
|
| 118 |
+
auc = roc_auc_score(labels, scores)
|
| 119 |
+
|
| 120 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
| 121 |
+
ax.plot(fpr, tpr, 'b-', lw=2, label=f'AUC = {auc:.3f}')
|
| 122 |
+
ax.plot([0, 1], [0, 1], 'k--', lw=1)
|
| 123 |
+
ax.set_xlabel('False Positive Rate')
|
| 124 |
+
ax.set_ylabel('True Positive Rate')
|
| 125 |
+
ax.set_title(title)
|
| 126 |
+
ax.legend()
|
| 127 |
+
plt.tight_layout()
|
| 128 |
+
if outpath:
|
| 129 |
+
plt.savefig(outpath, dpi=150, bbox_inches='tight')
|
| 130 |
+
plt.close()
|
| 131 |
+
return auc
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def plot_best_of_k(results, outpath=None):
|
| 135 |
+
"""Plot best-of-K success rate as a function of K."""
|
| 136 |
+
Ks = sorted(results.keys())
|
| 137 |
+
success_rates = [results[K] for K in Ks]
|
| 138 |
+
|
| 139 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 140 |
+
ax.semilogx(Ks, success_rates, 'b-o', lw=2, markersize=8)
|
| 141 |
+
ax.set_xlabel('K (number of candidates)', fontsize=12)
|
| 142 |
+
ax.set_ylabel('Success rate (best score > 0.7)', fontsize=12)
|
| 143 |
+
ax.set_title('Best-of-K Analysis', fontsize=14)
|
| 144 |
+
ax.set_ylim(0, 1.05)
|
| 145 |
+
ax.grid(True, alpha=0.3)
|
| 146 |
+
ax.axhline(0.8, color='red', linestyle='--', alpha=0.5, label='80% success')
|
| 147 |
+
ax.legend()
|
| 148 |
+
plt.tight_layout()
|
| 149 |
+
if outpath:
|
| 150 |
+
plt.savefig(outpath, dpi=150, bbox_inches='tight')
|
| 151 |
+
plt.close()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@torch.no_grad()
|
| 155 |
+
def evaluate(model, loader, device):
|
| 156 |
+
"""Run model on a dataset and collect all predictions."""
|
| 157 |
+
model.eval()
|
| 158 |
+
all_scores, all_labels, all_types, all_pdbs = [], [], [], []
|
| 159 |
+
|
| 160 |
+
for batch in loader:
|
| 161 |
+
esm_feats = batch['esm_feats'].to(device) if 'esm_feats' in batch else None
|
| 162 |
+
scores = model(
|
| 163 |
+
batch['node_feats'].to(device),
|
| 164 |
+
batch['edge_feats'].to(device),
|
| 165 |
+
batch['node_mask'].to(device),
|
| 166 |
+
esm_feats=esm_feats,
|
| 167 |
+
)
|
| 168 |
+
all_scores.extend(scores.cpu().numpy().tolist())
|
| 169 |
+
all_labels.extend(batch['label'].numpy().tolist())
|
| 170 |
+
all_types.extend(batch['type'])
|
| 171 |
+
all_pdbs.extend(batch['pdb'])
|
| 172 |
+
|
| 173 |
+
return (np.array(all_scores), np.array(all_labels),
|
| 174 |
+
np.array(all_types), np.array(all_pdbs))
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def main():
|
| 178 |
+
parser = argparse.ArgumentParser(description='Evaluate Allo-Designer Q_theta scorer')
|
| 179 |
+
parser.add_argument('--target', default='cam',
|
| 180 |
+
help='Target name (cam, abl, era, or any custom target with data in data/processed/)')
|
| 181 |
+
parser.add_argument('--all_targets', action='store_true',
|
| 182 |
+
help='Evaluate on all available targets and produce aggregated results')
|
| 183 |
+
parser.add_argument('--checkpoint', required=True, help='Path to model checkpoint')
|
| 184 |
+
parser.add_argument('--data_dir', default='data/processed')
|
| 185 |
+
parser.add_argument('--split', choices=['val', 'test'], default='test')
|
| 186 |
+
parser.add_argument('--batch_size', type=int, default=32)
|
| 187 |
+
parser.add_argument('--gpu', type=int, default=7)
|
| 188 |
+
parser.add_argument('--outdir', default='results')
|
| 189 |
+
parser.add_argument('--bok_threshold', type=float, default=0.7,
|
| 190 |
+
help='Score threshold for best-of-K (default 0.7; use per-target value for calibrated results)')
|
| 191 |
+
parser.add_argument('--esm_dir', default=None,
|
| 192 |
+
help='Path to ESM-2 embedding cache (auto-detected at <data_dir>/esm2_embeddings if omitted)')
|
| 193 |
+
parser.add_argument('--no_wandb', action='store_true', help='(ignored; here for CLI compatibility)')
|
| 194 |
+
args = parser.parse_args()
|
| 195 |
+
|
| 196 |
+
# Auto-detect ESM dir under data_dir
|
| 197 |
+
if args.esm_dir is None:
|
| 198 |
+
cand = os.path.join(args.data_dir, 'esm2_embeddings')
|
| 199 |
+
if os.path.isdir(cand):
|
| 200 |
+
args.esm_dir = cand
|
| 201 |
+
|
| 202 |
+
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
|
| 203 |
+
os.makedirs(args.outdir, exist_ok=True)
|
| 204 |
+
os.makedirs(f'{args.outdir}/figures', exist_ok=True)
|
| 205 |
+
os.makedirs(f'{args.outdir}/tables', exist_ok=True)
|
| 206 |
+
|
| 207 |
+
# Load model
|
| 208 |
+
state = torch.load(args.checkpoint, map_location=device)
|
| 209 |
+
config = state.get('config', {})
|
| 210 |
+
model = build_model(config).to(device)
|
| 211 |
+
model.load_state_dict(state['model_state'])
|
| 212 |
+
logger.info(f"Loaded model from {args.checkpoint}")
|
| 213 |
+
|
| 214 |
+
# Load dataset
|
| 215 |
+
data_path = os.path.join(args.data_dir, args.target, f'{args.split}.pkl')
|
| 216 |
+
if not os.path.exists(data_path):
|
| 217 |
+
logger.error(f"Data not found: {data_path}")
|
| 218 |
+
sys.exit(1)
|
| 219 |
+
|
| 220 |
+
dataset = TwoStateComplexDataset(data_path, max_nodes=128,
|
| 221 |
+
esm_dir=args.esm_dir, target_name=args.target)
|
| 222 |
+
loader = DataLoader(
|
| 223 |
+
dataset, batch_size=args.batch_size, shuffle=False,
|
| 224 |
+
num_workers=2, collate_fn=collate_fn
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Run evaluation
|
| 228 |
+
logger.info(f"Evaluating on {len(dataset)} samples...")
|
| 229 |
+
scores, labels, types, pdbs = evaluate(model, loader, device)
|
| 230 |
+
|
| 231 |
+
# Separate by type
|
| 232 |
+
pos_mask = (types == 'positive')
|
| 233 |
+
neg_apo_mask = (types == 'negative_apo')
|
| 234 |
+
decoy_mask = np.array(['decoy' in t for t in types])
|
| 235 |
+
|
| 236 |
+
pos_scores = scores[pos_mask]
|
| 237 |
+
neg_scores = scores[neg_apo_mask]
|
| 238 |
+
decoy_scores = scores[decoy_mask]
|
| 239 |
+
|
| 240 |
+
logger.info(f"\n{'='*50}")
|
| 241 |
+
logger.info(f"Results for {args.target} ({args.split})")
|
| 242 |
+
logger.info(f"{'='*50}")
|
| 243 |
+
logger.info(f"Positive samples: {pos_mask.sum()}")
|
| 244 |
+
logger.info(f"Negative (apo) samples: {neg_apo_mask.sum()}")
|
| 245 |
+
logger.info(f"Decoy samples: {decoy_mask.sum()}")
|
| 246 |
+
|
| 247 |
+
# --- Core metrics ---
|
| 248 |
+
metrics = {}
|
| 249 |
+
|
| 250 |
+
# 1. Spearman correlation with DockQ labels
|
| 251 |
+
sp, p_val = spearmanr(scores, labels)
|
| 252 |
+
metrics['spearman_all'] = float(sp)
|
| 253 |
+
metrics['spearman_pval'] = float(p_val)
|
| 254 |
+
logger.info(f"\nSpearman(Q_theta, DockQ): {sp:.3f} (p={p_val:.3e})")
|
| 255 |
+
|
| 256 |
+
# 2. Selectivity gap (positive vs negative_apo)
|
| 257 |
+
if pos_mask.sum() > 0 and neg_apo_mask.sum() > 0:
|
| 258 |
+
gap = float(pos_scores.mean() - neg_scores.mean())
|
| 259 |
+
ranking_acc = float((pos_scores.mean() > neg_scores).mean() if len(neg_scores) > 0 else 0.5)
|
| 260 |
+
metrics['selectivity_gap'] = gap
|
| 261 |
+
metrics['pos_score_mean'] = float(pos_scores.mean())
|
| 262 |
+
metrics['neg_score_mean'] = float(neg_scores.mean())
|
| 263 |
+
metrics['pos_score_std'] = float(pos_scores.std())
|
| 264 |
+
metrics['neg_score_std'] = float(neg_scores.std())
|
| 265 |
+
logger.info(f"Selectivity gap (pos - neg): {gap:.3f}")
|
| 266 |
+
logger.info(f" Pos: {pos_scores.mean():.3f} ± {pos_scores.std():.3f}")
|
| 267 |
+
logger.info(f" Neg: {neg_scores.mean():.3f} ± {neg_scores.std():.3f}")
|
| 268 |
+
|
| 269 |
+
# 3. AUC for positive vs negative
|
| 270 |
+
if pos_mask.sum() > 0 and neg_apo_mask.sum() > 0:
|
| 271 |
+
pn_scores = np.concatenate([pos_scores, neg_scores])
|
| 272 |
+
pn_labels = np.concatenate([np.ones(len(pos_scores)), np.zeros(len(neg_scores))])
|
| 273 |
+
auc = roc_auc_score(pn_labels, pn_scores)
|
| 274 |
+
metrics['auc_pos_vs_neg'] = float(auc)
|
| 275 |
+
logger.info(f"AUC (pos vs neg_apo): {auc:.3f}")
|
| 276 |
+
|
| 277 |
+
# ROC curve
|
| 278 |
+
plot_roc_curve(
|
| 279 |
+
pn_labels, pn_scores,
|
| 280 |
+
title=f'ROC: Positive vs Negative Apo ({args.target.upper()})',
|
| 281 |
+
outpath=f'{args.outdir}/figures/roc_{args.target}_{args.split}.png'
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# 4. AUC for quality classification (DockQ > 0.5)
|
| 285 |
+
binary = (labels > 0.5).astype(int)
|
| 286 |
+
if binary.sum() > 0 and binary.sum() < len(binary):
|
| 287 |
+
auc_quality = roc_auc_score(binary, scores)
|
| 288 |
+
metrics['auc_quality'] = float(auc_quality)
|
| 289 |
+
logger.info(f"AUC (quality>0.5): {auc_quality:.3f}")
|
| 290 |
+
|
| 291 |
+
# 5. Best-of-K analysis
|
| 292 |
+
if len(pos_scores) > 0:
|
| 293 |
+
bok_results = compute_best_of_k(pos_scores, K_values=[1, 2, 5, 10, 20, 50],
|
| 294 |
+
threshold=args.bok_threshold)
|
| 295 |
+
metrics['best_of_k'] = {str(K): float(v) for K, v in bok_results.items()}
|
| 296 |
+
logger.info(f"\nBest-of-K success rates:")
|
| 297 |
+
for K, rate in bok_results.items():
|
| 298 |
+
logger.info(f" K={K:3d}: {rate:.3f}")
|
| 299 |
+
plot_best_of_k(
|
| 300 |
+
bok_results,
|
| 301 |
+
outpath=f'{args.outdir}/figures/best_of_k_{args.target}_{args.split}.png'
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# 6. Score distributions plot
|
| 305 |
+
plot_score_distributions(
|
| 306 |
+
pos_scores if len(pos_scores) > 0 else np.array([]),
|
| 307 |
+
neg_scores if len(neg_scores) > 0 else np.array([]),
|
| 308 |
+
decoy_scores if len(decoy_scores) > 0 else None,
|
| 309 |
+
title=f'Q_theta Score Distributions ({args.target.upper()})',
|
| 310 |
+
outpath=f'{args.outdir}/figures/score_dist_{args.target}_{args.split}.png'
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Save metrics
|
| 314 |
+
out_json = f'{args.outdir}/tables/eval_{args.target}_{args.split}.json'
|
| 315 |
+
with open(out_json, 'w') as f:
|
| 316 |
+
json.dump(metrics, f, indent=2)
|
| 317 |
+
logger.info(f"\nSaved metrics to {out_json}")
|
| 318 |
+
|
| 319 |
+
# Print summary table
|
| 320 |
+
logger.info(f"\n{'='*50}")
|
| 321 |
+
logger.info("SUMMARY TABLE")
|
| 322 |
+
logger.info(f"{'='*50}")
|
| 323 |
+
logger.info(f"{'Metric':<30} {'Value':>10}")
|
| 324 |
+
logger.info(f"{'-'*42}")
|
| 325 |
+
for k, v in metrics.items():
|
| 326 |
+
if isinstance(v, float):
|
| 327 |
+
logger.info(f"{k:<30} {v:>10.4f}")
|
| 328 |
+
logger.info(f"{'='*50}")
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
if __name__ == '__main__':
|
| 332 |
+
main()
|
code/scripts/pxdesign_guidance/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# PXDesign + Q_theta guidance integration
|
code/scripts/pxdesign_guidance/convert_cif_to_pdb.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Convert PXDesign CIF outputs to PDB format for evaluation pipeline.
|
| 3 |
+
|
| 4 |
+
PXDesign outputs .cif files with:
|
| 5 |
+
- Chain IDs like A0/B0 (multi-char, not PDB-compatible)
|
| 6 |
+
- Non-standard residue name 'xpb' for designed binder residues
|
| 7 |
+
|
| 8 |
+
This script converts them to PDB format with:
|
| 9 |
+
- Single-char chain IDs (A, B)
|
| 10 |
+
- Preserved residue names (xpb is kept; eval tools handle it)
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python code/scripts/pxdesign_guidance/convert_cif_to_pdb.py
|
| 14 |
+
"""
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
from glob import glob
|
| 18 |
+
|
| 19 |
+
from Bio.PDB import MMCIFParser, PDBIO, Select
|
| 20 |
+
|
| 21 |
+
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 22 |
+
_PROJECT_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '../../..'))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ChainRenamer(Select):
|
| 26 |
+
"""Rename multi-char chain IDs to single-char for PDB format."""
|
| 27 |
+
def __init__(self, chain_map):
|
| 28 |
+
self.chain_map = chain_map
|
| 29 |
+
|
| 30 |
+
def accept_chain(self, chain):
|
| 31 |
+
return 1
|
| 32 |
+
|
| 33 |
+
def accept_residue(self, residue):
|
| 34 |
+
return 1
|
| 35 |
+
|
| 36 |
+
def accept_atom(self, atom):
|
| 37 |
+
return 1
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def convert_cif_to_pdb(cif_path, pdb_path):
|
| 41 |
+
"""Convert a single CIF file to PDB format."""
|
| 42 |
+
parser = MMCIFParser(QUIET=True)
|
| 43 |
+
structure = parser.get_structure('s', cif_path)
|
| 44 |
+
model = structure[0]
|
| 45 |
+
|
| 46 |
+
# Build chain ID mapping (A0->A, B0->B, etc.)
|
| 47 |
+
chain_map = {}
|
| 48 |
+
used_ids = set()
|
| 49 |
+
for chain in model.get_chains():
|
| 50 |
+
old_id = chain.id
|
| 51 |
+
# Use first character
|
| 52 |
+
new_id = old_id[0] if old_id else 'A'
|
| 53 |
+
# Avoid duplicates
|
| 54 |
+
while new_id in used_ids:
|
| 55 |
+
new_id = chr(ord(new_id) + 1)
|
| 56 |
+
used_ids.add(new_id)
|
| 57 |
+
chain_map[old_id] = new_id
|
| 58 |
+
|
| 59 |
+
# Rename chains and fix non-standard residue names
|
| 60 |
+
chains_to_rename = list(model.get_chains())
|
| 61 |
+
for chain in chains_to_rename:
|
| 62 |
+
old_id = chain.id
|
| 63 |
+
new_id = chain_map.get(old_id, old_id)
|
| 64 |
+
if old_id != new_id:
|
| 65 |
+
chain.id = new_id
|
| 66 |
+
# Rename 'xpb' residues to 'GLY' (backbone-only binder residues)
|
| 67 |
+
for residue in chain.get_residues():
|
| 68 |
+
if residue.resname.strip().lower() == 'xpb':
|
| 69 |
+
residue.resname = 'GLY'
|
| 70 |
+
|
| 71 |
+
# Write PDB
|
| 72 |
+
io = PDBIO()
|
| 73 |
+
io.set_structure(structure)
|
| 74 |
+
io.save(pdb_path)
|
| 75 |
+
return True
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def convert_directory(src_dir, method_name):
|
| 79 |
+
"""Convert all CIF files in a directory tree to PDB."""
|
| 80 |
+
cif_files = sorted(glob(os.path.join(src_dir, '**/*.cif'), recursive=True))
|
| 81 |
+
cif_files = [f for f in cif_files if 'sample' in os.path.basename(f).lower()]
|
| 82 |
+
|
| 83 |
+
if not cif_files:
|
| 84 |
+
print(f" No CIF files found in {src_dir}")
|
| 85 |
+
return 0
|
| 86 |
+
|
| 87 |
+
# Create converted_pdbs directory
|
| 88 |
+
converted_dir = os.path.join(src_dir, 'converted_pdbs')
|
| 89 |
+
os.makedirs(converted_dir, exist_ok=True)
|
| 90 |
+
|
| 91 |
+
n_converted = 0
|
| 92 |
+
for cif_path in cif_files:
|
| 93 |
+
basename = os.path.basename(cif_path).replace('.cif', '.pdb')
|
| 94 |
+
# For TDS/SMC with round subdirs, include round info
|
| 95 |
+
rel_path = os.path.relpath(cif_path, src_dir)
|
| 96 |
+
parts = rel_path.split(os.sep)
|
| 97 |
+
if any(p.startswith('round_') for p in parts):
|
| 98 |
+
round_part = [p for p in parts if p.startswith('round_')][0]
|
| 99 |
+
basename = f"{round_part}_{basename}"
|
| 100 |
+
|
| 101 |
+
pdb_path = os.path.join(converted_dir, basename)
|
| 102 |
+
try:
|
| 103 |
+
convert_cif_to_pdb(cif_path, pdb_path)
|
| 104 |
+
n_converted += 1
|
| 105 |
+
except Exception as e:
|
| 106 |
+
print(f" Failed {cif_path}: {e}")
|
| 107 |
+
|
| 108 |
+
print(f" Converted {n_converted}/{len(cif_files)} CIF -> PDB in {converted_dir}")
|
| 109 |
+
return n_converted
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def main():
|
| 113 |
+
methods = {
|
| 114 |
+
'pxdesign_guided': os.path.join(_PROJECT_DIR, 'results/pxdesign_guided'),
|
| 115 |
+
'pxdesign_tds': os.path.join(_PROJECT_DIR, 'results/pxdesign_tds'),
|
| 116 |
+
'pxdesign_smc': os.path.join(_PROJECT_DIR, 'results/pxdesign_smc'),
|
| 117 |
+
}
|
| 118 |
+
# Langevin outputs are already PDB (post-hoc refinement)
|
| 119 |
+
|
| 120 |
+
total = 0
|
| 121 |
+
for name, src_dir in methods.items():
|
| 122 |
+
print(f"\n{name}:")
|
| 123 |
+
if os.path.exists(src_dir):
|
| 124 |
+
total += convert_directory(src_dir, name)
|
| 125 |
+
else:
|
| 126 |
+
print(f" Directory not found: {src_dir}")
|
| 127 |
+
|
| 128 |
+
print(f"\nTotal converted: {total}")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == '__main__':
|
| 132 |
+
main()
|
code/scripts/pxdesign_guidance/guided_pxdesign.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PXDesign + Q_theta Classifier Guidance.
|
| 3 |
+
|
| 4 |
+
Monkey-patches PXDesign's diffusion sampling loop to inject Q_theta selectivity
|
| 5 |
+
gradient after each denoising step. This steers the diffusion trajectory toward
|
| 6 |
+
binder backbones that are conformationally selective.
|
| 7 |
+
|
| 8 |
+
The patched diffusion loop:
|
| 9 |
+
x_denoised = denoise_net(x_noisy, t_hat, ...)
|
| 10 |
+
grad = ∇_{x_denoised}[Q(holo,Y) - Q(apo,Y)] # <-- INJECTED
|
| 11 |
+
x_denoised = x_denoised + scale(t) * grad # <-- INJECTED
|
| 12 |
+
delta = (x_noisy - x_denoised) / t_hat
|
| 13 |
+
x_l = x_noisy + eta * dt * delta
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python code/scripts/pxdesign_guidance/guided_pxdesign.py \
|
| 17 |
+
--input experiments/pxdesign_cam/output/cam_binder.json \
|
| 18 |
+
--qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
|
| 19 |
+
--ref_holo data/pdbs/cam_holo/3CLN.pdb \
|
| 20 |
+
--ref_apo data/pdbs/cam_apo/1CFD.pdb \
|
| 21 |
+
--guidance_scale 1.0 \
|
| 22 |
+
--N_sample 50 --N_step 400 \
|
| 23 |
+
--gpu 0
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
import argparse
|
| 29 |
+
import json
|
| 30 |
+
import logging
|
| 31 |
+
import time
|
| 32 |
+
import shutil
|
| 33 |
+
from typing import Callable, Optional, Union
|
| 34 |
+
from functools import partial
|
| 35 |
+
|
| 36 |
+
import numpy as np
|
| 37 |
+
import torch
|
| 38 |
+
|
| 39 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
# ── Paths ────────────────────────────────────────────────────────────────────
|
| 43 |
+
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 44 |
+
_ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
|
| 45 |
+
_ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
|
| 46 |
+
_PXDESIGN_DIR = os.environ.get('PXDESIGN_DIR', '')
|
| 47 |
+
|
| 48 |
+
if _ALLO_CODE_DIR not in sys.path:
|
| 49 |
+
sys.path.insert(0, _ALLO_CODE_DIR)
|
| 50 |
+
if _PXDESIGN_DIR not in sys.path:
|
| 51 |
+
sys.path.insert(0, _PXDESIGN_DIR)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def guided_sample_diffusion(
|
| 55 |
+
denoise_net: Callable,
|
| 56 |
+
input_feature_dict: dict,
|
| 57 |
+
s_inputs: torch.Tensor,
|
| 58 |
+
s_trunk: torch.Tensor,
|
| 59 |
+
z_trunk: torch.Tensor,
|
| 60 |
+
noise_schedule: torch.Tensor,
|
| 61 |
+
N_sample: int = 1,
|
| 62 |
+
gamma0: float = 0.8,
|
| 63 |
+
gamma_min: float = 1.0,
|
| 64 |
+
noise_scale_lambda: float = 1.003,
|
| 65 |
+
step_scale_eta: Union[float, dict] = {"type": "const", "min": 1.5, "max": 1.5},
|
| 66 |
+
diffusion_chunk_size: Optional[int] = None,
|
| 67 |
+
inplace_safe: bool = False,
|
| 68 |
+
attn_chunk_size: Optional[int] = None,
|
| 69 |
+
# Guidance parameters (injected via partial)
|
| 70 |
+
guidance_module=None,
|
| 71 |
+
guidance_scale: float = 1.0,
|
| 72 |
+
guidance_start: float = 0.8,
|
| 73 |
+
guidance_end: float = 0.1,
|
| 74 |
+
) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
Modified PXDesign sample_diffusion with Q_theta classifier guidance.
|
| 77 |
+
|
| 78 |
+
Same as original generator.sample_diffusion but with gradient injection
|
| 79 |
+
after each denoising step. The gradient is scaled by a schedule that
|
| 80 |
+
applies stronger guidance at high noise levels (early steps).
|
| 81 |
+
"""
|
| 82 |
+
from protenix.model.utils import centre_random_augmentation
|
| 83 |
+
|
| 84 |
+
N_atom = input_feature_dict["atom_to_token_idx"].size(-1)
|
| 85 |
+
batch_shape = s_inputs.shape[:-2]
|
| 86 |
+
device = s_inputs.device
|
| 87 |
+
dtype = s_inputs.dtype
|
| 88 |
+
|
| 89 |
+
logger.info(f"Guided sampling: scale={guidance_scale}, "
|
| 90 |
+
f"window=[{guidance_end:.1f}, {guidance_start:.1f}]")
|
| 91 |
+
|
| 92 |
+
def _chunk_sample_diffusion_guided(chunk_n_sample, inplace_safe):
|
| 93 |
+
x_l = noise_schedule[0] * torch.randn(
|
| 94 |
+
size=(*batch_shape, chunk_n_sample, N_atom, 3),
|
| 95 |
+
device=device, dtype=dtype
|
| 96 |
+
)
|
| 97 |
+
T = len(noise_schedule)
|
| 98 |
+
|
| 99 |
+
for step_t, (c_tau_last, c_tau) in enumerate(
|
| 100 |
+
zip(noise_schedule[:-1], noise_schedule[1:])
|
| 101 |
+
):
|
| 102 |
+
# Centre random augmentation
|
| 103 |
+
x_l = (
|
| 104 |
+
centre_random_augmentation(x_input_coords=x_l, N_sample=1)
|
| 105 |
+
.squeeze(dim=-3)
|
| 106 |
+
.to(dtype)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Predictor step: add noise
|
| 110 |
+
gamma = float(gamma0) if c_tau > gamma_min else 0
|
| 111 |
+
t_hat = c_tau_last * (gamma + 1)
|
| 112 |
+
delta_noise_level = torch.sqrt(t_hat**2 - c_tau_last**2)
|
| 113 |
+
x_noisy = x_l + noise_scale_lambda * delta_noise_level * torch.randn(
|
| 114 |
+
size=x_l.shape, device=device, dtype=dtype
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# Reshape t_hat for network
|
| 118 |
+
t_hat_tensor = (
|
| 119 |
+
t_hat.reshape((1,) * (len(batch_shape) + 1))
|
| 120 |
+
.expand(*batch_shape, chunk_n_sample)
|
| 121 |
+
.to(dtype)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Denoise
|
| 125 |
+
x_denoised = denoise_net(
|
| 126 |
+
x_noisy=x_noisy,
|
| 127 |
+
t_hat_noise_level=t_hat_tensor,
|
| 128 |
+
input_feature_dict=input_feature_dict,
|
| 129 |
+
s_inputs=s_inputs,
|
| 130 |
+
s_trunk=s_trunk,
|
| 131 |
+
z_trunk=z_trunk,
|
| 132 |
+
chunk_size=attn_chunk_size,
|
| 133 |
+
inplace_safe=inplace_safe,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# ── Q_theta guidance injection ──────────────────────────────
|
| 137 |
+
if guidance_module is not None:
|
| 138 |
+
# Compute progress fraction (0=start/high noise, 1=end/low noise)
|
| 139 |
+
progress = step_t / (T - 1) if T > 1 else 1.0
|
| 140 |
+
|
| 141 |
+
# Apply guidance only within the specified window
|
| 142 |
+
if guidance_end <= (1.0 - progress) <= guidance_start:
|
| 143 |
+
# Handle batch dimensions
|
| 144 |
+
x_for_grad = x_denoised
|
| 145 |
+
if x_for_grad.dim() > 3:
|
| 146 |
+
x_for_grad = x_for_grad.squeeze(0)
|
| 147 |
+
|
| 148 |
+
# Scale: stronger at high noise, weaker near convergence
|
| 149 |
+
noise_fraction = 1.0 - progress
|
| 150 |
+
scale = guidance_scale * noise_fraction
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
# Compute gradient for first sample (or all if small batch)
|
| 154 |
+
n_guide = min(chunk_n_sample, 4)
|
| 155 |
+
grad_accum = torch.zeros_like(x_for_grad)
|
| 156 |
+
|
| 157 |
+
for si in range(n_guide):
|
| 158 |
+
grad, margin = guidance_module.compute_guidance_gradient(
|
| 159 |
+
x_for_grad, input_feature_dict,
|
| 160 |
+
t_hat=t_hat, sample_idx=si
|
| 161 |
+
)
|
| 162 |
+
grad_accum[si] = grad[si] if grad.shape[0] > si else grad[0]
|
| 163 |
+
|
| 164 |
+
# Broadcast gradient to remaining samples
|
| 165 |
+
if n_guide < chunk_n_sample and n_guide > 0:
|
| 166 |
+
avg_grad = grad_accum[:n_guide].mean(dim=0, keepdim=True)
|
| 167 |
+
grad_accum[n_guide:] = avg_grad.expand(
|
| 168 |
+
chunk_n_sample - n_guide, -1, -1)
|
| 169 |
+
|
| 170 |
+
# Normalize gradient to prevent explosion
|
| 171 |
+
grad_norm = grad_accum.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 172 |
+
grad_normalized = grad_accum / grad_norm
|
| 173 |
+
avg_norm = grad_norm.mean().item()
|
| 174 |
+
|
| 175 |
+
# Apply guidance
|
| 176 |
+
if avg_norm > 1e-6:
|
| 177 |
+
# Scale by average gradient magnitude to keep step size reasonable
|
| 178 |
+
x_denoised = x_denoised + scale * avg_norm * grad_normalized
|
| 179 |
+
|
| 180 |
+
if step_t % 50 == 0:
|
| 181 |
+
logger.info(
|
| 182 |
+
f" Step {step_t}/{T}: margin={margin:.3f}, "
|
| 183 |
+
f"grad_norm={avg_norm:.4f}, scale={scale:.3f}")
|
| 184 |
+
except Exception as e:
|
| 185 |
+
if step_t % 100 == 0:
|
| 186 |
+
logger.debug(f" Step {step_t}: guidance failed: {e}")
|
| 187 |
+
# ── End guidance ────────────────────────────────────────────
|
| 188 |
+
|
| 189 |
+
# Euler step
|
| 190 |
+
delta = (x_noisy - x_denoised) / t_hat_tensor[..., None, None]
|
| 191 |
+
dt = c_tau - t_hat_tensor
|
| 192 |
+
if isinstance(step_scale_eta, float):
|
| 193 |
+
eta = step_scale_eta
|
| 194 |
+
elif step_scale_eta["type"] == "const":
|
| 195 |
+
assert step_scale_eta["min"] == step_scale_eta["max"]
|
| 196 |
+
eta = step_scale_eta["min"]
|
| 197 |
+
else:
|
| 198 |
+
eta_min, eta_max = step_scale_eta["min"], step_scale_eta["max"]
|
| 199 |
+
if step_scale_eta["type"] == "linear":
|
| 200 |
+
eta = eta_min + (eta_max - eta_min) * (step_t / T)
|
| 201 |
+
elif step_scale_eta["type"] == "poly":
|
| 202 |
+
eta = eta_min + (eta_max - eta_min) * (step_t / T) ** 2
|
| 203 |
+
elif step_scale_eta["type"] == "cos":
|
| 204 |
+
eta = eta_min + 0.5 * (eta_max - eta_min) * (
|
| 205 |
+
1 - np.cos(np.pi * step_t / T))
|
| 206 |
+
elif step_scale_eta["type"] == "piecewise":
|
| 207 |
+
eta = eta_min if step_t / T < 0.5 else eta_max
|
| 208 |
+
elif step_scale_eta["type"] == "piecewise_65":
|
| 209 |
+
eta = eta_min if step_t / T < 0.65 else eta_max
|
| 210 |
+
elif step_scale_eta["type"] == "piecewise_70":
|
| 211 |
+
eta = eta_min if step_t / T < 0.70 else eta_max
|
| 212 |
+
else:
|
| 213 |
+
raise ValueError("Unsupported eta schedule!")
|
| 214 |
+
x_l = x_noisy + eta * dt[..., None, None] * delta
|
| 215 |
+
|
| 216 |
+
return x_l
|
| 217 |
+
|
| 218 |
+
# Chunked sampling
|
| 219 |
+
if diffusion_chunk_size is None:
|
| 220 |
+
x_l = _chunk_sample_diffusion_guided(N_sample, inplace_safe=inplace_safe)
|
| 221 |
+
else:
|
| 222 |
+
x_l = []
|
| 223 |
+
no_chunks = N_sample // diffusion_chunk_size + (
|
| 224 |
+
N_sample % diffusion_chunk_size != 0)
|
| 225 |
+
for i in range(no_chunks):
|
| 226 |
+
chunk_n_sample = (
|
| 227 |
+
diffusion_chunk_size
|
| 228 |
+
if i < no_chunks - 1
|
| 229 |
+
else N_sample - i * diffusion_chunk_size
|
| 230 |
+
)
|
| 231 |
+
chunk_x_l = _chunk_sample_diffusion_guided(
|
| 232 |
+
chunk_n_sample, inplace_safe=inplace_safe)
|
| 233 |
+
x_l.append(chunk_x_l)
|
| 234 |
+
x_l = torch.cat(x_l, -3)
|
| 235 |
+
|
| 236 |
+
return x_l
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def run_guided_pxdesign(args):
|
| 240 |
+
"""Run PXDesign with Q_theta classifier guidance."""
|
| 241 |
+
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
|
| 242 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
| 243 |
+
|
| 244 |
+
# Import PXDesign components
|
| 245 |
+
from pxdesign.runner.inference import InferenceRunner, main as pxdesign_main
|
| 246 |
+
from pxdesign.utils.infer import (
|
| 247 |
+
get_configs, convert_to_bioassembly_dict, download_inference_cache, derive_seed
|
| 248 |
+
)
|
| 249 |
+
from pxdesign.utils.inputs import process_input_file
|
| 250 |
+
from protenix.config import save_config
|
| 251 |
+
from protenix.utils.seed import seed_everything
|
| 252 |
+
from protenix.utils.torch_utils import autocasting_disable_decorator
|
| 253 |
+
|
| 254 |
+
from qtheta_pxdesign import QThetaPXDesignGuidance
|
| 255 |
+
|
| 256 |
+
# Set up output directory
|
| 257 |
+
outdir = args.outdir if os.path.isabs(args.outdir) else os.path.join(_ALLO_ROOT, args.outdir)
|
| 258 |
+
os.makedirs(outdir, exist_ok=True)
|
| 259 |
+
|
| 260 |
+
# Build PXDesign CLI arguments
|
| 261 |
+
pxdesign_argv = [
|
| 262 |
+
'--dump_dir', outdir,
|
| 263 |
+
'--input', args.input,
|
| 264 |
+
'--dtype', 'bf16',
|
| 265 |
+
'--N_sample', str(args.N_sample),
|
| 266 |
+
'--N_step', str(args.N_step),
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
configs = get_configs(pxdesign_argv)
|
| 270 |
+
configs.input_json_path = process_input_file(
|
| 271 |
+
configs.input_json_path, out_dir=outdir)
|
| 272 |
+
download_inference_cache(configs)
|
| 273 |
+
|
| 274 |
+
# Convert inputs
|
| 275 |
+
save_config(configs, os.path.join(outdir, "config.yaml"))
|
| 276 |
+
with open(configs.input_json_path, "r") as f:
|
| 277 |
+
orig_inputs = json.load(f)
|
| 278 |
+
for x in orig_inputs:
|
| 279 |
+
convert_to_bioassembly_dict(x, outdir)
|
| 280 |
+
configs.input_json_path = os.path.join(outdir, "input_tasks.json")
|
| 281 |
+
with open(configs.input_json_path, "w") as f:
|
| 282 |
+
json.dump(orig_inputs, f, indent=4)
|
| 283 |
+
|
| 284 |
+
# Create runner
|
| 285 |
+
runner = InferenceRunner(configs)
|
| 286 |
+
|
| 287 |
+
# Initialize Q_theta guidance
|
| 288 |
+
guidance = QThetaPXDesignGuidance(
|
| 289 |
+
checkpoint=args.qtheta_checkpoint if os.path.isabs(args.qtheta_checkpoint) else os.path.join(_ALLO_ROOT, args.qtheta_checkpoint),
|
| 290 |
+
ref_holo=args.ref_holo if os.path.isabs(args.ref_holo) else os.path.join(_ALLO_ROOT, args.ref_holo),
|
| 291 |
+
ref_apo=args.ref_apo if os.path.isabs(args.ref_apo) else os.path.join(_ALLO_ROOT, args.ref_apo),
|
| 292 |
+
ref_chain=args.ref_chain,
|
| 293 |
+
device='cuda:0', # After CUDA_VISIBLE_DEVICES remapping
|
| 294 |
+
esm_target=args.esm_target,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# Monkey-patch the sample_diffusion function
|
| 298 |
+
from pxdesign.model import generator as pxdesign_generator
|
| 299 |
+
import pxdesign.model.pxdesign as pxdesign_model
|
| 300 |
+
|
| 301 |
+
# Create guided version with guidance params bound
|
| 302 |
+
guided_fn = partial(
|
| 303 |
+
guided_sample_diffusion,
|
| 304 |
+
guidance_module=guidance,
|
| 305 |
+
guidance_scale=args.guidance_scale,
|
| 306 |
+
guidance_start=args.guidance_start,
|
| 307 |
+
guidance_end=args.guidance_end,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Patch the module-level function in generator.py
|
| 311 |
+
pxdesign_generator.sample_diffusion = guided_fn
|
| 312 |
+
|
| 313 |
+
# CRITICAL: pxdesign.py does `from pxdesign.model.generator import sample_diffusion`
|
| 314 |
+
# which creates a local binding in pxdesign.model.pxdesign namespace.
|
| 315 |
+
# We must patch that local binding too, otherwise the ProtenixDesign.sample_diffusion()
|
| 316 |
+
# method will still call the original unpatched function.
|
| 317 |
+
pxdesign_model.sample_diffusion = guided_fn
|
| 318 |
+
|
| 319 |
+
logger.info("PXDesign diffusion loop patched with Q_theta guidance")
|
| 320 |
+
|
| 321 |
+
# Run inference
|
| 322 |
+
seeds = [derive_seed(time.time_ns())] if not configs.seeds else configs.seeds
|
| 323 |
+
for seed in seeds:
|
| 324 |
+
logger.info(f"Running guided inference with seed {seed}")
|
| 325 |
+
seed_everything(seed=seed, deterministic=False)
|
| 326 |
+
runner._inference(seed)
|
| 327 |
+
|
| 328 |
+
# Score all generated designs
|
| 329 |
+
logger.info("Scoring generated designs...")
|
| 330 |
+
from glob import glob
|
| 331 |
+
|
| 332 |
+
pdb_dir = outdir
|
| 333 |
+
pdbs = []
|
| 334 |
+
for ext in ('*.pdb', '*.cif'):
|
| 335 |
+
pdbs.extend(glob(os.path.join(pdb_dir, '**/' + ext), recursive=True))
|
| 336 |
+
pdbs = sorted([p for p in pdbs if 'sample' in os.path.basename(p).lower()])
|
| 337 |
+
|
| 338 |
+
results = []
|
| 339 |
+
for i, pdb_path in enumerate(pdbs):
|
| 340 |
+
design_id = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
|
| 341 |
+
result = guidance.score_design(pdb_path)
|
| 342 |
+
if result is not None:
|
| 343 |
+
result['design_id'] = design_id
|
| 344 |
+
result['pdb_path'] = pdb_path
|
| 345 |
+
results.append(result)
|
| 346 |
+
logger.info(
|
| 347 |
+
f"[{i+1}/{len(pdbs)}] {design_id}: "
|
| 348 |
+
f"Q+={result['q_holo']:.3f} Q-={result['q_apo']:.3f} "
|
| 349 |
+
f"S={result['margin']:+.3f}")
|
| 350 |
+
|
| 351 |
+
# Save results
|
| 352 |
+
if results:
|
| 353 |
+
results.sort(key=lambda x: x['margin'], reverse=True)
|
| 354 |
+
margins = np.array([r['margin'] for r in results])
|
| 355 |
+
|
| 356 |
+
summary = {
|
| 357 |
+
'method': 'PXDesign + Classifier Guidance',
|
| 358 |
+
'n_designs': len(results),
|
| 359 |
+
'guidance_scale': args.guidance_scale,
|
| 360 |
+
'guidance_window': [args.guidance_end, args.guidance_start],
|
| 361 |
+
'margin_mean': float(margins.mean()),
|
| 362 |
+
'margin_std': float(margins.std()),
|
| 363 |
+
'frac_positive': float((margins > 0).mean()),
|
| 364 |
+
'q_holo_mean': float(np.mean([r['q_holo'] for r in results])),
|
| 365 |
+
'q_apo_mean': float(np.mean([r['q_apo'] for r in results])),
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
with open(os.path.join(outdir, 'guided_scores.json'), 'w') as f:
|
| 369 |
+
json.dump(results, f, indent=2)
|
| 370 |
+
with open(os.path.join(outdir, 'guided_summary.json'), 'w') as f:
|
| 371 |
+
json.dump(summary, f, indent=2)
|
| 372 |
+
|
| 373 |
+
logger.info(f"\n{'='*60}")
|
| 374 |
+
logger.info(f"PXDesign + Classifier Guidance Results ({len(results)} designs)")
|
| 375 |
+
logger.info(f" Margin: {margins.mean():.3f} ± {margins.std():.3f}")
|
| 376 |
+
logger.info(f" Fraction S > 0: {(margins > 0).mean():.1%}")
|
| 377 |
+
logger.info(f" Q(holo) mean: {summary['q_holo_mean']:.3f}")
|
| 378 |
+
logger.info(f"{'='*60}")
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def main():
|
| 382 |
+
parser = argparse.ArgumentParser(description='PXDesign + Q_theta Classifier Guidance')
|
| 383 |
+
parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json',
|
| 384 |
+
help='PXDesign input JSON')
|
| 385 |
+
parser.add_argument('--qtheta_checkpoint',
|
| 386 |
+
default='results/checkpoints_cam_v3/best_phase2.pt')
|
| 387 |
+
parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
|
| 388 |
+
parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
|
| 389 |
+
parser.add_argument('--ref_chain', default='A')
|
| 390 |
+
parser.add_argument('--guidance_scale', type=float, default=1.0,
|
| 391 |
+
help='Guidance gradient scale')
|
| 392 |
+
parser.add_argument('--guidance_start', type=float, default=0.8,
|
| 393 |
+
help='Start guidance at this noise fraction (high noise)')
|
| 394 |
+
parser.add_argument('--guidance_end', type=float, default=0.1,
|
| 395 |
+
help='Stop guidance at this noise fraction (low noise)')
|
| 396 |
+
parser.add_argument('--N_sample', type=int, default=50)
|
| 397 |
+
parser.add_argument('--N_step', type=int, default=400)
|
| 398 |
+
parser.add_argument('--gpu', type=int, default=0)
|
| 399 |
+
parser.add_argument('--outdir', default='results/pxdesign_guided')
|
| 400 |
+
parser.add_argument('--esm_target', default='cam',
|
| 401 |
+
help='Subdir under data/esm2_embeddings (e.g., adk, cam)')
|
| 402 |
+
args = parser.parse_args()
|
| 403 |
+
|
| 404 |
+
run_guided_pxdesign(args)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
if __name__ == '__main__':
|
| 408 |
+
main()
|
code/scripts/pxdesign_guidance/iterative_refinement.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Iterative Refinement via Langevin Noise-Refine Cycles.
|
| 3 |
+
|
| 4 |
+
Inspired by ProDifEvo (Uehara et al., ICML 2025): repeatedly perturb and
|
| 5 |
+
refine structures through Q_theta gradient ascent. Each cycle adds noise
|
| 6 |
+
for diversity, then refines with Langevin dynamics toward higher selectivity.
|
| 7 |
+
|
| 8 |
+
This allows designs to escape local optima and explore better selectivity
|
| 9 |
+
regions that single-shot generation cannot reach.
|
| 10 |
+
|
| 11 |
+
Pipeline:
|
| 12 |
+
1. Start from existing PXDesign outputs (seed structures)
|
| 13 |
+
2. Align binder to reference receptor frames
|
| 14 |
+
3. Run Langevin refinement with Q_theta gradient
|
| 15 |
+
4. Score the refined output
|
| 16 |
+
5. Repeat for K iterations, keeping best designs
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
python code/scripts/pxdesign_guidance/iterative_refinement.py \
|
| 20 |
+
--input_dir results/pxdesign_guided/converted_pdbs \
|
| 21 |
+
--qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
|
| 22 |
+
--ref_holo data/pdbs/cam_holo/3CLN.pdb \
|
| 23 |
+
--ref_apo data/pdbs/cam_apo/1CFD.pdb \
|
| 24 |
+
--n_iterations 3 --n_designs 10 \
|
| 25 |
+
--gpu 6
|
| 26 |
+
"""
|
| 27 |
+
import os
|
| 28 |
+
import sys
|
| 29 |
+
import json
|
| 30 |
+
import logging
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
from glob import glob
|
| 34 |
+
|
| 35 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 39 |
+
_ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
|
| 40 |
+
_ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
|
| 41 |
+
|
| 42 |
+
if _ALLO_CODE_DIR not in sys.path:
|
| 43 |
+
sys.path.insert(0, _ALLO_CODE_DIR)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def score_designs(pdb_paths, guidance):
|
| 47 |
+
"""Score a list of PDB paths with Q_theta."""
|
| 48 |
+
results = []
|
| 49 |
+
for pdb_path in pdb_paths:
|
| 50 |
+
result = guidance.score_design(pdb_path)
|
| 51 |
+
if result is not None:
|
| 52 |
+
result['pdb_path'] = pdb_path
|
| 53 |
+
result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
|
| 54 |
+
results.append(result)
|
| 55 |
+
return results
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def run_langevin_cycle(pdb_paths, guidance, n_steps=50, step_size=0.005,
|
| 59 |
+
iteration=0, outdir='results/iterative_refinement'):
|
| 60 |
+
"""Run Langevin refinement cycle on binder backbone coords using Q_theta.
|
| 61 |
+
|
| 62 |
+
Uses guidance.dq (DifferentiableQTheta) for differentiable scoring.
|
| 63 |
+
Aligns binder to holo/apo reference frames for dual-state scoring.
|
| 64 |
+
"""
|
| 65 |
+
from utils.pdb_utils import (load_structure, get_residues, get_backbone_coords,
|
| 66 |
+
get_aa_indices, align_structures)
|
| 67 |
+
|
| 68 |
+
refined_results = []
|
| 69 |
+
os.makedirs(outdir, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
for pdb_path in pdb_paths:
|
| 72 |
+
try:
|
| 73 |
+
model = load_structure(pdb_path)
|
| 74 |
+
chains = {c.id: c for c in model.get_chains()}
|
| 75 |
+
|
| 76 |
+
binder_chain = None
|
| 77 |
+
for cid in sorted(chains.keys()):
|
| 78 |
+
if cid != 'A':
|
| 79 |
+
binder_chain = cid
|
| 80 |
+
break
|
| 81 |
+
if binder_chain is None:
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
rec_res = get_residues(chains['A'])
|
| 85 |
+
if not rec_res:
|
| 86 |
+
rec_res = get_residues(chains['A'], only_standard=False)
|
| 87 |
+
binder_res = get_residues(chains[binder_chain])
|
| 88 |
+
if not binder_res:
|
| 89 |
+
binder_res = get_residues(chains[binder_chain], only_standard=False)
|
| 90 |
+
if len(binder_res) < 5:
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
binder_coords, binder_mask = get_backbone_coords(binder_res)
|
| 94 |
+
rec_coords, _ = get_backbone_coords(rec_res)
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
aa_idx = get_aa_indices(binder_res)
|
| 98 |
+
except Exception:
|
| 99 |
+
aa_idx = np.zeros(len(binder_res), dtype=np.int64)
|
| 100 |
+
|
| 101 |
+
# Compute alignment transforms
|
| 102 |
+
rec_ca = rec_coords[:, 1, :]
|
| 103 |
+
ref_holo_ca = guidance.ref_holo_ca.cpu().numpy()
|
| 104 |
+
ref_apo_ca = guidance.ref_apo_ca.cpu().numpy()
|
| 105 |
+
n_h = min(len(rec_ca), len(ref_holo_ca))
|
| 106 |
+
n_a = min(len(rec_ca), len(ref_apo_ca))
|
| 107 |
+
if n_h < 5 or n_a < 5:
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
_, R_h = align_structures(rec_ca[:n_h], ref_holo_ca[:n_h])
|
| 111 |
+
center_h = rec_ca[:n_h].mean(0)
|
| 112 |
+
ref_center_h = ref_holo_ca[:n_h].mean(0)
|
| 113 |
+
aligned_holo = (binder_coords.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h
|
| 114 |
+
aligned_holo = aligned_holo.reshape(-1, 4, 3)
|
| 115 |
+
|
| 116 |
+
_, R_a = align_structures(rec_ca[:n_a], ref_apo_ca[:n_a])
|
| 117 |
+
center_a = rec_ca[:n_a].mean(0)
|
| 118 |
+
ref_center_a = ref_apo_ca[:n_a].mean(0)
|
| 119 |
+
|
| 120 |
+
device = guidance.device
|
| 121 |
+
dq = guidance.dq
|
| 122 |
+
|
| 123 |
+
# Precompute alignment tensors (detached constants)
|
| 124 |
+
R_h_t = torch.from_numpy(R_h).float().to(device)
|
| 125 |
+
R_a_t = torch.from_numpy(R_a).float().to(device)
|
| 126 |
+
center_h_t = torch.from_numpy(center_h).float().to(device)
|
| 127 |
+
ref_center_h_t = torch.from_numpy(ref_center_h).float().to(device)
|
| 128 |
+
center_a_t = torch.from_numpy(center_a).float().to(device)
|
| 129 |
+
ref_center_a_t = torch.from_numpy(ref_center_a).float().to(device)
|
| 130 |
+
|
| 131 |
+
# Work in holo-aligned frame
|
| 132 |
+
coords_t = torch.from_numpy(aligned_holo.copy()).float().to(device)
|
| 133 |
+
mask_t = torch.from_numpy(binder_mask).bool().to(device)
|
| 134 |
+
aa_t = torch.from_numpy(aa_idx).long().to(device)
|
| 135 |
+
|
| 136 |
+
# Add noise for diversity (constant, small)
|
| 137 |
+
noise = torch.randn_like(coords_t) * 0.05
|
| 138 |
+
coords_t = coords_t + noise
|
| 139 |
+
|
| 140 |
+
best_margin = -float('inf')
|
| 141 |
+
best_coords = coords_t.clone()
|
| 142 |
+
|
| 143 |
+
def project_bond_lengths(coords, target_dist=3.8, n_iters=5):
|
| 144 |
+
"""Project CA-CA distances to target_dist via SHAKE-like iteration."""
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
for _ in range(n_iters):
|
| 147 |
+
ca = coords[:, 1, :].clone()
|
| 148 |
+
for i in range(len(ca) - 1):
|
| 149 |
+
delta = ca[i+1] - ca[i]
|
| 150 |
+
d = delta.norm()
|
| 151 |
+
if d < 1e-6:
|
| 152 |
+
continue
|
| 153 |
+
correction = 0.5 * (d - target_dist) / d * delta
|
| 154 |
+
coords[i, :, :] += correction.unsqueeze(0)
|
| 155 |
+
coords[i+1, :, :] -= correction.unsqueeze(0)
|
| 156 |
+
return coords
|
| 157 |
+
|
| 158 |
+
for step in range(n_steps):
|
| 159 |
+
coords_t = coords_t.detach().requires_grad_(True)
|
| 160 |
+
|
| 161 |
+
with torch.enable_grad():
|
| 162 |
+
q_holo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t,
|
| 163 |
+
receptor_label='holo')
|
| 164 |
+
|
| 165 |
+
# Transform holo-frame → original → apo-frame
|
| 166 |
+
flat_t = coords_t.reshape(-1, 3)
|
| 167 |
+
original = (flat_t - ref_center_h_t) @ R_h_t + center_h_t
|
| 168 |
+
apo_aligned = (original - center_a_t) @ R_a_t.T + ref_center_a_t
|
| 169 |
+
coords_apo = apo_aligned.reshape(-1, 4, 3)
|
| 170 |
+
|
| 171 |
+
q_apo = dq.score(coords_apo, mask_t, binder_aa_idx=aa_t,
|
| 172 |
+
receptor_label='apo')
|
| 173 |
+
margin = q_holo - q_apo
|
| 174 |
+
margin.backward()
|
| 175 |
+
|
| 176 |
+
grad = coords_t.grad
|
| 177 |
+
if grad is None or torch.isnan(grad).any():
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
grad_norm = grad.norm().clamp(min=1e-8)
|
| 181 |
+
|
| 182 |
+
if margin.item() > best_margin:
|
| 183 |
+
best_margin = margin.item()
|
| 184 |
+
best_coords = coords_t.detach().clone()
|
| 185 |
+
|
| 186 |
+
if step % 10 == 0:
|
| 187 |
+
logger.info(f" [{os.path.basename(pdb_path)}] Step {step}: "
|
| 188 |
+
f"Q+={q_holo.item():.3f} Q-={q_apo.item():.3f} "
|
| 189 |
+
f"S={margin.item():.3f} |g|={grad_norm.item():.4f}")
|
| 190 |
+
|
| 191 |
+
with torch.no_grad():
|
| 192 |
+
coords_t = coords_t + step_size * grad / grad_norm
|
| 193 |
+
# Annealed Langevin noise (small)
|
| 194 |
+
noise_scale = step_size * 0.05 * (1 - step / n_steps)
|
| 195 |
+
coords_t = coords_t + noise_scale * torch.randn_like(coords_t)
|
| 196 |
+
# Hard projection: enforce CA-CA = 3.8A
|
| 197 |
+
coords_t = project_bond_lengths(coords_t)
|
| 198 |
+
|
| 199 |
+
# Write refined backbone PDB
|
| 200 |
+
final_coords = best_coords.detach().cpu().numpy()
|
| 201 |
+
basename = os.path.basename(pdb_path).replace('.pdb', '')
|
| 202 |
+
out_path = os.path.join(outdir, f'{basename}_iter{iteration}.pdb')
|
| 203 |
+
|
| 204 |
+
atom_names = [' N ', ' CA ', ' C ', ' O ']
|
| 205 |
+
elements = ['N', 'C', 'C', 'O']
|
| 206 |
+
with open(out_path, 'w') as f:
|
| 207 |
+
atom_num = 1
|
| 208 |
+
for i in range(len(final_coords)):
|
| 209 |
+
if not binder_mask[i]:
|
| 210 |
+
continue
|
| 211 |
+
for j, (aname, elem) in enumerate(zip(atom_names, elements)):
|
| 212 |
+
x, y, z = final_coords[i, j]
|
| 213 |
+
f.write(f"ATOM {atom_num:5d} {aname} ALA B{i+1:4d} "
|
| 214 |
+
f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 {elem}\n")
|
| 215 |
+
atom_num += 1
|
| 216 |
+
f.write("END\n")
|
| 217 |
+
|
| 218 |
+
# Score refined design
|
| 219 |
+
result = guidance.score_design(out_path)
|
| 220 |
+
if result is not None:
|
| 221 |
+
result['pdb_path'] = out_path
|
| 222 |
+
result['iteration'] = iteration
|
| 223 |
+
result['best_margin_during_opt'] = best_margin
|
| 224 |
+
refined_results.append(result)
|
| 225 |
+
logger.info(f" -> Refined: S={result['margin']:.3f} "
|
| 226 |
+
f"(best during opt: {best_margin:.3f})")
|
| 227 |
+
|
| 228 |
+
except Exception as e:
|
| 229 |
+
logger.warning(f"Failed to refine {pdb_path}: {e}")
|
| 230 |
+
import traceback
|
| 231 |
+
traceback.print_exc()
|
| 232 |
+
|
| 233 |
+
return refined_results
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def main():
|
| 237 |
+
import argparse
|
| 238 |
+
parser = argparse.ArgumentParser()
|
| 239 |
+
parser.add_argument('--input_dir',
|
| 240 |
+
default='results/pxdesign_guided/converted_pdbs')
|
| 241 |
+
parser.add_argument('--qtheta_checkpoint',
|
| 242 |
+
default='results/checkpoints_cam_v3/best_phase2.pt')
|
| 243 |
+
parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
|
| 244 |
+
parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
|
| 245 |
+
parser.add_argument('--ref_chain', default='A')
|
| 246 |
+
parser.add_argument('--n_iterations', type=int, default=4,
|
| 247 |
+
help='Number of refine cycles')
|
| 248 |
+
parser.add_argument('--n_designs', type=int, default=20,
|
| 249 |
+
help='Number of designs to refine')
|
| 250 |
+
parser.add_argument('--n_steps', type=int, default=50,
|
| 251 |
+
help='Langevin steps per iteration')
|
| 252 |
+
parser.add_argument('--step_size', type=float, default=0.005)
|
| 253 |
+
parser.add_argument('--gpu', type=int, default=6)
|
| 254 |
+
parser.add_argument('--outdir', default='results/iterative_refinement')
|
| 255 |
+
args = parser.parse_args()
|
| 256 |
+
|
| 257 |
+
os.chdir(_ALLO_ROOT)
|
| 258 |
+
|
| 259 |
+
from scripts.pxdesign_guidance.qtheta_pxdesign import QThetaPXDesignGuidance
|
| 260 |
+
|
| 261 |
+
outdir = args.outdir
|
| 262 |
+
os.makedirs(outdir, exist_ok=True)
|
| 263 |
+
|
| 264 |
+
# Initialize scorer
|
| 265 |
+
guidance = QThetaPXDesignGuidance(
|
| 266 |
+
checkpoint=args.qtheta_checkpoint,
|
| 267 |
+
ref_holo=args.ref_holo,
|
| 268 |
+
ref_apo=args.ref_apo,
|
| 269 |
+
ref_chain=args.ref_chain,
|
| 270 |
+
device=f'cuda:{args.gpu}',
|
| 271 |
+
)
|
| 272 |
+
guidance._lazy_init()
|
| 273 |
+
|
| 274 |
+
# Collect input designs
|
| 275 |
+
input_pdbs = sorted(glob(os.path.join(args.input_dir, '*.pdb')))[:args.n_designs]
|
| 276 |
+
logger.info(f"Selected {len(input_pdbs)} designs for iterative refinement")
|
| 277 |
+
|
| 278 |
+
# Score initial designs
|
| 279 |
+
logger.info("Scoring initial designs...")
|
| 280 |
+
initial_results = score_designs(input_pdbs, guidance)
|
| 281 |
+
initial_margins = [r['margin'] for r in initial_results]
|
| 282 |
+
logger.info(f"Initial: S={np.mean(initial_margins):.3f}\u00b1{np.std(initial_margins):.3f}")
|
| 283 |
+
|
| 284 |
+
all_iteration_results = {'initial': initial_results}
|
| 285 |
+
|
| 286 |
+
# Iterative refinement
|
| 287 |
+
current_pdbs = input_pdbs
|
| 288 |
+
for iteration in range(args.n_iterations):
|
| 289 |
+
logger.info(f"\n{'='*50}")
|
| 290 |
+
logger.info(f"Iteration {iteration + 1}/{args.n_iterations}")
|
| 291 |
+
logger.info(f"{'='*50}")
|
| 292 |
+
|
| 293 |
+
iter_results = run_langevin_cycle(
|
| 294 |
+
current_pdbs, guidance,
|
| 295 |
+
n_steps=args.n_steps,
|
| 296 |
+
step_size=args.step_size,
|
| 297 |
+
iteration=iteration,
|
| 298 |
+
outdir=outdir,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if iter_results:
|
| 302 |
+
margins = [r['margin'] for r in iter_results]
|
| 303 |
+
logger.info(f"Iteration {iteration}: S={np.mean(margins):.3f}\u00b1{np.std(margins):.3f}")
|
| 304 |
+
all_iteration_results[f'iteration_{iteration}'] = iter_results
|
| 305 |
+
|
| 306 |
+
# Use refined designs as input for next iteration
|
| 307 |
+
current_pdbs = [r['pdb_path'] for r in iter_results]
|
| 308 |
+
|
| 309 |
+
# Summary
|
| 310 |
+
logger.info(f"\n{'='*60}")
|
| 311 |
+
logger.info("Iterative Refinement Summary")
|
| 312 |
+
logger.info(f"{'='*60}")
|
| 313 |
+
for key, results in all_iteration_results.items():
|
| 314 |
+
if results:
|
| 315 |
+
margins = [r['margin'] for r in results]
|
| 316 |
+
logger.info(f"{key:15s}: S={np.mean(margins):.3f}\u00b1{np.std(margins):.3f}, "
|
| 317 |
+
f"N={len(results)}, S>0={100*np.mean([m>0 for m in margins]):.0f}%")
|
| 318 |
+
|
| 319 |
+
# Save results
|
| 320 |
+
out_path = os.path.join(outdir, 'iterative_refinement_summary.json')
|
| 321 |
+
summary = {}
|
| 322 |
+
for key, results in all_iteration_results.items():
|
| 323 |
+
if results:
|
| 324 |
+
margins = [r['margin'] for r in results]
|
| 325 |
+
summary[key] = {
|
| 326 |
+
'n': len(results),
|
| 327 |
+
'margin_mean': float(np.mean(margins)),
|
| 328 |
+
'margin_std': float(np.std(margins)),
|
| 329 |
+
'margin_max': float(np.max(margins)),
|
| 330 |
+
'frac_positive': float(np.mean([m > 0 for m in margins])),
|
| 331 |
+
}
|
| 332 |
+
with open(out_path, 'w') as f:
|
| 333 |
+
json.dump(summary, f, indent=2)
|
| 334 |
+
logger.info(f"\nSaved to {out_path}")
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
if __name__ == '__main__':
|
| 338 |
+
main()
|
code/scripts/pxdesign_guidance/langevin_pxdesign.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PXDesign + Langevin Refinement.
|
| 3 |
+
|
| 4 |
+
Post-hoc gradient ascent on existing PXDesign binder backbones using Q_theta
|
| 5 |
+
selectivity gradient:
|
| 6 |
+
x_{t+1} = x_t + η · ∇_x[Q(holo,Y) - Q(apo,Y)] + √(2η) · ε
|
| 7 |
+
|
| 8 |
+
Takes PXDesign outputs (which have full sidechains), extracts backbone coords,
|
| 9 |
+
refines them via Langevin dynamics, and outputs refined backbone-only PDBs.
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python code/scripts/pxdesign_guidance/langevin_pxdesign.py \
|
| 13 |
+
--designs_dir experiments/pxdesign_cam/output/ \
|
| 14 |
+
--qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
|
| 15 |
+
--ref_holo data/pdbs/cam_holo/3CLN.pdb \
|
| 16 |
+
--ref_apo data/pdbs/cam_apo/1CFD.pdb \
|
| 17 |
+
--n_steps 100 --step_size 0.01 \
|
| 18 |
+
--gpu 0
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import argparse
|
| 24 |
+
import json
|
| 25 |
+
import logging
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
from glob import glob
|
| 29 |
+
|
| 30 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 34 |
+
_ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
|
| 35 |
+
_ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
|
| 36 |
+
|
| 37 |
+
if _ALLO_CODE_DIR not in sys.path:
|
| 38 |
+
sys.path.insert(0, _ALLO_CODE_DIR)
|
| 39 |
+
|
| 40 |
+
from utils.pdb_utils import (
|
| 41 |
+
load_structure, get_residues, get_backbone_coords,
|
| 42 |
+
get_aa_indices, align_structures
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def write_backbone_pdb(coords, mask, out_path, chain='B'):
|
| 47 |
+
"""Write backbone PDB (N, CA, C, O) from [N, 4, 3] numpy coords."""
|
| 48 |
+
atom_names = [' N ', ' CA ', ' C ', ' O ']
|
| 49 |
+
elements = ['N', 'C', 'C', 'O']
|
| 50 |
+
with open(out_path, 'w') as f:
|
| 51 |
+
atom_idx = 1
|
| 52 |
+
for i in range(len(coords)):
|
| 53 |
+
if not mask[i]:
|
| 54 |
+
continue
|
| 55 |
+
for j, (aname, elem) in enumerate(zip(atom_names, elements)):
|
| 56 |
+
x, y, z = coords[i, j, :]
|
| 57 |
+
f.write(
|
| 58 |
+
f"ATOM {atom_idx:5d} {aname:4s} ALA {chain}{i+1:4d} "
|
| 59 |
+
f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 {elem}\n"
|
| 60 |
+
)
|
| 61 |
+
atom_idx += 1
|
| 62 |
+
f.write("END\n")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def find_pxdesign_pdbs(designs_dir):
|
| 66 |
+
"""Find all PXDesign output PDB files."""
|
| 67 |
+
pdbs = sorted(glob(os.path.join(designs_dir, '**/*.pdb'), recursive=True))
|
| 68 |
+
pdbs = [p for p in pdbs if 'sample' in os.path.basename(p).lower()
|
| 69 |
+
or 'design' in os.path.basename(p).lower()
|
| 70 |
+
or 'rank' in os.path.basename(p).lower()]
|
| 71 |
+
if not pdbs:
|
| 72 |
+
pdbs = sorted(glob(os.path.join(designs_dir, '**/*.pdb'), recursive=True))
|
| 73 |
+
return pdbs
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def langevin_refine(dq, binder_coords_init, binder_mask, binder_aa_idx,
|
| 77 |
+
rec_coords, rec_mask, ref_holo_ca, ref_apo_ca,
|
| 78 |
+
n_steps=100, step_size=0.01, noise_scale=0.0,
|
| 79 |
+
device='cuda:0'):
|
| 80 |
+
"""
|
| 81 |
+
Langevin refinement of binder backbone coordinates.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
dq: DifferentiableQTheta scorer
|
| 85 |
+
binder_coords_init: [N_binder, 4, 3] numpy — initial binder backbone
|
| 86 |
+
binder_mask: [N_binder] numpy bool
|
| 87 |
+
binder_aa_idx: [N_binder] numpy int
|
| 88 |
+
rec_coords: [N_rec, 4, 3] numpy — receptor backbone
|
| 89 |
+
rec_mask: [N_rec] numpy bool
|
| 90 |
+
ref_holo_ca: [N_ref, 3] torch — holo reference CA
|
| 91 |
+
ref_apo_ca: [N_ref, 3] torch — apo reference CA
|
| 92 |
+
n_steps: int
|
| 93 |
+
step_size: float (η)
|
| 94 |
+
noise_scale: float (for stochastic Langevin, 0 = gradient ascent)
|
| 95 |
+
device: str
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
best_coords: [N_binder, 4, 3] numpy — refined coords
|
| 99 |
+
trajectory: list of dicts with step info
|
| 100 |
+
"""
|
| 101 |
+
device = torch.device(device)
|
| 102 |
+
|
| 103 |
+
# Convert to tensors
|
| 104 |
+
x = torch.from_numpy(binder_coords_init.copy()).float().to(device)
|
| 105 |
+
mask_t = torch.from_numpy(binder_mask).bool().to(device)
|
| 106 |
+
aa_t = torch.from_numpy(binder_aa_idx).long().to(device)
|
| 107 |
+
rec_ca = torch.from_numpy(rec_coords[:, 1, :]).float().to(device)
|
| 108 |
+
|
| 109 |
+
best_margin = -float('inf')
|
| 110 |
+
best_coords = binder_coords_init.copy()
|
| 111 |
+
best_q_holo = 0.0
|
| 112 |
+
best_q_apo = 0.0
|
| 113 |
+
trajectory = []
|
| 114 |
+
|
| 115 |
+
for step in range(n_steps):
|
| 116 |
+
x_grad = x.clone().requires_grad_(True)
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
with torch.enable_grad():
|
| 120 |
+
# Align to holo reference
|
| 121 |
+
n_align_h = min(len(rec_ca), len(ref_holo_ca))
|
| 122 |
+
if n_align_h < 5:
|
| 123 |
+
break
|
| 124 |
+
from qtheta_pxdesign import differentiable_kabsch
|
| 125 |
+
R_h, t_h = differentiable_kabsch(rec_ca[:n_align_h].detach(),
|
| 126 |
+
ref_holo_ca[:n_align_h].detach())
|
| 127 |
+
R_h, t_h = R_h.detach(), t_h.detach()
|
| 128 |
+
aligned_holo = x_grad.reshape(-1, 3) @ R_h.T + t_h
|
| 129 |
+
aligned_holo = aligned_holo.reshape(-1, 4, 3)
|
| 130 |
+
|
| 131 |
+
q_holo = dq.score(aligned_holo, mask_t, binder_aa_idx=aa_t,
|
| 132 |
+
receptor_label='holo')
|
| 133 |
+
|
| 134 |
+
# Align to apo reference
|
| 135 |
+
n_align_a = min(len(rec_ca), len(ref_apo_ca))
|
| 136 |
+
R_a, t_a = differentiable_kabsch(rec_ca[:n_align_a].detach(),
|
| 137 |
+
ref_apo_ca[:n_align_a].detach())
|
| 138 |
+
R_a, t_a = R_a.detach(), t_a.detach()
|
| 139 |
+
aligned_apo = x_grad.reshape(-1, 3) @ R_a.T + t_a
|
| 140 |
+
aligned_apo = aligned_apo.reshape(-1, 4, 3)
|
| 141 |
+
|
| 142 |
+
q_apo = dq.score(aligned_apo, mask_t, binder_aa_idx=aa_t,
|
| 143 |
+
receptor_label='apo')
|
| 144 |
+
|
| 145 |
+
margin = q_holo - q_apo
|
| 146 |
+
margin.backward()
|
| 147 |
+
|
| 148 |
+
grad = x_grad.grad
|
| 149 |
+
if grad is None or torch.isnan(grad).any():
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
# Gradient ascent step
|
| 153 |
+
x = x + step_size * grad
|
| 154 |
+
|
| 155 |
+
# Optional noise for stochastic Langevin
|
| 156 |
+
if noise_scale > 0:
|
| 157 |
+
x = x + noise_scale * np.sqrt(2 * step_size) * torch.randn_like(x)
|
| 158 |
+
|
| 159 |
+
current_margin = margin.item()
|
| 160 |
+
step_info = {
|
| 161 |
+
'step': step,
|
| 162 |
+
'q_holo': q_holo.item(),
|
| 163 |
+
'q_apo': q_apo.item(),
|
| 164 |
+
'margin': current_margin,
|
| 165 |
+
'grad_norm': grad.norm().item(),
|
| 166 |
+
}
|
| 167 |
+
trajectory.append(step_info)
|
| 168 |
+
|
| 169 |
+
if current_margin > best_margin:
|
| 170 |
+
best_margin = current_margin
|
| 171 |
+
best_coords = x.detach().cpu().numpy()
|
| 172 |
+
best_q_holo = q_holo.item()
|
| 173 |
+
best_q_apo = q_apo.item()
|
| 174 |
+
|
| 175 |
+
if step % 20 == 0:
|
| 176 |
+
logger.info(
|
| 177 |
+
f" Step {step:3d}: Q+={q_holo.item():.3f} Q-={q_apo.item():.3f} "
|
| 178 |
+
f"S={current_margin:+.3f} |∇|={grad.norm().item():.4f}")
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.debug(f" Step {step}: {e}")
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
return best_coords, trajectory, best_margin, best_q_holo, best_q_apo
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def main():
|
| 188 |
+
parser = argparse.ArgumentParser(description='PXDesign + Langevin Refinement')
|
| 189 |
+
parser.add_argument('--designs_dir', default='experiments/pxdesign_cam/output/')
|
| 190 |
+
parser.add_argument('--qtheta_checkpoint',
|
| 191 |
+
default='results/checkpoints_cam_v3/best_phase2.pt')
|
| 192 |
+
parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
|
| 193 |
+
parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
|
| 194 |
+
parser.add_argument('--ref_chain', default='A')
|
| 195 |
+
parser.add_argument('--n_steps', type=int, default=100)
|
| 196 |
+
parser.add_argument('--step_size', type=float, default=0.01)
|
| 197 |
+
parser.add_argument('--noise_scale', type=float, default=0.0,
|
| 198 |
+
help='Noise scale for stochastic Langevin (0=gradient ascent)')
|
| 199 |
+
parser.add_argument('--gpu', type=int, default=0)
|
| 200 |
+
parser.add_argument('--outdir', default='results/pxdesign_langevin')
|
| 201 |
+
args = parser.parse_args()
|
| 202 |
+
|
| 203 |
+
os.chdir(_ALLO_ROOT)
|
| 204 |
+
|
| 205 |
+
device = f'cuda:{args.gpu}'
|
| 206 |
+
|
| 207 |
+
from models.differentiable_features import DifferentiableQTheta
|
| 208 |
+
|
| 209 |
+
# Load scorer
|
| 210 |
+
dq = DifferentiableQTheta(args.qtheta_checkpoint, device=device)
|
| 211 |
+
dq.load_receptor(args.ref_holo, chain=args.ref_chain, label='holo')
|
| 212 |
+
dq.load_receptor(args.ref_apo, chain=args.ref_chain, label='apo')
|
| 213 |
+
|
| 214 |
+
# Load reference CA coords
|
| 215 |
+
holo_model = load_structure(args.ref_holo)
|
| 216 |
+
holo_res = get_residues(holo_model[args.ref_chain])
|
| 217 |
+
holo_coords, _ = get_backbone_coords(holo_res)
|
| 218 |
+
ref_holo_ca = torch.from_numpy(holo_coords[:, 1, :]).float().to(device)
|
| 219 |
+
|
| 220 |
+
apo_model = load_structure(args.ref_apo)
|
| 221 |
+
apo_res = get_residues(apo_model[args.ref_chain])
|
| 222 |
+
apo_coords, _ = get_backbone_coords(apo_res)
|
| 223 |
+
ref_apo_ca = torch.from_numpy(apo_coords[:, 1, :]).float().to(device)
|
| 224 |
+
|
| 225 |
+
# Find designs
|
| 226 |
+
pdbs = find_pxdesign_pdbs(args.designs_dir)
|
| 227 |
+
logger.info(f"Found {len(pdbs)} PXDesign outputs to refine")
|
| 228 |
+
|
| 229 |
+
outdir = args.outdir
|
| 230 |
+
os.makedirs(outdir, exist_ok=True)
|
| 231 |
+
|
| 232 |
+
all_results = []
|
| 233 |
+
for i, pdb_path in enumerate(pdbs):
|
| 234 |
+
design_id = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
|
| 235 |
+
logger.info(f"\n[{i+1}/{len(pdbs)}] Refining {design_id}...")
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
model = load_structure(pdb_path)
|
| 239 |
+
chains = {c.get_id(): c for c in model.get_chains()}
|
| 240 |
+
chain_ids = sorted(chains.keys())
|
| 241 |
+
|
| 242 |
+
# Identify chains
|
| 243 |
+
ref_len = len(holo_res)
|
| 244 |
+
rec_chain_id, binder_chain_id = None, None
|
| 245 |
+
for cid in chain_ids:
|
| 246 |
+
cres = get_residues(chains[cid])
|
| 247 |
+
if abs(len(cres) - ref_len) < ref_len * 0.3:
|
| 248 |
+
rec_chain_id = cid
|
| 249 |
+
else:
|
| 250 |
+
binder_chain_id = cid
|
| 251 |
+
|
| 252 |
+
if rec_chain_id is None or binder_chain_id is None:
|
| 253 |
+
if len(chain_ids) >= 2:
|
| 254 |
+
rec_chain_id, binder_chain_id = chain_ids[0], chain_ids[1]
|
| 255 |
+
else:
|
| 256 |
+
logger.warning(f"Skipping {design_id}: cannot identify chains")
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
rec_res = get_residues(chains[rec_chain_id])
|
| 260 |
+
binder_res = get_residues(chains[binder_chain_id])
|
| 261 |
+
|
| 262 |
+
rec_coords_np, rec_mask = get_backbone_coords(rec_res)
|
| 263 |
+
binder_coords_np, binder_mask = get_backbone_coords(binder_res)
|
| 264 |
+
aa_idx = get_aa_indices(binder_res)
|
| 265 |
+
|
| 266 |
+
# Score before refinement
|
| 267 |
+
rec_ca = rec_coords_np[:, 1, :]
|
| 268 |
+
n_align = min(len(rec_ca), len(holo_coords[:, 1, :]))
|
| 269 |
+
_, R_h = align_structures(rec_ca[:n_align], holo_coords[:n_align, 1, :])
|
| 270 |
+
center_h = rec_ca[:n_align].mean(0)
|
| 271 |
+
ref_center_h = holo_coords[:n_align, 1, :].mean(0)
|
| 272 |
+
|
| 273 |
+
aligned_init = (binder_coords_np.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h
|
| 274 |
+
aligned_init = aligned_init.reshape(-1, 4, 3)
|
| 275 |
+
with torch.no_grad():
|
| 276 |
+
q_h_init = dq.score(
|
| 277 |
+
torch.from_numpy(aligned_init).float().to(device),
|
| 278 |
+
torch.from_numpy(binder_mask).bool().to(device),
|
| 279 |
+
binder_aa_idx=torch.from_numpy(aa_idx).long().to(device),
|
| 280 |
+
receptor_label='holo').item()
|
| 281 |
+
|
| 282 |
+
n_align_a = min(len(rec_ca), len(apo_coords[:, 1, :]))
|
| 283 |
+
_, R_a = align_structures(rec_ca[:n_align_a], apo_coords[:n_align_a, 1, :])
|
| 284 |
+
center_a = rec_ca[:n_align_a].mean(0)
|
| 285 |
+
ref_center_a = apo_coords[:n_align_a, 1, :].mean(0)
|
| 286 |
+
aligned_init_a = (binder_coords_np.reshape(-1, 3) - center_a) @ R_a.T + ref_center_a
|
| 287 |
+
aligned_init_a = aligned_init_a.reshape(-1, 4, 3)
|
| 288 |
+
with torch.no_grad():
|
| 289 |
+
q_a_init = dq.score(
|
| 290 |
+
torch.from_numpy(aligned_init_a).float().to(device),
|
| 291 |
+
torch.from_numpy(binder_mask).bool().to(device),
|
| 292 |
+
binder_aa_idx=torch.from_numpy(aa_idx).long().to(device),
|
| 293 |
+
receptor_label='apo').item()
|
| 294 |
+
|
| 295 |
+
margin_init = q_h_init - q_a_init
|
| 296 |
+
|
| 297 |
+
# Run Langevin refinement
|
| 298 |
+
refined_coords, trajectory, best_margin, best_qh, best_qa = langevin_refine(
|
| 299 |
+
dq, binder_coords_np, binder_mask, aa_idx,
|
| 300 |
+
rec_coords_np, rec_mask, ref_holo_ca, ref_apo_ca,
|
| 301 |
+
n_steps=args.n_steps, step_size=args.step_size,
|
| 302 |
+
noise_scale=args.noise_scale, device=device,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Use best-margin values (matching the saved best_coords PDB)
|
| 306 |
+
margin_final = best_margin if trajectory else margin_init
|
| 307 |
+
|
| 308 |
+
# Save refined PDB
|
| 309 |
+
out_pdb = os.path.join(outdir, f'{design_id}_refined.pdb')
|
| 310 |
+
write_backbone_pdb(refined_coords, binder_mask, out_pdb)
|
| 311 |
+
|
| 312 |
+
result = {
|
| 313 |
+
'design_id': design_id,
|
| 314 |
+
'pdb_path': pdb_path,
|
| 315 |
+
'refined_pdb': out_pdb,
|
| 316 |
+
'q_holo_init': q_h_init,
|
| 317 |
+
'q_apo_init': q_a_init,
|
| 318 |
+
'margin_init': margin_init,
|
| 319 |
+
'q_holo_final': best_qh if trajectory else q_h_init,
|
| 320 |
+
'q_apo_final': best_qa if trajectory else q_a_init,
|
| 321 |
+
'margin_final': margin_final,
|
| 322 |
+
'margin_delta': margin_final - margin_init,
|
| 323 |
+
'n_steps_converged': len(trajectory),
|
| 324 |
+
'n_res': len(binder_res),
|
| 325 |
+
}
|
| 326 |
+
all_results.append(result)
|
| 327 |
+
|
| 328 |
+
logger.info(
|
| 329 |
+
f" {design_id}: S_init={margin_init:+.3f} -> S_final={margin_final:+.3f} "
|
| 330 |
+
f"(Δ={margin_final - margin_init:+.3f})")
|
| 331 |
+
|
| 332 |
+
except Exception as e:
|
| 333 |
+
logger.warning(f"Failed to refine {design_id}: {e}")
|
| 334 |
+
continue
|
| 335 |
+
|
| 336 |
+
# Summary
|
| 337 |
+
if all_results:
|
| 338 |
+
all_results.sort(key=lambda x: x['margin_final'], reverse=True)
|
| 339 |
+
margins_init = np.array([r['margin_init'] for r in all_results])
|
| 340 |
+
margins_final = np.array([r['margin_final'] for r in all_results])
|
| 341 |
+
deltas = margins_final - margins_init
|
| 342 |
+
|
| 343 |
+
summary = {
|
| 344 |
+
'method': 'PXDesign + Langevin',
|
| 345 |
+
'n_designs': len(all_results),
|
| 346 |
+
'n_steps': args.n_steps,
|
| 347 |
+
'step_size': args.step_size,
|
| 348 |
+
'margin_init_mean': float(margins_init.mean()),
|
| 349 |
+
'margin_final_mean': float(margins_final.mean()),
|
| 350 |
+
'margin_delta_mean': float(deltas.mean()),
|
| 351 |
+
'frac_improved': float((deltas > 0).mean()),
|
| 352 |
+
'frac_positive_init': float((margins_init > 0).mean()),
|
| 353 |
+
'frac_positive_final': float((margins_final > 0).mean()),
|
| 354 |
+
'q_holo_final_mean': float(np.mean([r['q_holo_final'] for r in all_results])),
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
with open(os.path.join(outdir, 'langevin_scores.json'), 'w') as f:
|
| 358 |
+
json.dump(all_results, f, indent=2)
|
| 359 |
+
with open(os.path.join(outdir, 'langevin_summary.json'), 'w') as f:
|
| 360 |
+
json.dump(summary, f, indent=2)
|
| 361 |
+
|
| 362 |
+
logger.info(f"\n{'='*60}")
|
| 363 |
+
logger.info(f"PXDesign + Langevin Results ({len(all_results)} designs)")
|
| 364 |
+
logger.info(f" Margin init: {margins_init.mean():.3f} ± {margins_init.std():.3f}")
|
| 365 |
+
logger.info(f" Margin final: {margins_final.mean():.3f} ± {margins_final.std():.3f}")
|
| 366 |
+
logger.info(f" Δ margin: {deltas.mean():+.3f} ± {deltas.std():.3f}")
|
| 367 |
+
logger.info(f" % improved: {(deltas > 0).mean():.1%}")
|
| 368 |
+
logger.info(f" S>0 init/final: {(margins_init > 0).mean():.1%} / "
|
| 369 |
+
f"{(margins_final > 0).mean():.1%}")
|
| 370 |
+
logger.info(f"{'='*60}")
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
if __name__ == '__main__':
|
| 374 |
+
main()
|
code/scripts/pxdesign_guidance/qtheta_pxdesign.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core Q_theta guidance module for PXDesign integration.
|
| 3 |
+
|
| 4 |
+
Provides differentiable Q_theta scoring for PXDesign's atom coordinate format.
|
| 5 |
+
Key responsibilities:
|
| 6 |
+
- Extract binder backbone (N, CA, C, O) from PXDesign's flat atom array
|
| 7 |
+
- Align binder to reference receptor frames via differentiable Kabsch
|
| 8 |
+
- Compute selectivity gradient ∇[Q(holo,Y) - Q(apo,Y)] w.r.t. atom coords
|
| 9 |
+
- Works in pxdesign env (PyTorch 2.3.1) using pure-PyTorch scorer (no e3nn)
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
guidance = QThetaPXDesignGuidance(
|
| 13 |
+
checkpoint='results/checkpoints_cam_v3/best_phase2.pt',
|
| 14 |
+
ref_holo='data/pdbs/cam_holo/3CLN.pdb',
|
| 15 |
+
ref_apo='data/pdbs/cam_apo/1CFD.pdb',
|
| 16 |
+
ref_chain='A',
|
| 17 |
+
device='cuda:0',
|
| 18 |
+
)
|
| 19 |
+
# Inside PXDesign diffusion loop:
|
| 20 |
+
grad = guidance.compute_guidance_gradient(x_denoised, input_feature_dict, t_hat)
|
| 21 |
+
x_denoised = x_denoised + scale * grad
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import sys
|
| 26 |
+
import logging
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# Add Allo-Designer code directory to path
|
| 33 |
+
_ALLO_CODE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
| 34 |
+
if _ALLO_CODE_DIR not in sys.path:
|
| 35 |
+
sys.path.insert(0, _ALLO_CODE_DIR)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def differentiable_kabsch(mobile, target):
|
| 39 |
+
"""
|
| 40 |
+
Differentiable Kabsch alignment using SVD.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
mobile: [N, 3] tensor (points to align FROM)
|
| 44 |
+
target: [N, 3] tensor (points to align TO)
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
R: [3, 3] rotation matrix
|
| 48 |
+
t: [3] translation vector
|
| 49 |
+
Such that aligned = (mobile - mobile_center) @ R.T + target_center
|
| 50 |
+
"""
|
| 51 |
+
mobile_center = mobile.mean(dim=0)
|
| 52 |
+
target_center = target.mean(dim=0)
|
| 53 |
+
|
| 54 |
+
mobile_centered = mobile - mobile_center
|
| 55 |
+
target_centered = target - target_center
|
| 56 |
+
|
| 57 |
+
H = mobile_centered.T @ target_centered # [3, 3]
|
| 58 |
+
U, S, Vh = torch.linalg.svd(H)
|
| 59 |
+
V = Vh.T
|
| 60 |
+
|
| 61 |
+
# Ensure proper rotation (det > 0)
|
| 62 |
+
d = torch.det(V @ U.T)
|
| 63 |
+
sign_matrix = torch.diag(torch.tensor([1.0, 1.0, torch.sign(d)],
|
| 64 |
+
device=mobile.device, dtype=mobile.dtype))
|
| 65 |
+
R = V @ sign_matrix @ U.T # [3, 3]
|
| 66 |
+
t = target_center - mobile_center @ R.T # [3]
|
| 67 |
+
|
| 68 |
+
return R, t
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class QThetaPXDesignGuidance:
|
| 72 |
+
"""
|
| 73 |
+
Q_theta guidance for PXDesign diffusion process.
|
| 74 |
+
|
| 75 |
+
Lazily initializes the scorer and reference structures on first use.
|
| 76 |
+
Handles extraction of binder backbone from PXDesign's flat atom array
|
| 77 |
+
and alignment to reference receptor frames.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, checkpoint, ref_holo, ref_apo, ref_chain='A',
|
| 81 |
+
device='cuda:0', cutoff=8.0, esm_target='cam'):
|
| 82 |
+
self.checkpoint = checkpoint
|
| 83 |
+
self.ref_holo = ref_holo
|
| 84 |
+
self.ref_apo = ref_apo
|
| 85 |
+
self.ref_chain = ref_chain
|
| 86 |
+
self.device = torch.device(device)
|
| 87 |
+
self.cutoff = cutoff
|
| 88 |
+
self.esm_target = esm_target
|
| 89 |
+
|
| 90 |
+
self._initialized = False
|
| 91 |
+
self.dq = None
|
| 92 |
+
self.ref_holo_ca = None
|
| 93 |
+
self.ref_apo_ca = None
|
| 94 |
+
|
| 95 |
+
def _lazy_init(self):
|
| 96 |
+
"""Initialize Q_theta scorer and load reference structures."""
|
| 97 |
+
if self._initialized:
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
from models.differentiable_features import DifferentiableQTheta
|
| 101 |
+
from utils.pdb_utils import load_structure, get_residues, get_backbone_coords
|
| 102 |
+
|
| 103 |
+
logger.info(f"Loading Q_theta checkpoint: {self.checkpoint}")
|
| 104 |
+
self.dq = DifferentiableQTheta(self.checkpoint, device=str(self.device))
|
| 105 |
+
self.dq.load_receptor(self.ref_holo, chain=self.ref_chain, label='holo',
|
| 106 |
+
esm_target=self.esm_target)
|
| 107 |
+
self.dq.load_receptor(self.ref_apo, chain=self.ref_chain, label='apo',
|
| 108 |
+
esm_target=self.esm_target)
|
| 109 |
+
|
| 110 |
+
# Cache reference CA coords for alignment
|
| 111 |
+
holo_model = load_structure(self.ref_holo)
|
| 112 |
+
holo_res = get_residues(holo_model[self.ref_chain])
|
| 113 |
+
holo_coords, _ = get_backbone_coords(holo_res)
|
| 114 |
+
self.ref_holo_ca = torch.from_numpy(holo_coords[:, 1, :]).float().to(self.device)
|
| 115 |
+
|
| 116 |
+
apo_model = load_structure(self.ref_apo)
|
| 117 |
+
apo_res = get_residues(apo_model[self.ref_chain])
|
| 118 |
+
apo_coords, _ = get_backbone_coords(apo_res)
|
| 119 |
+
self.ref_apo_ca = torch.from_numpy(apo_coords[:, 1, :]).float().to(self.device)
|
| 120 |
+
|
| 121 |
+
self._initialized = True
|
| 122 |
+
logger.info(f"Q_theta guidance initialized: holo={len(holo_res)} res, apo={len(apo_res)} res")
|
| 123 |
+
|
| 124 |
+
def extract_binder_backbone(self, x_coords, input_feature_dict):
|
| 125 |
+
"""
|
| 126 |
+
Extract binder backbone atoms (N, CA, C, O) from PXDesign's flat atom array.
|
| 127 |
+
|
| 128 |
+
PXDesign stores all atoms in a flat [N_atom, 3] array. Entity annotations
|
| 129 |
+
identify which atoms belong to the designed binder (entity_id=2 typically,
|
| 130 |
+
or the last entity). We extract backbone atoms for each binder residue.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
x_coords: [N_sample, N_atom, 3] — current coordinates from diffusion
|
| 134 |
+
input_feature_dict: dict with atom_to_token_idx, entity_id, etc.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
binder_bb: [N_sample, N_binder_res, 4, 3] — backbone coords (N, CA, C, O)
|
| 138 |
+
binder_mask: [N_binder_res] — validity mask
|
| 139 |
+
rec_bb: [N_rec_res, 4, 3] — receptor backbone coords (from condition)
|
| 140 |
+
rec_mask: [N_rec_res] — receptor validity mask
|
| 141 |
+
binder_atom_indices: [N_binder_bb_atoms] — indices into flat atom array
|
| 142 |
+
"""
|
| 143 |
+
atom_to_token = input_feature_dict['atom_to_token_idx'] # [N_atom]
|
| 144 |
+
if atom_to_token.dim() > 1:
|
| 145 |
+
atom_to_token = atom_to_token.squeeze(0)
|
| 146 |
+
|
| 147 |
+
# Identify binder vs receptor tokens
|
| 148 |
+
# In PXDesign: design_token_mask=True for binder tokens
|
| 149 |
+
design_token_mask = input_feature_dict.get('design_token_mask', None)
|
| 150 |
+
if design_token_mask is not None:
|
| 151 |
+
if design_token_mask.dim() > 1:
|
| 152 |
+
design_token_mask = design_token_mask.squeeze(0)
|
| 153 |
+
binder_tokens = torch.where(design_token_mask)[0]
|
| 154 |
+
rec_tokens = torch.where(~design_token_mask)[0]
|
| 155 |
+
else:
|
| 156 |
+
# Fallback: use entity_id (binder is typically entity_id=2, the last entity)
|
| 157 |
+
entity_id = input_feature_dict['entity_id']
|
| 158 |
+
if entity_id.dim() > 1:
|
| 159 |
+
entity_id = entity_id.squeeze(0)
|
| 160 |
+
max_entity = entity_id.max()
|
| 161 |
+
binder_tokens = torch.where(entity_id == max_entity)[0]
|
| 162 |
+
rec_tokens = torch.where(entity_id != max_entity)[0]
|
| 163 |
+
|
| 164 |
+
# Map tokens to atoms
|
| 165 |
+
# For standard amino acids, atom order within each token is:
|
| 166 |
+
# N(0), CA(1), C(2), O(3), CB(4), ...
|
| 167 |
+
# We need atoms 0-3 (N, CA, C, O) per token
|
| 168 |
+
|
| 169 |
+
# Get atom indices for each binder token
|
| 170 |
+
n_binder_res = len(binder_tokens)
|
| 171 |
+
if n_binder_res == 0:
|
| 172 |
+
return None
|
| 173 |
+
|
| 174 |
+
# Find atoms belonging to each binder residue
|
| 175 |
+
binder_bb_list = []
|
| 176 |
+
binder_atom_idx_list = []
|
| 177 |
+
for tok_idx in binder_tokens:
|
| 178 |
+
atom_indices = torch.where(atom_to_token == tok_idx.item())[0]
|
| 179 |
+
if len(atom_indices) >= 4:
|
| 180 |
+
# First 4 atoms are N, CA, C, O for standard amino acids
|
| 181 |
+
bb_atoms = atom_indices[:4]
|
| 182 |
+
binder_bb_list.append(bb_atoms)
|
| 183 |
+
binder_atom_idx_list.append(bb_atoms)
|
| 184 |
+
|
| 185 |
+
if not binder_bb_list:
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
n_binder_res = len(binder_bb_list)
|
| 189 |
+
binder_bb_indices = torch.stack(binder_bb_list) # [N_binder, 4]
|
| 190 |
+
all_binder_atom_indices = torch.cat(binder_atom_idx_list) # [N_binder * 4]
|
| 191 |
+
|
| 192 |
+
# Extract binder backbone coords for all samples
|
| 193 |
+
# x_coords: [N_sample, N_atom, 3]
|
| 194 |
+
binder_bb = x_coords[:, binder_bb_indices, :] # [N_sample, N_binder, 4, 3]
|
| 195 |
+
binder_mask = torch.ones(n_binder_res, dtype=torch.bool, device=x_coords.device)
|
| 196 |
+
|
| 197 |
+
# Extract receptor backbone from x_coords or condition_coordinate.
|
| 198 |
+
# PXDesign stores condition_coordinate in label_dict (not input_feature_dict),
|
| 199 |
+
# so we extract receptor backbone from x_coords directly. In the diffusion
|
| 200 |
+
# process, receptor atoms are conditioned at their reference positions.
|
| 201 |
+
# Try condition_coordinate first (if available), then fall back to x_coords.
|
| 202 |
+
cond_coords = input_feature_dict.get('condition_coordinate', None)
|
| 203 |
+
if cond_coords is None:
|
| 204 |
+
# Also try label_dict nesting
|
| 205 |
+
label_dict = input_feature_dict.get('label_dict', None)
|
| 206 |
+
if label_dict is not None:
|
| 207 |
+
cond_coords = label_dict.get('condition_coordinate', None)
|
| 208 |
+
|
| 209 |
+
rec_bb = None
|
| 210 |
+
rec_mask = None
|
| 211 |
+
|
| 212 |
+
# Get receptor backbone atoms
|
| 213 |
+
rec_bb_list = []
|
| 214 |
+
for tok_idx in rec_tokens:
|
| 215 |
+
atom_indices = torch.where(atom_to_token == tok_idx.item())[0]
|
| 216 |
+
if len(atom_indices) >= 4:
|
| 217 |
+
rec_bb_list.append(atom_indices[:4])
|
| 218 |
+
|
| 219 |
+
if rec_bb_list:
|
| 220 |
+
rec_bb_indices = torch.stack(rec_bb_list) # [N_rec, 4]
|
| 221 |
+
|
| 222 |
+
if cond_coords is not None:
|
| 223 |
+
if cond_coords.dim() > 2:
|
| 224 |
+
cond_coords = cond_coords.squeeze(0)
|
| 225 |
+
rec_bb = cond_coords[rec_bb_indices, :] # [N_rec, 4, 3]
|
| 226 |
+
else:
|
| 227 |
+
# Fallback: extract receptor coords from x_coords (sample 0)
|
| 228 |
+
# Receptor atoms are conditioned and constant across samples
|
| 229 |
+
rec_bb = x_coords[0, rec_bb_indices, :].detach() # [N_rec, 4, 3]
|
| 230 |
+
|
| 231 |
+
rec_mask = torch.ones(len(rec_bb_list), dtype=torch.bool,
|
| 232 |
+
device=x_coords.device)
|
| 233 |
+
|
| 234 |
+
return {
|
| 235 |
+
'binder_bb': binder_bb, # [N_sample, N_binder, 4, 3]
|
| 236 |
+
'binder_mask': binder_mask, # [N_binder]
|
| 237 |
+
'rec_bb': rec_bb, # [N_rec, 4, 3] or None
|
| 238 |
+
'rec_mask': rec_mask, # [N_rec] or None
|
| 239 |
+
'binder_atom_indices': binder_bb_indices, # [N_binder, 4]
|
| 240 |
+
'all_binder_atom_indices': all_binder_atom_indices, # [N_binder * 4]
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
def align_and_score(self, binder_bb, rec_bb, rec_mask, receptor_label):
|
| 244 |
+
"""
|
| 245 |
+
Align binder to a reference receptor frame and score with Q_theta.
|
| 246 |
+
|
| 247 |
+
Uses the receptor chain from the design to compute Kabsch alignment
|
| 248 |
+
to the reference receptor, then transforms the binder accordingly.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
binder_bb: [N_binder, 4, 3] — binder backbone coords (requires_grad)
|
| 252 |
+
rec_bb: [N_rec, 4, 3] — receptor backbone coords
|
| 253 |
+
rec_mask: [N_rec] bool
|
| 254 |
+
receptor_label: 'holo' or 'apo'
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
score: scalar tensor, differentiable w.r.t. binder_bb
|
| 258 |
+
"""
|
| 259 |
+
if receptor_label == 'holo':
|
| 260 |
+
ref_ca = self.ref_holo_ca
|
| 261 |
+
else:
|
| 262 |
+
ref_ca = self.ref_apo_ca
|
| 263 |
+
|
| 264 |
+
# Get CA atoms from receptor
|
| 265 |
+
rec_ca = rec_bb[:, 1, :] # [N_rec, 3]
|
| 266 |
+
|
| 267 |
+
# Use overlapping residues for alignment (take min length)
|
| 268 |
+
n_align = min(len(rec_ca), len(ref_ca))
|
| 269 |
+
if n_align < 5:
|
| 270 |
+
return torch.zeros(1, device=binder_bb.device, dtype=binder_bb.dtype,
|
| 271 |
+
requires_grad=True).squeeze()
|
| 272 |
+
|
| 273 |
+
mobile_ca = rec_ca[:n_align].detach()
|
| 274 |
+
target_ca = ref_ca[:n_align].detach()
|
| 275 |
+
|
| 276 |
+
# Compute Kabsch alignment (detached — no gradient through rotation)
|
| 277 |
+
R, t = differentiable_kabsch(mobile_ca, target_ca)
|
| 278 |
+
R = R.detach()
|
| 279 |
+
t = t.detach()
|
| 280 |
+
|
| 281 |
+
# Apply transform to binder (gradient flows through binder_bb)
|
| 282 |
+
binder_flat = binder_bb.reshape(-1, 3) # [N_binder*4, 3]
|
| 283 |
+
aligned = binder_flat @ R.T + t # [N_binder*4, 3]
|
| 284 |
+
aligned_bb = aligned.reshape(-1, 4, 3) # [N_binder, 4, 3]
|
| 285 |
+
|
| 286 |
+
# Score with Q_theta
|
| 287 |
+
binder_mask = torch.ones(aligned_bb.shape[0], dtype=torch.bool,
|
| 288 |
+
device=binder_bb.device)
|
| 289 |
+
score = self.dq.score(aligned_bb, binder_mask, receptor_label=receptor_label,
|
| 290 |
+
cutoff=self.cutoff)
|
| 291 |
+
return score
|
| 292 |
+
|
| 293 |
+
def compute_guidance_gradient(self, x_denoised, input_feature_dict, t_hat=None,
|
| 294 |
+
sample_idx=0):
|
| 295 |
+
"""
|
| 296 |
+
Compute Q_theta selectivity gradient for guidance.
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
x_denoised: [N_sample, N_atom, 3] — denoised coordinates from diffusion net
|
| 300 |
+
input_feature_dict: PXDesign input features dict
|
| 301 |
+
t_hat: current noise level (for logging/scaling)
|
| 302 |
+
sample_idx: which sample to compute gradient for (or -1 for all)
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
gradient: [N_sample, N_atom, 3] — gradient to add to x_denoised
|
| 306 |
+
(non-zero only at binder backbone atom positions)
|
| 307 |
+
margin: float — current selectivity margin
|
| 308 |
+
"""
|
| 309 |
+
self._lazy_init()
|
| 310 |
+
|
| 311 |
+
extraction = self.extract_binder_backbone(x_denoised.detach(), input_feature_dict)
|
| 312 |
+
if extraction is None:
|
| 313 |
+
return torch.zeros_like(x_denoised), 0.0
|
| 314 |
+
|
| 315 |
+
binder_bb = extraction['binder_bb'] # [N_sample, N_binder, 4, 3]
|
| 316 |
+
binder_mask = extraction['binder_mask'] # [N_binder]
|
| 317 |
+
rec_bb = extraction['rec_bb'] # [N_rec, 4, 3]
|
| 318 |
+
rec_mask = extraction['rec_mask'] # [N_rec]
|
| 319 |
+
binder_atom_indices = extraction['binder_atom_indices'] # [N_binder, 4]
|
| 320 |
+
|
| 321 |
+
if rec_bb is None:
|
| 322 |
+
return torch.zeros_like(x_denoised), 0.0
|
| 323 |
+
|
| 324 |
+
N_sample = x_denoised.shape[0]
|
| 325 |
+
gradient = torch.zeros_like(x_denoised)
|
| 326 |
+
margins = []
|
| 327 |
+
|
| 328 |
+
# Ensure receptor is float32 for Q_theta scoring
|
| 329 |
+
if rec_bb is not None:
|
| 330 |
+
rec_bb = rec_bb.float()
|
| 331 |
+
|
| 332 |
+
# Process each sample
|
| 333 |
+
indices = range(N_sample) if sample_idx == -1 else [sample_idx]
|
| 334 |
+
for si in indices:
|
| 335 |
+
# Make binder coords differentiable, cast to float32 for Q_theta
|
| 336 |
+
binder_si = binder_bb[si].clone().float().requires_grad_(True) # [N_binder, 4, 3]
|
| 337 |
+
|
| 338 |
+
try:
|
| 339 |
+
with torch.enable_grad():
|
| 340 |
+
q_holo = self.align_and_score(binder_si, rec_bb, rec_mask, 'holo')
|
| 341 |
+
q_apo = self.align_and_score(binder_si, rec_bb, rec_mask, 'apo')
|
| 342 |
+
margin = q_holo - q_apo
|
| 343 |
+
margin.backward()
|
| 344 |
+
|
| 345 |
+
if binder_si.grad is not None and not torch.isnan(binder_si.grad).any():
|
| 346 |
+
# Map gradient back to full atom array
|
| 347 |
+
grad_bb = binder_si.grad # [N_binder, 4, 3]
|
| 348 |
+
for ri in range(len(binder_atom_indices)):
|
| 349 |
+
for ai in range(4):
|
| 350 |
+
atom_idx = binder_atom_indices[ri, ai]
|
| 351 |
+
gradient[si, atom_idx] = grad_bb[ri, ai]
|
| 352 |
+
margins.append(margin.item())
|
| 353 |
+
else:
|
| 354 |
+
margins.append(0.0)
|
| 355 |
+
except Exception as e:
|
| 356 |
+
logger.debug(f"Gradient computation failed for sample {si}: {e}")
|
| 357 |
+
margins.append(0.0)
|
| 358 |
+
|
| 359 |
+
avg_margin = np.mean(margins) if margins else 0.0
|
| 360 |
+
return gradient, avg_margin
|
| 361 |
+
|
| 362 |
+
def score_design(self, pdb_path, rec_chain='A', binder_chain='B'):
|
| 363 |
+
"""
|
| 364 |
+
Score a single PXDesign output PDB/CIF (post-hoc, no gradient).
|
| 365 |
+
|
| 366 |
+
Handles PXDesign CIF files which use chain IDs like 'A0'/'B0' and
|
| 367 |
+
non-standard residue name 'xpb' for designed binder residues.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
dict with q_holo, q_apo, margin, or None on failure
|
| 371 |
+
"""
|
| 372 |
+
self._lazy_init()
|
| 373 |
+
|
| 374 |
+
from utils.pdb_utils import (
|
| 375 |
+
load_structure, get_residues, get_backbone_coords,
|
| 376 |
+
get_aa_indices, align_structures
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
model = load_structure(pdb_path)
|
| 381 |
+
chains = {c.get_id(): c for c in model.get_chains()}
|
| 382 |
+
|
| 383 |
+
if len(chains) < 2:
|
| 384 |
+
return None
|
| 385 |
+
|
| 386 |
+
chain_ids = sorted(chains.keys())
|
| 387 |
+
|
| 388 |
+
# Identify receptor and binder
|
| 389 |
+
# PXDesign CIF uses chain IDs like 'A0', 'B0' instead of 'A', 'B'
|
| 390 |
+
rc, bc = None, None
|
| 391 |
+
if rec_chain in chains and binder_chain in chains:
|
| 392 |
+
rc, bc = rec_chain, binder_chain
|
| 393 |
+
else:
|
| 394 |
+
# Match by residue count: receptor matches reference length,
|
| 395 |
+
# binder is the other chain
|
| 396 |
+
ref_model = load_structure(self.ref_holo)
|
| 397 |
+
ref_res = get_residues(ref_model[self.ref_chain])
|
| 398 |
+
ref_len = len(ref_res)
|
| 399 |
+
for cid in chain_ids:
|
| 400 |
+
# Try standard residues first, then all residues
|
| 401 |
+
cres = get_residues(chains[cid])
|
| 402 |
+
if not cres:
|
| 403 |
+
cres = get_residues(chains[cid], only_standard=False)
|
| 404 |
+
n_res = len(cres)
|
| 405 |
+
if n_res > 0 and abs(n_res - ref_len) < ref_len * 0.3:
|
| 406 |
+
rc = cid
|
| 407 |
+
elif n_res > 0:
|
| 408 |
+
bc = cid
|
| 409 |
+
if rc is None or bc is None:
|
| 410 |
+
rc, bc = chain_ids[0], chain_ids[1]
|
| 411 |
+
|
| 412 |
+
rec_res = get_residues(chains[rc])
|
| 413 |
+
if not rec_res:
|
| 414 |
+
rec_res = get_residues(chains[rc], only_standard=False)
|
| 415 |
+
|
| 416 |
+
# For binder: PXDesign uses 'xpb' residue names (non-standard)
|
| 417 |
+
binder_res = get_residues(chains[bc])
|
| 418 |
+
if not binder_res:
|
| 419 |
+
binder_res = get_residues(chains[bc], only_standard=False)
|
| 420 |
+
|
| 421 |
+
if not rec_res or not binder_res:
|
| 422 |
+
return None
|
| 423 |
+
|
| 424 |
+
rec_coords, rec_mask = get_backbone_coords(rec_res)
|
| 425 |
+
binder_coords, binder_mask = get_backbone_coords(binder_res)
|
| 426 |
+
|
| 427 |
+
# Handle amino acid indices: use get_aa_indices for standard AAs,
|
| 428 |
+
# default to GLY (7) for non-standard (PXDesign 'xpb')
|
| 429 |
+
try:
|
| 430 |
+
aa_idx = get_aa_indices(binder_res)
|
| 431 |
+
except Exception:
|
| 432 |
+
aa_idx = np.zeros(len(binder_res), dtype=np.int64) # default to ALA
|
| 433 |
+
|
| 434 |
+
device = self.device
|
| 435 |
+
|
| 436 |
+
# Align to holo
|
| 437 |
+
rec_ca = rec_coords[:, 1, :]
|
| 438 |
+
ref_holo_ca_np = self.ref_holo_ca.cpu().numpy()
|
| 439 |
+
n_align = min(len(rec_ca), len(ref_holo_ca_np))
|
| 440 |
+
if n_align < 5:
|
| 441 |
+
return None
|
| 442 |
+
_, R_h = align_structures(rec_ca[:n_align], ref_holo_ca_np[:n_align])
|
| 443 |
+
center_h = rec_ca[:n_align].mean(0)
|
| 444 |
+
ref_center_h = ref_holo_ca_np[:n_align].mean(0)
|
| 445 |
+
aligned_holo = (binder_coords.reshape(-1, 3) - center_h) @ R_h.T + ref_center_h
|
| 446 |
+
aligned_holo = aligned_holo.reshape(-1, 4, 3)
|
| 447 |
+
|
| 448 |
+
# Align to apo
|
| 449 |
+
ref_apo_ca_np = self.ref_apo_ca.cpu().numpy()
|
| 450 |
+
n_align_a = min(len(rec_ca), len(ref_apo_ca_np))
|
| 451 |
+
_, R_a = align_structures(rec_ca[:n_align_a], ref_apo_ca_np[:n_align_a])
|
| 452 |
+
center_a = rec_ca[:n_align_a].mean(0)
|
| 453 |
+
ref_center_a = ref_apo_ca_np[:n_align_a].mean(0)
|
| 454 |
+
aligned_apo = (binder_coords.reshape(-1, 3) - center_a) @ R_a.T + ref_center_a
|
| 455 |
+
aligned_apo = aligned_apo.reshape(-1, 4, 3)
|
| 456 |
+
|
| 457 |
+
with torch.no_grad():
|
| 458 |
+
coords_h = torch.from_numpy(aligned_holo).float().to(device)
|
| 459 |
+
coords_a = torch.from_numpy(aligned_apo).float().to(device)
|
| 460 |
+
mask_t = torch.from_numpy(binder_mask).bool().to(device)
|
| 461 |
+
aa_t = torch.from_numpy(aa_idx).long().to(device)
|
| 462 |
+
|
| 463 |
+
q_holo = self.dq.score(coords_h, mask_t, binder_aa_idx=aa_t,
|
| 464 |
+
receptor_label='holo').item()
|
| 465 |
+
q_apo = self.dq.score(coords_a, mask_t, binder_aa_idx=aa_t,
|
| 466 |
+
receptor_label='apo').item()
|
| 467 |
+
|
| 468 |
+
return {
|
| 469 |
+
'q_holo': q_holo,
|
| 470 |
+
'q_apo': q_apo,
|
| 471 |
+
'margin': q_holo - q_apo,
|
| 472 |
+
'n_res': len(binder_res),
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
except Exception as e:
|
| 476 |
+
logger.warning(f"Error scoring {pdb_path}: {e}")
|
| 477 |
+
return None
|
code/scripts/pxdesign_guidance/smc_pxdesign.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PXDesign + SMC Reranking.
|
| 3 |
+
|
| 4 |
+
Post-hoc Sequential Monte Carlo: generate multiple batches of vanilla PXDesign
|
| 5 |
+
binders, score with Q_theta, and rank by selectivity margin. No modification
|
| 6 |
+
to the PXDesign diffusion process — pure generate-score-rank pipeline.
|
| 7 |
+
|
| 8 |
+
This is the simplest Q_theta integration strategy: generate a large pool of
|
| 9 |
+
candidates and select the best ones by selectivity score.
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python code/scripts/pxdesign_guidance/smc_pxdesign.py \
|
| 13 |
+
--input experiments/pxdesign_cam/output/cam_binder.json \
|
| 14 |
+
--qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
|
| 15 |
+
--ref_holo data/pdbs/cam_holo/3CLN.pdb \
|
| 16 |
+
--ref_apo data/pdbs/cam_apo/1CFD.pdb \
|
| 17 |
+
--n_particles 16 --n_rounds 4 \
|
| 18 |
+
--gpu 0
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import argparse
|
| 24 |
+
import json
|
| 25 |
+
import logging
|
| 26 |
+
import shutil
|
| 27 |
+
import subprocess
|
| 28 |
+
from glob import glob
|
| 29 |
+
|
| 30 |
+
import numpy as np
|
| 31 |
+
import torch
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 37 |
+
_ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
|
| 38 |
+
_ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
|
| 39 |
+
|
| 40 |
+
if _ALLO_CODE_DIR not in sys.path:
|
| 41 |
+
sys.path.insert(0, _ALLO_CODE_DIR)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def run_pxdesign_batch(input_json, outdir, n_sample, n_step, gpu):
|
| 45 |
+
"""Run vanilla PXDesign via CLI subprocess."""
|
| 46 |
+
pxdesign_python = 'python'
|
| 47 |
+
|
| 48 |
+
# Use pxdesign CLI
|
| 49 |
+
cmd = [
|
| 50 |
+
pxdesign_python, '-m', 'pxdesign.runner.cli', 'infer',
|
| 51 |
+
'-o', outdir,
|
| 52 |
+
'-i', input_json,
|
| 53 |
+
'--dtype', 'bf16',
|
| 54 |
+
'--N_sample', str(n_sample),
|
| 55 |
+
'--N_step', str(n_step),
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
env = os.environ.copy()
|
| 59 |
+
# Inherit CUDA_VISIBLE_DEVICES from parent
|
| 60 |
+
|
| 61 |
+
logger.info(f"Running PXDesign: {n_sample} samples -> {outdir}")
|
| 62 |
+
result = subprocess.run(cmd, capture_output=True, text=True, env=env,
|
| 63 |
+
timeout=7200)
|
| 64 |
+
|
| 65 |
+
if result.returncode != 0:
|
| 66 |
+
# Try alternative invocation via module
|
| 67 |
+
cmd_alt = [
|
| 68 |
+
pxdesign_python, '-m', 'pxdesign.runner.inference',
|
| 69 |
+
'--dump_dir', outdir,
|
| 70 |
+
'--input', input_json,
|
| 71 |
+
'--dtype', 'bf16',
|
| 72 |
+
'--N_sample', str(n_sample),
|
| 73 |
+
'--N_step', str(n_step),
|
| 74 |
+
]
|
| 75 |
+
result = subprocess.run(cmd_alt, capture_output=True, text=True, env=env,
|
| 76 |
+
timeout=7200)
|
| 77 |
+
if result.returncode != 0:
|
| 78 |
+
logger.error(f"PXDesign failed:\nstdout: {result.stdout[-1000:]}\nstderr: {result.stderr[-1000:]}")
|
| 79 |
+
return False
|
| 80 |
+
return True
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def collect_pdbs(outdir):
|
| 84 |
+
"""Collect PDB/CIF files from PXDesign output."""
|
| 85 |
+
pdbs = []
|
| 86 |
+
for ext in ('*.pdb', '*.cif'):
|
| 87 |
+
pdbs.extend(glob(os.path.join(outdir, '**/' + ext), recursive=True))
|
| 88 |
+
pdbs = sorted(pdbs)
|
| 89 |
+
filtered = [p for p in pdbs if 'sample' in os.path.basename(p).lower()
|
| 90 |
+
or 'design' in os.path.basename(p).lower()
|
| 91 |
+
or 'rank' in os.path.basename(p).lower()]
|
| 92 |
+
return filtered if filtered else pdbs
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def smc_particle_filter(args):
|
| 96 |
+
"""Run SMC reranking with PXDesign."""
|
| 97 |
+
os.chdir(_ALLO_ROOT)
|
| 98 |
+
|
| 99 |
+
from qtheta_pxdesign import QThetaPXDesignGuidance
|
| 100 |
+
|
| 101 |
+
outdir = args.outdir
|
| 102 |
+
os.makedirs(outdir, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
# Initialize scorer
|
| 105 |
+
guidance = QThetaPXDesignGuidance(
|
| 106 |
+
checkpoint=args.qtheta_checkpoint,
|
| 107 |
+
ref_holo=args.ref_holo,
|
| 108 |
+
ref_apo=args.ref_apo,
|
| 109 |
+
ref_chain=args.ref_chain,
|
| 110 |
+
device=f'cuda:{args.gpu}',
|
| 111 |
+
)
|
| 112 |
+
guidance._lazy_init()
|
| 113 |
+
|
| 114 |
+
all_designs = []
|
| 115 |
+
round_summaries = []
|
| 116 |
+
|
| 117 |
+
for round_idx in range(args.n_rounds):
|
| 118 |
+
round_dir = os.path.join(outdir, f'round_{round_idx}')
|
| 119 |
+
os.makedirs(round_dir, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
logger.info(f"\n{'='*60}")
|
| 122 |
+
logger.info(f"SMC Round {round_idx + 1}/{args.n_rounds}")
|
| 123 |
+
logger.info(f"{'='*60}")
|
| 124 |
+
|
| 125 |
+
# Generate particles via vanilla PXDesign
|
| 126 |
+
gen_dir = os.path.join(round_dir, 'generated')
|
| 127 |
+
success = run_pxdesign_batch(
|
| 128 |
+
input_json=args.input,
|
| 129 |
+
outdir=gen_dir,
|
| 130 |
+
n_sample=args.n_particles,
|
| 131 |
+
n_step=args.N_step,
|
| 132 |
+
gpu=args.gpu,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if not success:
|
| 136 |
+
# If subprocess fails, try using existing PXDesign outputs
|
| 137 |
+
logger.warning(f"Round {round_idx} generation failed. "
|
| 138 |
+
f"Checking for existing outputs...")
|
| 139 |
+
pdbs = collect_pdbs(args.designs_dir) if hasattr(args, 'designs_dir') else []
|
| 140 |
+
if not pdbs:
|
| 141 |
+
continue
|
| 142 |
+
else:
|
| 143 |
+
pdbs = collect_pdbs(gen_dir)
|
| 144 |
+
|
| 145 |
+
if not pdbs:
|
| 146 |
+
logger.warning(f"No PDBs found in round {round_idx}")
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
# Score all particles
|
| 150 |
+
logger.info(f"Scoring {len(pdbs)} particles...")
|
| 151 |
+
round_results = []
|
| 152 |
+
for pdb_path in pdbs:
|
| 153 |
+
result = guidance.score_design(pdb_path)
|
| 154 |
+
if result is not None:
|
| 155 |
+
result['pdb_path'] = pdb_path
|
| 156 |
+
result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
|
| 157 |
+
result['round'] = round_idx
|
| 158 |
+
round_results.append(result)
|
| 159 |
+
|
| 160 |
+
if not round_results:
|
| 161 |
+
continue
|
| 162 |
+
|
| 163 |
+
margins = np.array([r['margin'] for r in round_results])
|
| 164 |
+
|
| 165 |
+
round_summary = {
|
| 166 |
+
'round': round_idx,
|
| 167 |
+
'n_particles': len(round_results),
|
| 168 |
+
'margin_mean': float(margins.mean()),
|
| 169 |
+
'margin_std': float(margins.std()),
|
| 170 |
+
'margin_max': float(margins.max()),
|
| 171 |
+
'frac_positive': float((margins > 0).mean()),
|
| 172 |
+
}
|
| 173 |
+
round_summaries.append(round_summary)
|
| 174 |
+
|
| 175 |
+
logger.info(f"Round {round_idx}: margin={margins.mean():.3f}±{margins.std():.3f}, "
|
| 176 |
+
f"max={margins.max():.3f}, S>0={round_summary['frac_positive']:.1%}")
|
| 177 |
+
|
| 178 |
+
all_designs.extend(round_results)
|
| 179 |
+
|
| 180 |
+
# Final ranking and summary
|
| 181 |
+
if all_designs:
|
| 182 |
+
all_designs.sort(key=lambda x: x['margin'], reverse=True)
|
| 183 |
+
all_margins = np.array([d['margin'] for d in all_designs])
|
| 184 |
+
holo_scores = np.array([d['q_holo'] for d in all_designs])
|
| 185 |
+
|
| 186 |
+
# Best-of-K
|
| 187 |
+
bok = {}
|
| 188 |
+
for K in [1, 2, 5, 10]:
|
| 189 |
+
n_trials = 2000
|
| 190 |
+
n_avail = len(all_margins)
|
| 191 |
+
successes = sum(
|
| 192 |
+
1 for _ in range(n_trials)
|
| 193 |
+
if all_margins[np.random.choice(n_avail, min(K, n_avail), replace=False)].max() > 0
|
| 194 |
+
)
|
| 195 |
+
bok[K] = successes / n_trials
|
| 196 |
+
|
| 197 |
+
summary = {
|
| 198 |
+
'method': 'PXDesign + SMC',
|
| 199 |
+
'n_rounds': args.n_rounds,
|
| 200 |
+
'n_particles_per_round': args.n_particles,
|
| 201 |
+
'total_designs': len(all_designs),
|
| 202 |
+
'margin_mean': float(all_margins.mean()),
|
| 203 |
+
'margin_std': float(all_margins.std()),
|
| 204 |
+
'margin_max': float(all_margins.max()),
|
| 205 |
+
'frac_positive': float((all_margins > 0).mean()),
|
| 206 |
+
'q_holo_mean': float(holo_scores.mean()),
|
| 207 |
+
'q_apo_mean': float(np.mean([d['q_apo'] for d in all_designs])),
|
| 208 |
+
'best_of_k': {str(k): v for k, v in bok.items()},
|
| 209 |
+
'round_summaries': round_summaries,
|
| 210 |
+
'top5': all_designs[:5],
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
with open(os.path.join(outdir, 'smc_scores.json'), 'w') as f:
|
| 214 |
+
json.dump(all_designs, f, indent=2)
|
| 215 |
+
with open(os.path.join(outdir, 'smc_summary.json'), 'w') as f:
|
| 216 |
+
json.dump(summary, f, indent=2)
|
| 217 |
+
|
| 218 |
+
# Copy best designs
|
| 219 |
+
best_dir = os.path.join(outdir, 'best_designs')
|
| 220 |
+
os.makedirs(best_dir, exist_ok=True)
|
| 221 |
+
for i, d in enumerate(all_designs[:20]):
|
| 222 |
+
if os.path.exists(d['pdb_path']):
|
| 223 |
+
dest = os.path.join(best_dir, f'rank_{i:02d}_{d["design_id"]}.pdb')
|
| 224 |
+
shutil.copy2(d['pdb_path'], dest)
|
| 225 |
+
|
| 226 |
+
logger.info(f"\n{'='*60}")
|
| 227 |
+
logger.info(f"PXDesign + SMC Results ({len(all_designs)} total designs)")
|
| 228 |
+
logger.info(f" Margin: {all_margins.mean():.3f} ± {all_margins.std():.3f}")
|
| 229 |
+
logger.info(f" Max margin: {all_margins.max():.3f}")
|
| 230 |
+
logger.info(f" Fraction S > 0: {(all_margins > 0).mean():.1%}")
|
| 231 |
+
logger.info(f" Q(holo) mean: {holo_scores.mean():.3f}")
|
| 232 |
+
logger.info(f" Best-of-K:")
|
| 233 |
+
for k, v in sorted(bok.items()):
|
| 234 |
+
logger.info(f" K={k:3d}: {v:.3f}")
|
| 235 |
+
logger.info(f"{'='*60}")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def main():
|
| 239 |
+
parser = argparse.ArgumentParser(description='PXDesign + SMC Reranking')
|
| 240 |
+
parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json',
|
| 241 |
+
help='PXDesign input JSON')
|
| 242 |
+
parser.add_argument('--designs_dir', default='experiments/pxdesign_cam/output/',
|
| 243 |
+
help='Existing PXDesign outputs (fallback if generation fails)')
|
| 244 |
+
parser.add_argument('--qtheta_checkpoint',
|
| 245 |
+
default='results/checkpoints_cam_v3/best_phase2.pt')
|
| 246 |
+
parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
|
| 247 |
+
parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
|
| 248 |
+
parser.add_argument('--ref_chain', default='A')
|
| 249 |
+
parser.add_argument('--n_particles', type=int, default=16,
|
| 250 |
+
help='Particles per round')
|
| 251 |
+
parser.add_argument('--n_rounds', type=int, default=4,
|
| 252 |
+
help='Number of SMC rounds')
|
| 253 |
+
parser.add_argument('--N_step', type=int, default=400)
|
| 254 |
+
parser.add_argument('--gpu', type=int, default=0)
|
| 255 |
+
parser.add_argument('--outdir', default='results/pxdesign_smc')
|
| 256 |
+
args = parser.parse_args()
|
| 257 |
+
|
| 258 |
+
smc_particle_filter(args)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
if __name__ == '__main__':
|
| 262 |
+
main()
|
code/scripts/pxdesign_guidance/tds_pxdesign.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PXDesign + Twisted Diffusion Sampling (TDS).
|
| 3 |
+
|
| 4 |
+
Multi-round particle filtering with guided PXDesign:
|
| 5 |
+
Round r:
|
| 6 |
+
1. Generate N particles via PXDesign with Q_theta classifier guidance
|
| 7 |
+
2. Score each particle with Q_theta selectivity margin
|
| 8 |
+
3. Compute importance weights w_i ~ exp(margin_i / temperature)
|
| 9 |
+
4. Resample particles (keep best, discard worst)
|
| 10 |
+
5. Add perturbation noise for diversity
|
| 11 |
+
|
| 12 |
+
This combines in-process guidance (the "twisted proposal") with post-hoc
|
| 13 |
+
importance-weighted resampling for highest-quality designs.
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
python code/scripts/pxdesign_guidance/tds_pxdesign.py \
|
| 17 |
+
--input experiments/pxdesign_cam/output/cam_binder.json \
|
| 18 |
+
--qtheta_checkpoint results/checkpoints_cam_v3/best_phase2.pt \
|
| 19 |
+
--ref_holo data/pdbs/cam_holo/3CLN.pdb \
|
| 20 |
+
--ref_apo data/pdbs/cam_apo/1CFD.pdb \
|
| 21 |
+
--n_particles 16 --n_rounds 4 \
|
| 22 |
+
--guidance_scale 0.5 \
|
| 23 |
+
--gpu 0
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
import argparse
|
| 29 |
+
import json
|
| 30 |
+
import logging
|
| 31 |
+
import shutil
|
| 32 |
+
import subprocess
|
| 33 |
+
from glob import glob
|
| 34 |
+
|
| 35 |
+
import numpy as np
|
| 36 |
+
import torch
|
| 37 |
+
|
| 38 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 42 |
+
_ALLO_CODE_DIR = os.path.abspath(os.path.join(_SCRIPT_DIR, '..', '..'))
|
| 43 |
+
_ALLO_ROOT = os.path.abspath(os.path.join(_ALLO_CODE_DIR, '..'))
|
| 44 |
+
|
| 45 |
+
if _ALLO_CODE_DIR not in sys.path:
|
| 46 |
+
sys.path.insert(0, _ALLO_CODE_DIR)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def compute_ess(log_weights):
|
| 50 |
+
"""Compute effective sample size from log-weights."""
|
| 51 |
+
log_weights = log_weights - log_weights.max()
|
| 52 |
+
weights = np.exp(log_weights)
|
| 53 |
+
weights = weights / weights.sum()
|
| 54 |
+
return 1.0 / (weights ** 2).sum()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def run_guided_pxdesign_batch(input_json, outdir, n_sample, n_step,
|
| 58 |
+
gpu, guidance_args):
|
| 59 |
+
"""Run guided PXDesign as a subprocess."""
|
| 60 |
+
pxdesign_python = 'python'
|
| 61 |
+
cmd = [
|
| 62 |
+
pxdesign_python,
|
| 63 |
+
os.path.join(_SCRIPT_DIR, 'guided_pxdesign.py'),
|
| 64 |
+
'--input', input_json,
|
| 65 |
+
'--qtheta_checkpoint', guidance_args['checkpoint'],
|
| 66 |
+
'--ref_holo', guidance_args['ref_holo'],
|
| 67 |
+
'--ref_apo', guidance_args['ref_apo'],
|
| 68 |
+
'--ref_chain', guidance_args['ref_chain'],
|
| 69 |
+
'--guidance_scale', str(guidance_args['guidance_scale']),
|
| 70 |
+
'--guidance_start', str(guidance_args.get('guidance_start', 0.8)),
|
| 71 |
+
'--guidance_end', str(guidance_args.get('guidance_end', 0.1)),
|
| 72 |
+
'--N_sample', str(n_sample),
|
| 73 |
+
'--N_step', str(n_step),
|
| 74 |
+
'--gpu', str(gpu),
|
| 75 |
+
'--outdir', outdir,
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
env = os.environ.copy()
|
| 79 |
+
# Inherit CUDA_VISIBLE_DEVICES from parent
|
| 80 |
+
|
| 81 |
+
logger.info(f"Running guided PXDesign: {n_sample} samples -> {outdir}")
|
| 82 |
+
result = subprocess.run(cmd, capture_output=True, text=True, env=env,
|
| 83 |
+
timeout=7200)
|
| 84 |
+
|
| 85 |
+
if result.returncode != 0:
|
| 86 |
+
logger.error(f"PXDesign failed:\n{result.stderr[-2000:]}")
|
| 87 |
+
return False
|
| 88 |
+
return True
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def run_vanilla_pxdesign_batch(input_json, outdir, n_sample, n_step, gpu):
|
| 92 |
+
"""Run vanilla PXDesign (no guidance) as a subprocess."""
|
| 93 |
+
pxdesign_env = 'python'
|
| 94 |
+
|
| 95 |
+
cmd = [
|
| 96 |
+
pxdesign_env, '-m', 'pxdesign.runner.inference',
|
| 97 |
+
'--dump_dir', outdir,
|
| 98 |
+
'--input', input_json,
|
| 99 |
+
'--dtype', 'bf16',
|
| 100 |
+
'--N_sample', str(n_sample),
|
| 101 |
+
'--N_step', str(n_step),
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
env = os.environ.copy()
|
| 105 |
+
# Inherit CUDA_VISIBLE_DEVICES from parent
|
| 106 |
+
|
| 107 |
+
logger.info(f"Running vanilla PXDesign: {n_sample} samples -> {outdir}")
|
| 108 |
+
result = subprocess.run(cmd, capture_output=True, text=True, env=env,
|
| 109 |
+
timeout=7200)
|
| 110 |
+
|
| 111 |
+
if result.returncode != 0:
|
| 112 |
+
logger.error(f"PXDesign failed:\n{result.stderr[-2000:]}")
|
| 113 |
+
return False
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def collect_pdbs(outdir):
|
| 118 |
+
"""Collect PDB/CIF paths from PXDesign output directory."""
|
| 119 |
+
pdbs = []
|
| 120 |
+
for ext in ('*.pdb', '*.cif'):
|
| 121 |
+
pdbs.extend(glob(os.path.join(outdir, '**/' + ext), recursive=True))
|
| 122 |
+
pdbs = sorted(pdbs)
|
| 123 |
+
filtered = [p for p in pdbs if 'sample' in os.path.basename(p).lower()
|
| 124 |
+
or 'design' in os.path.basename(p).lower()
|
| 125 |
+
or 'rank' in os.path.basename(p).lower()]
|
| 126 |
+
return filtered if filtered else pdbs
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def tds_particle_filter(args):
|
| 130 |
+
"""Run TDS particle filtering with PXDesign."""
|
| 131 |
+
from qtheta_pxdesign import QThetaPXDesignGuidance
|
| 132 |
+
|
| 133 |
+
outdir = os.path.join(_ALLO_ROOT, args.outdir)
|
| 134 |
+
os.makedirs(outdir, exist_ok=True)
|
| 135 |
+
|
| 136 |
+
# Initialize scorer
|
| 137 |
+
guidance = QThetaPXDesignGuidance(
|
| 138 |
+
checkpoint=os.path.join(_ALLO_ROOT, args.qtheta_checkpoint),
|
| 139 |
+
ref_holo=os.path.join(_ALLO_ROOT, args.ref_holo),
|
| 140 |
+
ref_apo=os.path.join(_ALLO_ROOT, args.ref_apo),
|
| 141 |
+
ref_chain=args.ref_chain,
|
| 142 |
+
device=f'cuda:{args.gpu}',
|
| 143 |
+
)
|
| 144 |
+
guidance._lazy_init()
|
| 145 |
+
|
| 146 |
+
guidance_args = {
|
| 147 |
+
'checkpoint': args.qtheta_checkpoint,
|
| 148 |
+
'ref_holo': args.ref_holo,
|
| 149 |
+
'ref_apo': args.ref_apo,
|
| 150 |
+
'ref_chain': args.ref_chain,
|
| 151 |
+
'guidance_scale': args.guidance_scale,
|
| 152 |
+
'guidance_start': args.guidance_start,
|
| 153 |
+
'guidance_end': args.guidance_end,
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
all_designs = []
|
| 157 |
+
round_summaries = []
|
| 158 |
+
|
| 159 |
+
for round_idx in range(args.n_rounds):
|
| 160 |
+
round_dir = os.path.join(outdir, f'round_{round_idx}')
|
| 161 |
+
os.makedirs(round_dir, exist_ok=True)
|
| 162 |
+
|
| 163 |
+
logger.info(f"\n{'='*60}")
|
| 164 |
+
logger.info(f"TDS Round {round_idx + 1}/{args.n_rounds}")
|
| 165 |
+
logger.info(f"{'='*60}")
|
| 166 |
+
|
| 167 |
+
# Generate particles via guided PXDesign
|
| 168 |
+
gen_dir = os.path.join(round_dir, 'generated')
|
| 169 |
+
success = run_guided_pxdesign_batch(
|
| 170 |
+
input_json=os.path.join(_ALLO_ROOT, args.input),
|
| 171 |
+
outdir=gen_dir,
|
| 172 |
+
n_sample=args.n_particles,
|
| 173 |
+
n_step=args.N_step,
|
| 174 |
+
gpu=args.gpu,
|
| 175 |
+
guidance_args=guidance_args,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if not success:
|
| 179 |
+
logger.warning(f"Round {round_idx} generation failed, skipping")
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
# Collect and score particles
|
| 183 |
+
pdbs = collect_pdbs(gen_dir)
|
| 184 |
+
if not pdbs:
|
| 185 |
+
logger.warning(f"No PDBs found in round {round_idx}")
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
logger.info(f"Scoring {len(pdbs)} particles...")
|
| 189 |
+
round_results = []
|
| 190 |
+
for pdb_path in pdbs:
|
| 191 |
+
result = guidance.score_design(pdb_path)
|
| 192 |
+
if result is not None:
|
| 193 |
+
result['pdb_path'] = pdb_path
|
| 194 |
+
result['design_id'] = os.path.basename(pdb_path).replace('.pdb', '').replace('.cif', '')
|
| 195 |
+
result['round'] = round_idx
|
| 196 |
+
round_results.append(result)
|
| 197 |
+
|
| 198 |
+
if not round_results:
|
| 199 |
+
logger.warning(f"No scorable designs in round {round_idx}")
|
| 200 |
+
continue
|
| 201 |
+
|
| 202 |
+
margins = np.array([r['margin'] for r in round_results])
|
| 203 |
+
|
| 204 |
+
# Compute importance weights
|
| 205 |
+
log_weights = margins / args.temperature
|
| 206 |
+
ess = compute_ess(log_weights)
|
| 207 |
+
|
| 208 |
+
round_summary = {
|
| 209 |
+
'round': round_idx,
|
| 210 |
+
'n_particles': len(round_results),
|
| 211 |
+
'margin_mean': float(margins.mean()),
|
| 212 |
+
'margin_std': float(margins.std()),
|
| 213 |
+
'margin_max': float(margins.max()),
|
| 214 |
+
'frac_positive': float((margins > 0).mean()),
|
| 215 |
+
'ess': float(ess),
|
| 216 |
+
}
|
| 217 |
+
round_summaries.append(round_summary)
|
| 218 |
+
|
| 219 |
+
logger.info(f"Round {round_idx}: margin={margins.mean():.3f}±{margins.std():.3f}, "
|
| 220 |
+
f"max={margins.max():.3f}, S>0={round_summary['frac_positive']:.1%}, "
|
| 221 |
+
f"ESS={ess:.1f}/{len(round_results)}")
|
| 222 |
+
|
| 223 |
+
# Add to design pool
|
| 224 |
+
all_designs.extend(round_results)
|
| 225 |
+
|
| 226 |
+
# Resample for next round (top-K selection for PXDesign since
|
| 227 |
+
# we can't easily perturb and re-denoise)
|
| 228 |
+
if round_idx < args.n_rounds - 1:
|
| 229 |
+
# Copy best designs to inform next round
|
| 230 |
+
# For PXDesign, each round generates fresh samples with guidance
|
| 231 |
+
# Resampling influence is through the guidance strength
|
| 232 |
+
# Increase guidance scale for later rounds
|
| 233 |
+
guidance_args['guidance_scale'] = args.guidance_scale * (1.0 + 0.2 * (round_idx + 1))
|
| 234 |
+
logger.info(f"Increasing guidance scale to {guidance_args['guidance_scale']:.2f} "
|
| 235 |
+
f"for next round")
|
| 236 |
+
|
| 237 |
+
# Final summary
|
| 238 |
+
if all_designs:
|
| 239 |
+
all_designs.sort(key=lambda x: x['margin'], reverse=True)
|
| 240 |
+
all_margins = np.array([d['margin'] for d in all_designs])
|
| 241 |
+
holo_scores = np.array([d['q_holo'] for d in all_designs])
|
| 242 |
+
|
| 243 |
+
# Best-of-K
|
| 244 |
+
bok = {}
|
| 245 |
+
for K in [1, 2, 5, 10]:
|
| 246 |
+
n_trials = 2000
|
| 247 |
+
n_avail = len(all_margins)
|
| 248 |
+
successes = sum(
|
| 249 |
+
1 for _ in range(n_trials)
|
| 250 |
+
if all_margins[np.random.choice(n_avail, min(K, n_avail), replace=False)].max() > 0
|
| 251 |
+
)
|
| 252 |
+
bok[K] = successes / n_trials
|
| 253 |
+
|
| 254 |
+
summary = {
|
| 255 |
+
'method': 'PXDesign + TDS',
|
| 256 |
+
'n_rounds': args.n_rounds,
|
| 257 |
+
'n_particles_per_round': args.n_particles,
|
| 258 |
+
'total_designs': len(all_designs),
|
| 259 |
+
'guidance_scale': args.guidance_scale,
|
| 260 |
+
'temperature': args.temperature,
|
| 261 |
+
'margin_mean': float(all_margins.mean()),
|
| 262 |
+
'margin_std': float(all_margins.std()),
|
| 263 |
+
'margin_max': float(all_margins.max()),
|
| 264 |
+
'frac_positive': float((all_margins > 0).mean()),
|
| 265 |
+
'q_holo_mean': float(holo_scores.mean()),
|
| 266 |
+
'best_of_k': {str(k): v for k, v in bok.items()},
|
| 267 |
+
'round_summaries': round_summaries,
|
| 268 |
+
'top5': all_designs[:5],
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
with open(os.path.join(outdir, 'tds_scores.json'), 'w') as f:
|
| 272 |
+
json.dump(all_designs, f, indent=2)
|
| 273 |
+
with open(os.path.join(outdir, 'tds_summary.json'), 'w') as f:
|
| 274 |
+
json.dump(summary, f, indent=2)
|
| 275 |
+
|
| 276 |
+
# Copy best designs to top-level
|
| 277 |
+
best_dir = os.path.join(outdir, 'best_designs')
|
| 278 |
+
os.makedirs(best_dir, exist_ok=True)
|
| 279 |
+
for i, d in enumerate(all_designs[:20]):
|
| 280 |
+
if os.path.exists(d['pdb_path']):
|
| 281 |
+
dest = os.path.join(best_dir, f'rank_{i:02d}_{d["design_id"]}.pdb')
|
| 282 |
+
shutil.copy2(d['pdb_path'], dest)
|
| 283 |
+
|
| 284 |
+
logger.info(f"\n{'='*60}")
|
| 285 |
+
logger.info(f"PXDesign + TDS Results ({len(all_designs)} total designs)")
|
| 286 |
+
logger.info(f" Margin: {all_margins.mean():.3f} ± {all_margins.std():.3f}")
|
| 287 |
+
logger.info(f" Max margin: {all_margins.max():.3f}")
|
| 288 |
+
logger.info(f" Fraction S > 0: {(all_margins > 0).mean():.1%}")
|
| 289 |
+
logger.info(f" Q(holo) mean: {holo_scores.mean():.3f}")
|
| 290 |
+
logger.info(f" Best-of-K:")
|
| 291 |
+
for k, v in sorted(bok.items()):
|
| 292 |
+
logger.info(f" K={k:3d}: {v:.3f}")
|
| 293 |
+
logger.info(f"{'='*60}")
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def main():
|
| 297 |
+
parser = argparse.ArgumentParser(description='PXDesign + TDS')
|
| 298 |
+
parser.add_argument('--input', default='experiments/pxdesign_cam/output/cam_binder.json')
|
| 299 |
+
parser.add_argument('--qtheta_checkpoint',
|
| 300 |
+
default='results/checkpoints_cam_v3/best_phase2.pt')
|
| 301 |
+
parser.add_argument('--ref_holo', default='data/pdbs/cam_holo/3CLN.pdb')
|
| 302 |
+
parser.add_argument('--ref_apo', default='data/pdbs/cam_apo/1CFD.pdb')
|
| 303 |
+
parser.add_argument('--ref_chain', default='A')
|
| 304 |
+
parser.add_argument('--n_particles', type=int, default=16,
|
| 305 |
+
help='Particles per round')
|
| 306 |
+
parser.add_argument('--n_rounds', type=int, default=4,
|
| 307 |
+
help='Number of TDS rounds')
|
| 308 |
+
parser.add_argument('--guidance_scale', type=float, default=0.5,
|
| 309 |
+
help='Initial guidance scale')
|
| 310 |
+
parser.add_argument('--guidance_start', type=float, default=0.8)
|
| 311 |
+
parser.add_argument('--guidance_end', type=float, default=0.1)
|
| 312 |
+
parser.add_argument('--temperature', type=float, default=0.5,
|
| 313 |
+
help='Temperature for importance weights')
|
| 314 |
+
parser.add_argument('--N_step', type=int, default=400)
|
| 315 |
+
parser.add_argument('--gpu', type=int, default=0)
|
| 316 |
+
parser.add_argument('--outdir', default='results/pxdesign_tds')
|
| 317 |
+
args = parser.parse_args()
|
| 318 |
+
|
| 319 |
+
tds_particle_filter(args)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
if __name__ == '__main__':
|
| 323 |
+
main()
|
code/scripts/rescore.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Re-score binder PDB designs with a Q_theta checkpoint.
|
| 4 |
+
|
| 5 |
+
Walks a directory of designs (binder PDB + sibling holo / apo receptor PDBs),
|
| 6 |
+
runs each through DifferentiableQTheta, and writes per-design
|
| 7 |
+
S = Q_theta(holo) - Q_theta(apo) plus the raw holo/apo scores to JSON.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python code/scripts/rescore.py \\
|
| 11 |
+
--checkpoint checkpoints/Q_theta_phase2.pt \\
|
| 12 |
+
--gpu 0
|
| 13 |
+
"""
|
| 14 |
+
import os, sys, json, argparse, glob, logging
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
BASE = str(Path(__file__).resolve().parent.parent.parent)
|
| 23 |
+
sys.path.insert(0, os.path.join(BASE, 'code'))
|
| 24 |
+
sys.path.insert(0, BASE)
|
| 25 |
+
|
| 26 |
+
from models.differentiable_features import DifferentiableQTheta
|
| 27 |
+
from utils.pdb_utils import load_structure, get_residues, get_backbone_coords, get_aa_indices, align_structures
|
| 28 |
+
|
| 29 |
+
HOLO_PDB = os.path.join(BASE, 'data/pdbs/cam_holo/3CLN.pdb')
|
| 30 |
+
APO_PDB = os.path.join(BASE, 'data/pdbs/cam_apo/1CFD.pdb')
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def score_pdb_list(dq, pdb_list, ref_resnums, ref_coords, device):
|
| 34 |
+
"""Score a list of design PDB files."""
|
| 35 |
+
results = []
|
| 36 |
+
for pdb_path in pdb_list:
|
| 37 |
+
name = os.path.basename(pdb_path).replace(".pdb", "")
|
| 38 |
+
try:
|
| 39 |
+
design_model = load_structure(pdb_path)
|
| 40 |
+
chains = [c.id for c in design_model.get_chains()]
|
| 41 |
+
rec_chain = 'A' if 'A' in chains else chains[0]
|
| 42 |
+
binder_chain = 'B' if 'B' in chains else [c for c in chains if c != rec_chain][0]
|
| 43 |
+
|
| 44 |
+
rec_res = get_residues(design_model[rec_chain])
|
| 45 |
+
binder_res = get_residues(design_model[binder_chain])
|
| 46 |
+
rec_coords_d, _ = get_backbone_coords(rec_res)
|
| 47 |
+
binder_coords, binder_mask = get_backbone_coords(binder_res)
|
| 48 |
+
binder_aa_idx = get_aa_indices(binder_res)
|
| 49 |
+
|
| 50 |
+
design_resnums = {r.get_id()[1]: i for i, r in enumerate(rec_res)}
|
| 51 |
+
common = sorted(set(design_resnums.keys()) & set(ref_resnums.keys()))
|
| 52 |
+
if len(common) < 10:
|
| 53 |
+
logger.warning(f" Skip {name}: <10 common residues")
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
d_ca = rec_coords_d[[design_resnums[r] for r in common], 1]
|
| 57 |
+
r_ca = ref_coords[[ref_resnums[r] for r in common], 1]
|
| 58 |
+
mobile_center = d_ca.mean(0)
|
| 59 |
+
ref_center = r_ca.mean(0)
|
| 60 |
+
_, R = align_structures(d_ca, r_ca)
|
| 61 |
+
|
| 62 |
+
flat = binder_coords.reshape(-1, 3) - mobile_center
|
| 63 |
+
aligned_binder = (flat @ R.T + ref_center).reshape(-1, 4, 3)
|
| 64 |
+
|
| 65 |
+
coords_t = torch.from_numpy(aligned_binder).float().to(device)
|
| 66 |
+
mask_t = torch.from_numpy(binder_mask).bool().to(device)
|
| 67 |
+
aa_t = torch.from_numpy(binder_aa_idx).long().to(device)
|
| 68 |
+
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
q_holo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t,
|
| 71 |
+
receptor_label='holo').item()
|
| 72 |
+
q_apo = dq.score(coords_t, mask_t, binder_aa_idx=aa_t,
|
| 73 |
+
receptor_label='apo').item()
|
| 74 |
+
S = q_holo - q_apo
|
| 75 |
+
results.append({"design": name, "Q_holo": q_holo, "Q_apo": q_apo, "S": S})
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.warning(f" Skip {name}: {e}")
|
| 78 |
+
return results
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def summarize(results, label):
|
| 82 |
+
if not results:
|
| 83 |
+
return {}
|
| 84 |
+
S = [r["S"] for r in results]
|
| 85 |
+
return {
|
| 86 |
+
"method": label, "n": len(S),
|
| 87 |
+
"S_mean": float(np.mean(S)), "S_std": float(np.std(S)),
|
| 88 |
+
"S_pos_pct": float(np.mean([s > 0 for s in S]) * 100),
|
| 89 |
+
"Q_holo_mean": float(np.mean([r["Q_holo"] for r in results])),
|
| 90 |
+
"Q_apo_mean": float(np.mean([r["Q_apo"] for r in results])),
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def main():
|
| 95 |
+
parser = argparse.ArgumentParser()
|
| 96 |
+
parser.add_argument("--gpu", type=int, default=7)
|
| 97 |
+
parser.add_argument("--checkpoint", default="checkpoints/Q_theta_phase2.pt")
|
| 98 |
+
args = parser.parse_args()
|
| 99 |
+
|
| 100 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
|
| 101 |
+
device = "cuda:0"
|
| 102 |
+
|
| 103 |
+
logger.info(f"Loading Q_theta from {args.checkpoint}")
|
| 104 |
+
dq = DifferentiableQTheta(checkpoint_path=args.checkpoint, device=device,
|
| 105 |
+
esm_dir=os.path.join(BASE, "data/esm2_embeddings"))
|
| 106 |
+
dq.load_receptor(HOLO_PDB, chain='A', label='holo', esm_target='cam')
|
| 107 |
+
dq.load_receptor(APO_PDB, chain='A', label='apo', esm_target='cam')
|
| 108 |
+
|
| 109 |
+
ref_model = load_structure(HOLO_PDB)
|
| 110 |
+
ref_res = get_residues(ref_model['A'])
|
| 111 |
+
ref_coords, _ = get_backbone_coords(ref_res)
|
| 112 |
+
ref_resnums = {r.get_id()[1]: i for i, r in enumerate(ref_res)}
|
| 113 |
+
|
| 114 |
+
output_dir = os.path.join(BASE, "results/v2_strict_holdout/scoring")
|
| 115 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 116 |
+
|
| 117 |
+
# Define design directories
|
| 118 |
+
design_sets = {
|
| 119 |
+
"vanilla": os.path.join(BASE, "results/independent_validation/vanilla/holo_pdbs"),
|
| 120 |
+
"langevin": os.path.join(BASE, "results/langevin_refinement/refined_pdbs"),
|
| 121 |
+
"classifier": os.path.join(BASE, "results/guided_diffusion/guided"),
|
| 122 |
+
"smc_r3": os.path.join(BASE, "results/smc_guidance/cam/round_3"),
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
# Also check for TDS and PXDesign
|
| 126 |
+
tds_dirs = glob.glob(os.path.join(BASE, "results/tds_guidance/cam/designs"))
|
| 127 |
+
if tds_dirs:
|
| 128 |
+
design_sets["tds"] = tds_dirs[0]
|
| 129 |
+
|
| 130 |
+
# PXDesign directories
|
| 131 |
+
for px_method in ["pxdesign_scoring", "pxdesign_classifier", "pxdesign_tds",
|
| 132 |
+
"pxdesign_smc", "pxdesign_langevin"]:
|
| 133 |
+
px_dir = os.path.join(BASE, f"results_familysplit/design_bd30/{px_method}")
|
| 134 |
+
if not os.path.exists(px_dir):
|
| 135 |
+
px_dir = os.path.join(BASE, f"results/{px_method}")
|
| 136 |
+
if os.path.exists(px_dir):
|
| 137 |
+
pdbs = glob.glob(os.path.join(px_dir, "*.pdb"))
|
| 138 |
+
if pdbs:
|
| 139 |
+
design_sets[px_method] = px_dir
|
| 140 |
+
|
| 141 |
+
all_results = {}
|
| 142 |
+
summaries = []
|
| 143 |
+
|
| 144 |
+
for method, pdb_dir in design_sets.items():
|
| 145 |
+
if not os.path.exists(pdb_dir):
|
| 146 |
+
logger.warning(f" {method}: directory not found ({pdb_dir})")
|
| 147 |
+
continue
|
| 148 |
+
pdbs = sorted(glob.glob(os.path.join(pdb_dir, "*.pdb")))
|
| 149 |
+
if not pdbs:
|
| 150 |
+
logger.warning(f" {method}: no PDB files")
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
logger.info(f"\n=== {method} ({len(pdbs)} designs) ===")
|
| 154 |
+
results = score_pdb_list(dq, pdbs, ref_resnums, ref_coords, device)
|
| 155 |
+
s = summarize(results, method)
|
| 156 |
+
if s:
|
| 157 |
+
summaries.append(s)
|
| 158 |
+
logger.info(f" {method}: n={s['n']}, S̄={s['S_mean']:.3f}±{s['S_std']:.3f}, "
|
| 159 |
+
f"S>0={s['S_pos_pct']:.0f}%, Q+={s['Q_holo_mean']:.3f}, Q-={s['Q_apo_mean']:.3f}")
|
| 160 |
+
all_results[method] = {"results": results, "summary": s}
|
| 161 |
+
|
| 162 |
+
# Save
|
| 163 |
+
with open(os.path.join(output_dir, "rescore_v2_all.json"), "w") as f:
|
| 164 |
+
json.dump(all_results, f, indent=2)
|
| 165 |
+
|
| 166 |
+
# Print summary table
|
| 167 |
+
print("\n" + "=" * 70)
|
| 168 |
+
print("V2 RESCORING SUMMARY (strict holdout, CaM OOD)")
|
| 169 |
+
print("=" * 70)
|
| 170 |
+
print(f"{'Method':20s} {'n':>4s} {'S̄':>8s} {'±σ':>6s} {'S>0%':>6s} {'Q+':>6s} {'Q-':>6s}")
|
| 171 |
+
print("-" * 70)
|
| 172 |
+
for s in sorted(summaries, key=lambda x: x['S_mean'], reverse=True):
|
| 173 |
+
print(f"{s['method']:20s} {s['n']:4d} {s['S_mean']:8.3f} {s['S_std']:6.3f} "
|
| 174 |
+
f"{s['S_pos_pct']:5.1f}% {s['Q_holo_mean']:6.3f} {s['Q_apo_mean']:6.3f}")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
main()
|
code/trainers/__init__.py
ADDED
|
File without changes
|
code/trainers/trainer.py
ADDED
|
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Trainer for the Q_theta state-selectivity scorer.
|
| 3 |
+
|
| 4 |
+
Implements two-phase training:
|
| 5 |
+
Phase 1: DockQ regression (learn complex quality from all data)
|
| 6 |
+
Phase 2: Selectivity fine-tuning (learn to rank X+ > X- for the same binder)
|
| 7 |
+
|
| 8 |
+
Integrates with Weights & Biases for experiment tracking.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import time
|
| 13 |
+
import logging
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.optim import AdamW
|
| 18 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
|
| 19 |
+
from scipy.stats import spearmanr
|
| 20 |
+
from sklearn.metrics import roc_auc_score
|
| 21 |
+
|
| 22 |
+
import wandb
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AverageMeter:
|
| 28 |
+
def __init__(self):
|
| 29 |
+
self.reset()
|
| 30 |
+
|
| 31 |
+
def reset(self):
|
| 32 |
+
self.val = 0.0
|
| 33 |
+
self.avg = 0.0
|
| 34 |
+
self.sum = 0.0
|
| 35 |
+
self.count = 0
|
| 36 |
+
|
| 37 |
+
def update(self, val, n=1):
|
| 38 |
+
self.val = val
|
| 39 |
+
self.sum += val * n
|
| 40 |
+
self.count += n
|
| 41 |
+
self.avg = self.sum / self.count
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class AlloDesignerTrainer:
|
| 45 |
+
"""
|
| 46 |
+
Two-phase trainer for Q_theta.
|
| 47 |
+
|
| 48 |
+
Phase 1 (DockQ regression):
|
| 49 |
+
- Minimizes MSE(Q_theta(X, Y), DockQ_label) on all complex types
|
| 50 |
+
- Learns general complex quality
|
| 51 |
+
|
| 52 |
+
Phase 2 (Selectivity fine-tuning):
|
| 53 |
+
- Minimizes selectivity margin loss on paired (pos, neg) data
|
| 54 |
+
- Learns to rank Q(X+, Y) > Q(X-, Y)
|
| 55 |
+
- Combined: L = L_regression + lambda_rank * L_selectivity
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(self, model, config, device='cuda'):
|
| 59 |
+
self.model = model.to(device)
|
| 60 |
+
self.config = config
|
| 61 |
+
self.device = device
|
| 62 |
+
self.use_sam = config.get('optimizer', 'adamw') == 'sam'
|
| 63 |
+
|
| 64 |
+
# Optimizer
|
| 65 |
+
if self.use_sam:
|
| 66 |
+
from utils.sam import SAM
|
| 67 |
+
self.optimizer = SAM(
|
| 68 |
+
model.parameters(),
|
| 69 |
+
base_optimizer=AdamW,
|
| 70 |
+
rho=0.05,
|
| 71 |
+
lr=config.get('lr', 1e-4),
|
| 72 |
+
weight_decay=config.get('weight_decay', 1e-4),
|
| 73 |
+
betas=(0.9, 0.999),
|
| 74 |
+
)
|
| 75 |
+
# SAM wraps AdamW; scheduler goes on base_optimizer
|
| 76 |
+
sched_optimizer = self.optimizer.base_optimizer
|
| 77 |
+
else:
|
| 78 |
+
self.optimizer = AdamW(
|
| 79 |
+
model.parameters(),
|
| 80 |
+
lr=config.get('lr', 1e-4),
|
| 81 |
+
weight_decay=config.get('weight_decay', 1e-4),
|
| 82 |
+
betas=(0.9, 0.999),
|
| 83 |
+
)
|
| 84 |
+
sched_optimizer = self.optimizer
|
| 85 |
+
|
| 86 |
+
# Learning rate scheduler (warmup + cosine)
|
| 87 |
+
n_warmup = config.get('warmup_steps', 100)
|
| 88 |
+
n_total = config.get('max_steps', 5000)
|
| 89 |
+
|
| 90 |
+
warmup_sched = LinearLR(sched_optimizer, start_factor=0.01, end_factor=1.0, total_iters=n_warmup)
|
| 91 |
+
cosine_sched = CosineAnnealingLR(sched_optimizer, T_max=n_total - n_warmup, eta_min=1e-6)
|
| 92 |
+
self.scheduler = SequentialLR(sched_optimizer, [warmup_sched, cosine_sched], milestones=[n_warmup])
|
| 93 |
+
|
| 94 |
+
self.global_step = 0
|
| 95 |
+
self.best_val_metric = -float('inf')
|
| 96 |
+
self.checkpoint_dir = config.get('checkpoint_dir', 'results/checkpoints')
|
| 97 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 98 |
+
|
| 99 |
+
# ------------------------------------------------------------------ #
|
| 100 |
+
# Phase 1: DockQ regression
|
| 101 |
+
# ------------------------------------------------------------------ #
|
| 102 |
+
|
| 103 |
+
def train_step_phase1(self, batch):
|
| 104 |
+
"""Single training step for Phase 1 (DockQ regression)."""
|
| 105 |
+
self.model.train()
|
| 106 |
+
node_feats = batch['node_feats'].to(self.device) # [B, N, node_dim]
|
| 107 |
+
edge_feats = batch['edge_feats'].to(self.device) # [B, N, N, edge_dim]
|
| 108 |
+
node_mask = batch['node_mask'].to(self.device) # [B, N]
|
| 109 |
+
labels = batch['label'].to(self.device) # [B]
|
| 110 |
+
esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None
|
| 111 |
+
|
| 112 |
+
self.optimizer.zero_grad()
|
| 113 |
+
|
| 114 |
+
scores = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats) # [B]
|
| 115 |
+
loss = nn.functional.mse_loss(scores, labels)
|
| 116 |
+
|
| 117 |
+
loss.backward()
|
| 118 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 119 |
+
|
| 120 |
+
if self.use_sam:
|
| 121 |
+
self.optimizer.first_step()
|
| 122 |
+
# Second forward-backward pass
|
| 123 |
+
scores2 = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats)
|
| 124 |
+
loss2 = nn.functional.mse_loss(scores2, labels)
|
| 125 |
+
self.optimizer.zero_grad()
|
| 126 |
+
loss2.backward()
|
| 127 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 128 |
+
self.optimizer.second_step()
|
| 129 |
+
else:
|
| 130 |
+
self.optimizer.step()
|
| 131 |
+
|
| 132 |
+
self.scheduler.step()
|
| 133 |
+
self.global_step += 1
|
| 134 |
+
|
| 135 |
+
return {'loss': loss.item(), 'scores': scores.detach(), 'labels': labels}
|
| 136 |
+
|
| 137 |
+
def run_phase1(self, train_loader, val_loader, n_epochs: int = 30, run_name: str = 'phase1'):
|
| 138 |
+
"""Phase 1 training loop."""
|
| 139 |
+
logger.info(f"Starting Phase 1 (DockQ regression) for {n_epochs} epochs")
|
| 140 |
+
wandb.define_metric('phase1/step')
|
| 141 |
+
wandb.define_metric('phase1/*', step_metric='phase1/step')
|
| 142 |
+
|
| 143 |
+
for epoch in range(n_epochs):
|
| 144 |
+
# Train
|
| 145 |
+
train_meter = AverageMeter()
|
| 146 |
+
all_scores, all_labels = [], []
|
| 147 |
+
|
| 148 |
+
for batch in train_loader:
|
| 149 |
+
result = self.train_step_phase1(batch)
|
| 150 |
+
train_meter.update(result['loss'], n=len(result['scores']))
|
| 151 |
+
all_scores.append(result['scores'].cpu().numpy())
|
| 152 |
+
all_labels.append(result['labels'].cpu().numpy())
|
| 153 |
+
|
| 154 |
+
if self.global_step % 50 == 0:
|
| 155 |
+
wandb.log({
|
| 156 |
+
'phase1/train_loss': result['loss'],
|
| 157 |
+
'phase1/lr': self.optimizer.param_groups[0]['lr'],
|
| 158 |
+
'phase1/step': self.global_step,
|
| 159 |
+
})
|
| 160 |
+
|
| 161 |
+
# Compute Spearman corr on training data
|
| 162 |
+
all_scores = np.concatenate(all_scores)
|
| 163 |
+
all_labels = np.concatenate(all_labels)
|
| 164 |
+
train_spearman = spearmanr(all_scores, all_labels).correlation
|
| 165 |
+
|
| 166 |
+
# Validate
|
| 167 |
+
val_metrics = self.evaluate_phase1(val_loader)
|
| 168 |
+
|
| 169 |
+
logger.info(
|
| 170 |
+
f"Phase1 Epoch {epoch+1}/{n_epochs} | "
|
| 171 |
+
f"Train Loss: {train_meter.avg:.4f} | "
|
| 172 |
+
f"Train Spearman: {train_spearman:.3f} | "
|
| 173 |
+
f"Val Loss: {val_metrics['val_loss']:.4f} | "
|
| 174 |
+
f"Val Spearman: {val_metrics['val_spearman']:.3f} | "
|
| 175 |
+
f"Val AUC: {val_metrics.get('val_auc', 0):.3f}"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
wandb.log({
|
| 179 |
+
'phase1/epoch': epoch + 1,
|
| 180 |
+
'phase1/train_loss_epoch': train_meter.avg,
|
| 181 |
+
'phase1/train_spearman': train_spearman,
|
| 182 |
+
**{f'phase1/{k}': v for k, v in val_metrics.items()},
|
| 183 |
+
})
|
| 184 |
+
|
| 185 |
+
# Checkpoint best model
|
| 186 |
+
if val_metrics['val_spearman'] > self.best_val_metric:
|
| 187 |
+
self.best_val_metric = val_metrics['val_spearman']
|
| 188 |
+
self.save_checkpoint('best_phase1.pt', extra={'epoch': epoch, 'phase': 1})
|
| 189 |
+
logger.info(f" -> New best Phase 1 model (val_spearman={self.best_val_metric:.3f})")
|
| 190 |
+
|
| 191 |
+
logger.info("Phase 1 training complete.")
|
| 192 |
+
|
| 193 |
+
@torch.no_grad()
|
| 194 |
+
def evaluate_phase1(self, loader):
|
| 195 |
+
"""Evaluate Phase 1 model on val/test set."""
|
| 196 |
+
self.model.eval()
|
| 197 |
+
all_scores, all_labels = [], []
|
| 198 |
+
total_loss = 0.0
|
| 199 |
+
n_batches = 0
|
| 200 |
+
|
| 201 |
+
for batch in loader:
|
| 202 |
+
node_feats = batch['node_feats'].to(self.device)
|
| 203 |
+
edge_feats = batch['edge_feats'].to(self.device)
|
| 204 |
+
node_mask = batch['node_mask'].to(self.device)
|
| 205 |
+
labels = batch['label'].to(self.device)
|
| 206 |
+
esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None
|
| 207 |
+
|
| 208 |
+
scores = self.model(node_feats, edge_feats, node_mask, esm_feats=esm_feats)
|
| 209 |
+
loss = nn.functional.mse_loss(scores, labels)
|
| 210 |
+
|
| 211 |
+
total_loss += loss.item()
|
| 212 |
+
n_batches += 1
|
| 213 |
+
all_scores.append(scores.cpu().numpy())
|
| 214 |
+
all_labels.append(labels.cpu().numpy())
|
| 215 |
+
|
| 216 |
+
all_scores = np.concatenate(all_scores)
|
| 217 |
+
all_labels = np.concatenate(all_labels)
|
| 218 |
+
|
| 219 |
+
spearman = spearmanr(all_scores, all_labels).correlation
|
| 220 |
+
if np.isnan(spearman):
|
| 221 |
+
spearman = 0.0
|
| 222 |
+
|
| 223 |
+
metrics = {
|
| 224 |
+
'val_loss': total_loss / max(n_batches, 1),
|
| 225 |
+
'val_spearman': float(spearman),
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
# AUC for binary quality (label > 0.5 = positive)
|
| 229 |
+
binary_labels = (all_labels > 0.5).astype(int)
|
| 230 |
+
if binary_labels.sum() > 0 and binary_labels.sum() < len(binary_labels):
|
| 231 |
+
try:
|
| 232 |
+
metrics['val_auc'] = roc_auc_score(binary_labels, all_scores)
|
| 233 |
+
except Exception:
|
| 234 |
+
pass
|
| 235 |
+
|
| 236 |
+
return metrics
|
| 237 |
+
|
| 238 |
+
# ------------------------------------------------------------------ #
|
| 239 |
+
# Phase 2: Selectivity fine-tuning
|
| 240 |
+
# ------------------------------------------------------------------ #
|
| 241 |
+
|
| 242 |
+
def train_step_phase2(self, batch, lambda_rank: float = 1.0, margin: float = 0.2,
|
| 243 |
+
lambda_ddg: float = 0.1):
|
| 244 |
+
"""Single training step for Phase 2 (selectivity margin + ddG auxiliary)."""
|
| 245 |
+
self.model.train()
|
| 246 |
+
|
| 247 |
+
pos = batch['pos']
|
| 248 |
+
neg = batch['neg']
|
| 249 |
+
|
| 250 |
+
pos_node = pos['node_feats'].to(self.device)
|
| 251 |
+
pos_edge = pos['edge_feats'].to(self.device)
|
| 252 |
+
pos_mask = pos['node_mask'].to(self.device)
|
| 253 |
+
pos_label = pos['label'].to(self.device)
|
| 254 |
+
pos_ce = pos.get('contact_energy', None)
|
| 255 |
+
if pos_ce is not None:
|
| 256 |
+
pos_ce = pos_ce.to(self.device)
|
| 257 |
+
|
| 258 |
+
neg_node = neg['node_feats'].to(self.device)
|
| 259 |
+
neg_edge = neg['edge_feats'].to(self.device)
|
| 260 |
+
neg_mask = neg['node_mask'].to(self.device)
|
| 261 |
+
pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None
|
| 262 |
+
neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None
|
| 263 |
+
|
| 264 |
+
self.optimizer.zero_grad()
|
| 265 |
+
|
| 266 |
+
pos_scores = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm) # [B]
|
| 267 |
+
neg_scores = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm) # [B]
|
| 268 |
+
|
| 269 |
+
# Regression loss on positive examples
|
| 270 |
+
loss_reg = nn.functional.mse_loss(pos_scores, pos_label)
|
| 271 |
+
|
| 272 |
+
# Selectivity margin loss: pos_score - neg_score > margin
|
| 273 |
+
loss_margin = nn.functional.relu(margin - (pos_scores - neg_scores)).mean()
|
| 274 |
+
|
| 275 |
+
# InfoNCE-style selectivity loss
|
| 276 |
+
eps = 1e-6
|
| 277 |
+
pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps))
|
| 278 |
+
neg_logit = torch.log(neg_scores.clamp(eps, 1 - eps) / (1 - neg_scores).clamp(eps))
|
| 279 |
+
log_denom = torch.stack([pos_logit, neg_logit], dim=-1).logsumexp(dim=-1)
|
| 280 |
+
infonce_loss = -(pos_logit - log_denom).mean()
|
| 281 |
+
|
| 282 |
+
# ddG auxiliary loss: MSE against contact-energy proxy (physics-informed soft label)
|
| 283 |
+
loss_ddg = torch.tensor(0.0, device=self.device)
|
| 284 |
+
if pos_ce is not None and pos_ce.shape[0] > 0:
|
| 285 |
+
# pos_ce is a contact-energy-based ddG proxy in [0, 1]
|
| 286 |
+
# Align positive score toward the contact energy signal
|
| 287 |
+
loss_ddg = nn.functional.mse_loss(pos_scores, pos_ce)
|
| 288 |
+
|
| 289 |
+
loss = loss_reg + lambda_rank * (loss_margin + infonce_loss) + lambda_ddg * loss_ddg
|
| 290 |
+
|
| 291 |
+
loss.backward()
|
| 292 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 293 |
+
|
| 294 |
+
if self.use_sam:
|
| 295 |
+
self.optimizer.first_step()
|
| 296 |
+
# Second forward-backward for SAM
|
| 297 |
+
pos_scores2 = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm)
|
| 298 |
+
neg_scores2 = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm)
|
| 299 |
+
loss_reg2 = nn.functional.mse_loss(pos_scores2, pos_label)
|
| 300 |
+
loss_margin2 = nn.functional.relu(margin - (pos_scores2 - neg_scores2)).mean()
|
| 301 |
+
eps2 = 1e-6
|
| 302 |
+
pl2 = torch.log(pos_scores2.clamp(eps2, 1-eps2) / (1-pos_scores2).clamp(eps2))
|
| 303 |
+
nl2 = torch.log(neg_scores2.clamp(eps2, 1-eps2) / (1-neg_scores2).clamp(eps2))
|
| 304 |
+
ld2 = torch.stack([pl2, nl2], dim=-1).logsumexp(dim=-1)
|
| 305 |
+
infonce2 = -(pl2 - ld2).mean()
|
| 306 |
+
loss2 = loss_reg2 + lambda_rank * (loss_margin2 + infonce2)
|
| 307 |
+
self.optimizer.zero_grad()
|
| 308 |
+
loss2.backward()
|
| 309 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 310 |
+
self.optimizer.second_step()
|
| 311 |
+
else:
|
| 312 |
+
self.optimizer.step()
|
| 313 |
+
|
| 314 |
+
self.scheduler.step()
|
| 315 |
+
self.global_step += 1
|
| 316 |
+
|
| 317 |
+
selectivity_gap = (pos_scores - neg_scores).mean().item()
|
| 318 |
+
|
| 319 |
+
return {
|
| 320 |
+
'loss': loss.item(),
|
| 321 |
+
'loss_reg': loss_reg.item(),
|
| 322 |
+
'loss_margin': loss_margin.item(),
|
| 323 |
+
'loss_infonce': infonce_loss.item(),
|
| 324 |
+
'loss_ddg': loss_ddg.item(),
|
| 325 |
+
'selectivity_gap': selectivity_gap,
|
| 326 |
+
'pos_scores': pos_scores.detach(),
|
| 327 |
+
'neg_scores': neg_scores.detach(),
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
def train_step_phase2_v2(self, batch, lambda_rank: float = 1.0, margin: float = 0.2,
|
| 331 |
+
lambda_ddg: float = 0.0, lambda_path: float = 0.5):
|
| 332 |
+
"""Phase 2 training step with multi-negative + path monotonicity."""
|
| 333 |
+
self.model.train()
|
| 334 |
+
|
| 335 |
+
pos = batch['pos']
|
| 336 |
+
neg = batch['neg']
|
| 337 |
+
|
| 338 |
+
pos_node = pos['node_feats'].to(self.device)
|
| 339 |
+
pos_edge = pos['edge_feats'].to(self.device)
|
| 340 |
+
pos_mask = pos['node_mask'].to(self.device)
|
| 341 |
+
pos_label = pos['label'].to(self.device)
|
| 342 |
+
pos_ce = pos.get('contact_energy', None)
|
| 343 |
+
if pos_ce is not None:
|
| 344 |
+
pos_ce = pos_ce.to(self.device)
|
| 345 |
+
|
| 346 |
+
neg_node = neg['node_feats'].to(self.device)
|
| 347 |
+
neg_edge = neg['edge_feats'].to(self.device)
|
| 348 |
+
neg_mask = neg['node_mask'].to(self.device)
|
| 349 |
+
pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None
|
| 350 |
+
neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None
|
| 351 |
+
|
| 352 |
+
self.optimizer.zero_grad()
|
| 353 |
+
|
| 354 |
+
pos_scores = self.model(pos_node, pos_edge, pos_mask, esm_feats=pos_esm)
|
| 355 |
+
neg_scores = self.model(neg_node, neg_edge, neg_mask, esm_feats=neg_esm)
|
| 356 |
+
|
| 357 |
+
# Score path frames if present
|
| 358 |
+
path_scores = []
|
| 359 |
+
path_taus = batch.get('path_taus', [])
|
| 360 |
+
for path_frame in batch.get('path', []):
|
| 361 |
+
p_node = path_frame['node_feats'].to(self.device)
|
| 362 |
+
p_edge = path_frame['edge_feats'].to(self.device)
|
| 363 |
+
p_mask = path_frame['node_mask'].to(self.device)
|
| 364 |
+
p_score = self.model(p_node, p_edge, p_mask)
|
| 365 |
+
path_scores.append(p_score)
|
| 366 |
+
|
| 367 |
+
# Regression loss on positive examples
|
| 368 |
+
loss_reg = nn.functional.mse_loss(pos_scores, pos_label)
|
| 369 |
+
|
| 370 |
+
# Selectivity margin loss
|
| 371 |
+
loss_margin = nn.functional.relu(margin - (pos_scores - neg_scores)).mean()
|
| 372 |
+
|
| 373 |
+
# InfoNCE-style selectivity loss
|
| 374 |
+
eps = 1e-6
|
| 375 |
+
pos_logit = torch.log(pos_scores.clamp(eps, 1 - eps) / (1 - pos_scores).clamp(eps))
|
| 376 |
+
neg_logit = torch.log(neg_scores.clamp(eps, 1 - eps) / (1 - neg_scores).clamp(eps))
|
| 377 |
+
log_denom = torch.stack([pos_logit, neg_logit], dim=-1).logsumexp(dim=-1)
|
| 378 |
+
infonce_loss = -(pos_logit - log_denom).mean()
|
| 379 |
+
|
| 380 |
+
# ddG auxiliary loss
|
| 381 |
+
loss_ddg = torch.tensor(0.0, device=self.device)
|
| 382 |
+
if pos_ce is not None and pos_ce.shape[0] > 0 and lambda_ddg > 0:
|
| 383 |
+
loss_ddg = nn.functional.mse_loss(pos_scores, pos_ce)
|
| 384 |
+
|
| 385 |
+
# Path monotonicity loss
|
| 386 |
+
loss_path = torch.tensor(0.0, device=self.device)
|
| 387 |
+
if path_scores and lambda_path > 0:
|
| 388 |
+
small_margin = 0.05
|
| 389 |
+
for i in range(len(path_scores) - 1):
|
| 390 |
+
loss_path = loss_path + nn.functional.relu(
|
| 391 |
+
path_scores[i] - path_scores[i + 1] + small_margin
|
| 392 |
+
).mean()
|
| 393 |
+
# Last path frame < positive score
|
| 394 |
+
loss_path = loss_path + nn.functional.relu(
|
| 395 |
+
path_scores[-1] - pos_scores + margin
|
| 396 |
+
).mean()
|
| 397 |
+
# First path frame > negative score
|
| 398 |
+
loss_path = loss_path + nn.functional.relu(
|
| 399 |
+
neg_scores - path_scores[0] + small_margin
|
| 400 |
+
).mean()
|
| 401 |
+
|
| 402 |
+
loss = (loss_reg + lambda_rank * (loss_margin + infonce_loss)
|
| 403 |
+
+ lambda_ddg * loss_ddg + lambda_path * loss_path)
|
| 404 |
+
|
| 405 |
+
loss.backward()
|
| 406 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| 407 |
+
self.optimizer.step()
|
| 408 |
+
self.scheduler.step()
|
| 409 |
+
self.global_step += 1
|
| 410 |
+
|
| 411 |
+
selectivity_gap = (pos_scores - neg_scores).mean().item()
|
| 412 |
+
|
| 413 |
+
return {
|
| 414 |
+
'loss': loss.item(),
|
| 415 |
+
'loss_reg': loss_reg.item(),
|
| 416 |
+
'loss_margin': loss_margin.item(),
|
| 417 |
+
'loss_infonce': infonce_loss.item(),
|
| 418 |
+
'loss_ddg': loss_ddg.item(),
|
| 419 |
+
'loss_path': loss_path.item(),
|
| 420 |
+
'selectivity_gap': selectivity_gap,
|
| 421 |
+
'pos_scores': pos_scores.detach(),
|
| 422 |
+
'neg_scores': neg_scores.detach(),
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
def run_phase2_path(self, train_loader, val_loader, n_epochs: int = 20,
|
| 426 |
+
lambda_rank: float = 1.0, margin: float = 0.2,
|
| 427 |
+
lambda_ddg: float = 0.0, lambda_path: float = 0.5):
|
| 428 |
+
"""Phase 2 with path-aware training loop."""
|
| 429 |
+
logger.info(f"Starting Phase 2 (path-aware) for {n_epochs} epochs "
|
| 430 |
+
f"[lambda_rank={lambda_rank}, lambda_path={lambda_path}]")
|
| 431 |
+
self.best_val_metric = -float('inf')
|
| 432 |
+
|
| 433 |
+
for epoch in range(n_epochs):
|
| 434 |
+
loss_meter = AverageMeter()
|
| 435 |
+
gap_meter = AverageMeter()
|
| 436 |
+
path_meter = AverageMeter()
|
| 437 |
+
|
| 438 |
+
for batch in train_loader:
|
| 439 |
+
result = self.train_step_phase2_v2(
|
| 440 |
+
batch, lambda_rank, margin, lambda_ddg, lambda_path)
|
| 441 |
+
B = len(result['pos_scores'])
|
| 442 |
+
loss_meter.update(result['loss'], B)
|
| 443 |
+
gap_meter.update(result['selectivity_gap'], B)
|
| 444 |
+
path_meter.update(result['loss_path'], B)
|
| 445 |
+
|
| 446 |
+
if self.global_step % 50 == 0:
|
| 447 |
+
wandb.log({
|
| 448 |
+
'phase2/train_loss': result['loss'],
|
| 449 |
+
'phase2/loss_margin': result['loss_margin'],
|
| 450 |
+
'phase2/loss_infonce': result['loss_infonce'],
|
| 451 |
+
'phase2/loss_path': result['loss_path'],
|
| 452 |
+
'phase2/selectivity_gap': result['selectivity_gap'],
|
| 453 |
+
'phase2/lr': self.optimizer.param_groups[0]['lr'],
|
| 454 |
+
'phase2/step': self.global_step,
|
| 455 |
+
})
|
| 456 |
+
|
| 457 |
+
val_metrics = self.evaluate_phase2(val_loader)
|
| 458 |
+
|
| 459 |
+
logger.info(
|
| 460 |
+
f"Phase2-Path Epoch {epoch+1}/{n_epochs} | "
|
| 461 |
+
f"Loss: {loss_meter.avg:.4f} | "
|
| 462 |
+
f"Gap: {gap_meter.avg:.3f} | "
|
| 463 |
+
f"Path: {path_meter.avg:.4f} | "
|
| 464 |
+
f"Val Gap: {val_metrics['val_selectivity_gap']:.3f} | "
|
| 465 |
+
f"Val Acc: {val_metrics['val_ranking_acc']:.3f}"
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
wandb.log({
|
| 469 |
+
'phase2/epoch': epoch + 1,
|
| 470 |
+
'phase2/train_loss_epoch': loss_meter.avg,
|
| 471 |
+
'phase2/train_gap_epoch': gap_meter.avg,
|
| 472 |
+
'phase2/train_path_loss_epoch': path_meter.avg,
|
| 473 |
+
**{f'phase2/{k}': v for k, v in val_metrics.items()},
|
| 474 |
+
})
|
| 475 |
+
|
| 476 |
+
if val_metrics['val_selectivity_gap'] > self.best_val_metric:
|
| 477 |
+
self.best_val_metric = val_metrics['val_selectivity_gap']
|
| 478 |
+
self.save_checkpoint('best_phase2.pt', extra={'epoch': epoch, 'phase': 2})
|
| 479 |
+
logger.info(f" -> New best Phase 2 model (val_gap={self.best_val_metric:.3f})")
|
| 480 |
+
|
| 481 |
+
logger.info("Phase 2 (path-aware) training complete.")
|
| 482 |
+
|
| 483 |
+
def run_phase2(self, train_loader, val_loader, n_epochs: int = 20,
|
| 484 |
+
lambda_rank: float = 1.0, margin: float = 0.2,
|
| 485 |
+
lambda_ddg: float = 0.1):
|
| 486 |
+
"""Phase 2 training loop (selectivity fine-tuning + ddG auxiliary)."""
|
| 487 |
+
logger.info(f"Starting Phase 2 (selectivity fine-tuning) for {n_epochs} epochs "
|
| 488 |
+
f"[lambda_rank={lambda_rank}, lambda_ddg={lambda_ddg}]")
|
| 489 |
+
self.best_val_metric = -float('inf')
|
| 490 |
+
|
| 491 |
+
for epoch in range(n_epochs):
|
| 492 |
+
loss_meter = AverageMeter()
|
| 493 |
+
gap_meter = AverageMeter()
|
| 494 |
+
|
| 495 |
+
for batch in train_loader:
|
| 496 |
+
result = self.train_step_phase2(batch, lambda_rank, margin, lambda_ddg)
|
| 497 |
+
B = len(result['pos_scores'])
|
| 498 |
+
loss_meter.update(result['loss'], B)
|
| 499 |
+
gap_meter.update(result['selectivity_gap'], B)
|
| 500 |
+
|
| 501 |
+
if self.global_step % 50 == 0:
|
| 502 |
+
wandb.log({
|
| 503 |
+
'phase2/train_loss': result['loss'],
|
| 504 |
+
'phase2/loss_margin': result['loss_margin'],
|
| 505 |
+
'phase2/loss_infonce': result['loss_infonce'],
|
| 506 |
+
'phase2/loss_ddg': result['loss_ddg'],
|
| 507 |
+
'phase2/selectivity_gap': result['selectivity_gap'],
|
| 508 |
+
'phase2/lr': self.optimizer.param_groups[0]['lr'],
|
| 509 |
+
'phase2/step': self.global_step,
|
| 510 |
+
})
|
| 511 |
+
|
| 512 |
+
# Validate
|
| 513 |
+
val_metrics = self.evaluate_phase2(val_loader)
|
| 514 |
+
|
| 515 |
+
logger.info(
|
| 516 |
+
f"Phase2 Epoch {epoch+1}/{n_epochs} | "
|
| 517 |
+
f"Loss: {loss_meter.avg:.4f} | "
|
| 518 |
+
f"Gap: {gap_meter.avg:.3f} | "
|
| 519 |
+
f"Val Gap: {val_metrics['val_selectivity_gap']:.3f} | "
|
| 520 |
+
f"Val Acc: {val_metrics['val_ranking_acc']:.3f}"
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
wandb.log({
|
| 524 |
+
'phase2/epoch': epoch + 1,
|
| 525 |
+
'phase2/train_loss_epoch': loss_meter.avg,
|
| 526 |
+
'phase2/train_gap_epoch': gap_meter.avg,
|
| 527 |
+
**{f'phase2/{k}': v for k, v in val_metrics.items()},
|
| 528 |
+
})
|
| 529 |
+
|
| 530 |
+
# Checkpoint
|
| 531 |
+
if val_metrics['val_selectivity_gap'] > self.best_val_metric:
|
| 532 |
+
self.best_val_metric = val_metrics['val_selectivity_gap']
|
| 533 |
+
self.save_checkpoint('best_phase2.pt', extra={'epoch': epoch, 'phase': 2})
|
| 534 |
+
logger.info(f" -> New best Phase 2 model (val_gap={self.best_val_metric:.3f})")
|
| 535 |
+
|
| 536 |
+
logger.info("Phase 2 training complete.")
|
| 537 |
+
|
| 538 |
+
@torch.no_grad()
|
| 539 |
+
def evaluate_phase2(self, loader):
|
| 540 |
+
"""Evaluate selectivity on paired (pos, neg) val set."""
|
| 541 |
+
self.model.eval()
|
| 542 |
+
all_pos_scores, all_neg_scores = [], []
|
| 543 |
+
|
| 544 |
+
for batch in loader:
|
| 545 |
+
if 'pos' not in batch:
|
| 546 |
+
continue
|
| 547 |
+
pos = batch['pos']
|
| 548 |
+
neg = batch['neg']
|
| 549 |
+
|
| 550 |
+
pos_esm = pos['esm_feats'].to(self.device) if 'esm_feats' in pos else None
|
| 551 |
+
neg_esm = neg['esm_feats'].to(self.device) if 'esm_feats' in neg else None
|
| 552 |
+
pos_scores = self.model(
|
| 553 |
+
pos['node_feats'].to(self.device),
|
| 554 |
+
pos['edge_feats'].to(self.device),
|
| 555 |
+
pos['node_mask'].to(self.device),
|
| 556 |
+
esm_feats=pos_esm
|
| 557 |
+
)
|
| 558 |
+
neg_scores = self.model(
|
| 559 |
+
neg['node_feats'].to(self.device),
|
| 560 |
+
neg['edge_feats'].to(self.device),
|
| 561 |
+
neg['node_mask'].to(self.device),
|
| 562 |
+
esm_feats=neg_esm
|
| 563 |
+
)
|
| 564 |
+
all_pos_scores.append(pos_scores.cpu().numpy())
|
| 565 |
+
all_neg_scores.append(neg_scores.cpu().numpy())
|
| 566 |
+
|
| 567 |
+
if not all_pos_scores:
|
| 568 |
+
return {'val_selectivity_gap': 0.0, 'val_ranking_acc': 0.5}
|
| 569 |
+
|
| 570 |
+
all_pos = np.concatenate(all_pos_scores)
|
| 571 |
+
all_neg = np.concatenate(all_neg_scores)
|
| 572 |
+
|
| 573 |
+
gap = float((all_pos - all_neg).mean())
|
| 574 |
+
acc = float((all_pos > all_neg).mean())
|
| 575 |
+
|
| 576 |
+
return {
|
| 577 |
+
'val_selectivity_gap': gap,
|
| 578 |
+
'val_ranking_acc': acc,
|
| 579 |
+
'val_pos_score_mean': float(all_pos.mean()),
|
| 580 |
+
'val_neg_score_mean': float(all_neg.mean()),
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
# ------------------------------------------------------------------ #
|
| 584 |
+
# Checkpointing
|
| 585 |
+
# ------------------------------------------------------------------ #
|
| 586 |
+
|
| 587 |
+
def save_checkpoint(self, filename: str, extra: dict = None):
|
| 588 |
+
path = os.path.join(self.checkpoint_dir, filename)
|
| 589 |
+
state = {
|
| 590 |
+
'model_state': self.model.state_dict(),
|
| 591 |
+
'optimizer_state': self.optimizer.state_dict(),
|
| 592 |
+
'global_step': self.global_step,
|
| 593 |
+
'config': self.config,
|
| 594 |
+
}
|
| 595 |
+
if extra:
|
| 596 |
+
state.update(extra)
|
| 597 |
+
torch.save(state, path)
|
| 598 |
+
logger.debug(f"Saved checkpoint: {path}")
|
| 599 |
+
|
| 600 |
+
def load_checkpoint(self, filename: str):
|
| 601 |
+
path = os.path.join(self.checkpoint_dir, filename)
|
| 602 |
+
if not os.path.exists(path):
|
| 603 |
+
logger.warning(f"Checkpoint not found: {path}")
|
| 604 |
+
return False
|
| 605 |
+
state = torch.load(path, map_location=self.device)
|
| 606 |
+
self.model.load_state_dict(state['model_state'])
|
| 607 |
+
self.optimizer.load_state_dict(state['optimizer_state'])
|
| 608 |
+
self.global_step = state.get('global_step', 0)
|
| 609 |
+
logger.info(f"Loaded checkpoint from {path} (step {self.global_step})")
|
| 610 |
+
return True
|
| 611 |
+
|
| 612 |
+
# ------------------------------------------------------------------ #
|
| 613 |
+
# Full evaluation (test set)
|
| 614 |
+
# ------------------------------------------------------------------ #
|
| 615 |
+
|
| 616 |
+
@torch.no_grad()
|
| 617 |
+
def evaluate_test(self, test_loader, phase: int = 2):
|
| 618 |
+
"""Full evaluation on test set with all metrics."""
|
| 619 |
+
self.model.eval()
|
| 620 |
+
all_scores, all_labels, all_types = [], [], []
|
| 621 |
+
|
| 622 |
+
for batch in test_loader:
|
| 623 |
+
if 'pos' in batch:
|
| 624 |
+
# Paired batch
|
| 625 |
+
for key in ['pos', 'neg']:
|
| 626 |
+
d = batch[key]
|
| 627 |
+
d_esm = d['esm_feats'].to(self.device) if 'esm_feats' in d else None
|
| 628 |
+
scores = self.model(
|
| 629 |
+
d['node_feats'].to(self.device),
|
| 630 |
+
d['edge_feats'].to(self.device),
|
| 631 |
+
d['node_mask'].to(self.device),
|
| 632 |
+
esm_feats=d_esm
|
| 633 |
+
)
|
| 634 |
+
all_scores.extend(scores.cpu().numpy().tolist())
|
| 635 |
+
all_labels.extend(d['label'].numpy().tolist())
|
| 636 |
+
all_types.extend(['pos' if key == 'pos' else 'neg'] * len(scores))
|
| 637 |
+
else:
|
| 638 |
+
esm_feats = batch['esm_feats'].to(self.device) if 'esm_feats' in batch else None
|
| 639 |
+
scores = self.model(
|
| 640 |
+
batch['node_feats'].to(self.device),
|
| 641 |
+
batch['edge_feats'].to(self.device),
|
| 642 |
+
batch['node_mask'].to(self.device),
|
| 643 |
+
esm_feats=esm_feats
|
| 644 |
+
)
|
| 645 |
+
all_scores.extend(scores.cpu().numpy().tolist())
|
| 646 |
+
all_labels.extend(batch['label'].numpy().tolist())
|
| 647 |
+
all_types.extend(batch['type'])
|
| 648 |
+
|
| 649 |
+
all_scores = np.array(all_scores)
|
| 650 |
+
all_labels = np.array(all_labels)
|
| 651 |
+
|
| 652 |
+
metrics = {}
|
| 653 |
+
|
| 654 |
+
# Spearman correlation (all samples)
|
| 655 |
+
metrics['test_spearman'] = float(spearmanr(all_scores, all_labels).correlation or 0)
|
| 656 |
+
|
| 657 |
+
# AUC (binary: label > 0.5 = positive quality)
|
| 658 |
+
binary = (all_labels > 0.5).astype(int)
|
| 659 |
+
if binary.sum() > 0 and binary.sum() < len(binary):
|
| 660 |
+
try:
|
| 661 |
+
metrics['test_auc'] = float(roc_auc_score(binary, all_scores))
|
| 662 |
+
except Exception:
|
| 663 |
+
pass
|
| 664 |
+
|
| 665 |
+
# Selectivity gap (pos vs neg_apo pairs)
|
| 666 |
+
pos_mask = np.array([t == 'pos' or t == 'positive' for t in all_types])
|
| 667 |
+
neg_mask = np.array([t == 'neg' or t == 'negative_apo' for t in all_types])
|
| 668 |
+
if pos_mask.sum() > 0 and neg_mask.sum() > 0:
|
| 669 |
+
metrics['test_selectivity_gap'] = float(all_scores[pos_mask].mean() - all_scores[neg_mask].mean())
|
| 670 |
+
|
| 671 |
+
logger.info(f"Test evaluation: {metrics}")
|
| 672 |
+
wandb.log({f'test/{k}': v for k, v in metrics.items()})
|
| 673 |
+
|
| 674 |
+
return metrics, all_scores, all_labels, all_types
|
code/utils/__init__.py
ADDED
|
File without changes
|
code/utils/anm.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Anisotropic Network Model (ANM) for conformational path interpolation.
|
| 3 |
+
|
| 4 |
+
From-scratch implementation using scipy eigendecomposition.
|
| 5 |
+
Projects the apo→holo displacement onto low-frequency normal modes
|
| 6 |
+
to create physically motivated interpolation paths.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.linalg import eigh
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def compute_anm_modes(ca_coords, cutoff=15.0, n_modes=10):
|
| 14 |
+
"""
|
| 15 |
+
Build elastic network Hessian and compute normal modes via eigendecomposition.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
ca_coords: [N, 3] CA atom coordinates
|
| 19 |
+
cutoff: distance cutoff for spring connections (Angstroms)
|
| 20 |
+
n_modes: number of non-trivial modes to return
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
eigenvalues: [n_modes] array of eigenvalues (force constants)
|
| 24 |
+
eigenvectors: [n_modes, N, 3] mode displacement vectors
|
| 25 |
+
"""
|
| 26 |
+
N = len(ca_coords)
|
| 27 |
+
if N < 4:
|
| 28 |
+
return np.zeros(n_modes), np.zeros((n_modes, N, 3))
|
| 29 |
+
|
| 30 |
+
# Build 3N x 3N Hessian with uniform spring constant (gamma=1)
|
| 31 |
+
H = np.zeros((3 * N, 3 * N), dtype=np.float64)
|
| 32 |
+
|
| 33 |
+
for i in range(N):
|
| 34 |
+
for j in range(i + 1, N):
|
| 35 |
+
diff = ca_coords[j] - ca_coords[i]
|
| 36 |
+
dist = np.linalg.norm(diff)
|
| 37 |
+
if dist > cutoff or dist < 1e-6:
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
# Outer product of unit displacement vector
|
| 41 |
+
unit = diff / dist
|
| 42 |
+
block = np.outer(unit, unit) # [3, 3]
|
| 43 |
+
|
| 44 |
+
# Off-diagonal: H[i,j] = -gamma * (r_ij ⊗ r_ij) / |r_ij|^2
|
| 45 |
+
# With uniform gamma=1 and unit vectors, this simplifies to:
|
| 46 |
+
ii, jj = 3 * i, 3 * j
|
| 47 |
+
H[ii:ii+3, jj:jj+3] = -block
|
| 48 |
+
H[jj:jj+3, ii:ii+3] = -block
|
| 49 |
+
|
| 50 |
+
# Diagonal: accumulate
|
| 51 |
+
H[ii:ii+3, ii:ii+3] += block
|
| 52 |
+
H[jj:jj+3, jj:jj+3] += block
|
| 53 |
+
|
| 54 |
+
# Eigendecompose — first 6 modes are trivial (3 translation + 3 rotation)
|
| 55 |
+
n_total = min(6 + n_modes, 3 * N)
|
| 56 |
+
eigenvalues, eigvecs = eigh(H, subset_by_index=[0, n_total - 1])
|
| 57 |
+
|
| 58 |
+
# Skip the 6 trivial zero-frequency modes
|
| 59 |
+
start = min(6, len(eigenvalues) - 1)
|
| 60 |
+
n_available = len(eigenvalues) - start
|
| 61 |
+
n_return = min(n_modes, n_available)
|
| 62 |
+
|
| 63 |
+
evals = eigenvalues[start:start + n_return]
|
| 64 |
+
evecs = eigvecs[:, start:start + n_return] # [3N, n_return]
|
| 65 |
+
|
| 66 |
+
# Reshape eigenvectors to [n_modes, N, 3]
|
| 67 |
+
mode_vectors = np.zeros((n_return, N, 3))
|
| 68 |
+
for k in range(n_return):
|
| 69 |
+
mode_vectors[k] = evecs[:, k].reshape(N, 3)
|
| 70 |
+
|
| 71 |
+
# Pad if fewer modes available than requested
|
| 72 |
+
if n_return < n_modes:
|
| 73 |
+
pad_evals = np.zeros(n_modes)
|
| 74 |
+
pad_evals[:n_return] = evals
|
| 75 |
+
pad_modes = np.zeros((n_modes, N, 3))
|
| 76 |
+
pad_modes[:n_return] = mode_vectors
|
| 77 |
+
return pad_evals, pad_modes
|
| 78 |
+
|
| 79 |
+
return evals, mode_vectors
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _kabsch_align(mobile_ca, ref_ca):
|
| 83 |
+
"""Kabsch alignment of mobile onto ref (CA atoms only)."""
|
| 84 |
+
t_mobile = mobile_ca.mean(axis=0)
|
| 85 |
+
t_ref = ref_ca.mean(axis=0)
|
| 86 |
+
|
| 87 |
+
m = mobile_ca - t_mobile
|
| 88 |
+
r = ref_ca - t_ref
|
| 89 |
+
|
| 90 |
+
H = m.T @ r
|
| 91 |
+
U, S, Vt = np.linalg.svd(H)
|
| 92 |
+
d = np.linalg.det(Vt.T @ U.T)
|
| 93 |
+
sign = np.array([1.0, 1.0, np.sign(d)])
|
| 94 |
+
R = Vt.T @ np.diag(sign) @ U.T
|
| 95 |
+
|
| 96 |
+
return R, t_mobile, t_ref
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _reconstruct_oxygen(coords):
|
| 100 |
+
"""Reconstruct O atom from N, CA, C with ideal C=O geometry."""
|
| 101 |
+
C_pos = coords[:, 2, :]
|
| 102 |
+
CA_pos = coords[:, 1, :]
|
| 103 |
+
C_CA = C_pos - CA_pos
|
| 104 |
+
C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
|
| 105 |
+
C_CA_norm = np.maximum(C_CA_norm, 1e-8)
|
| 106 |
+
O_pos = C_pos + (C_CA / C_CA_norm) * 1.24
|
| 107 |
+
coords[:, 3, :] = O_pos
|
| 108 |
+
return coords
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def anm_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1,
|
| 112 |
+
n_frames=5, n_modes=10, cutoff=15.0):
|
| 113 |
+
"""
|
| 114 |
+
Interpolate backbone along dominant ANM modes from X0 toward X1.
|
| 115 |
+
|
| 116 |
+
Low-frequency modes capture global domain motions (e.g., CaM hinge bending),
|
| 117 |
+
creating physically informed paths where large-scale motions precede local
|
| 118 |
+
adjustments.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for apo state
|
| 122 |
+
coords_x1: [N1, 4, 3] backbone coords for holo state
|
| 123 |
+
mask_x0: [N0] bool
|
| 124 |
+
mask_x1: [N1] bool
|
| 125 |
+
n_frames: number of intermediate frames (excluding endpoints)
|
| 126 |
+
n_modes: number of ANM modes to use for projection
|
| 127 |
+
cutoff: ANM spring cutoff in Angstroms
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
path_frames: list of (coords_tau, mask_tau, tau) tuples
|
| 131 |
+
Same interface as interpolate_backbone_path
|
| 132 |
+
"""
|
| 133 |
+
n_common = min(len(coords_x0), len(coords_x1))
|
| 134 |
+
c0 = coords_x0[:n_common].copy()
|
| 135 |
+
c1 = coords_x1[:n_common].copy()
|
| 136 |
+
m0 = mask_x0[:n_common]
|
| 137 |
+
m1 = mask_x1[:n_common]
|
| 138 |
+
|
| 139 |
+
common_mask = m0 & m1
|
| 140 |
+
if common_mask.sum() < 5:
|
| 141 |
+
return []
|
| 142 |
+
|
| 143 |
+
# Kabsch-align X0 onto X1 using valid CA atoms
|
| 144 |
+
ca0 = c0[common_mask, 1, :]
|
| 145 |
+
ca1 = c1[common_mask, 1, :]
|
| 146 |
+
R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
|
| 147 |
+
|
| 148 |
+
# Apply alignment to all X0 backbone atoms
|
| 149 |
+
flat0 = c0.reshape(-1, 3)
|
| 150 |
+
aligned0 = (flat0 - t_mobile) @ R.T + t_ref
|
| 151 |
+
c0_aligned = aligned0.reshape(n_common, 4, 3)
|
| 152 |
+
|
| 153 |
+
# Compute apo→holo displacement (CA atoms, valid residues only)
|
| 154 |
+
ca0_aligned = c0_aligned[common_mask, 1, :] # [N_valid, 3]
|
| 155 |
+
ca1_valid = c1[common_mask, 1, :]
|
| 156 |
+
|
| 157 |
+
displacement = ca1_valid - ca0_aligned # [N_valid, 3]
|
| 158 |
+
|
| 159 |
+
# Compute ANM modes of the aligned apo structure
|
| 160 |
+
eigenvalues, mode_vectors = compute_anm_modes(
|
| 161 |
+
ca0_aligned, cutoff=cutoff, n_modes=n_modes
|
| 162 |
+
) # mode_vectors: [n_modes, N_valid, 3]
|
| 163 |
+
|
| 164 |
+
# Project displacement onto each mode
|
| 165 |
+
# d_k = sum_i mode_k[i] . displacement[i]
|
| 166 |
+
projections = np.zeros(n_modes)
|
| 167 |
+
for k in range(n_modes):
|
| 168 |
+
projections[k] = np.sum(mode_vectors[k] * displacement)
|
| 169 |
+
|
| 170 |
+
# Reconstruct mode-projected displacement: d_mode = sum_k d_k * mode_k
|
| 171 |
+
mode_displacement = np.zeros_like(displacement) # [N_valid, 3]
|
| 172 |
+
for k in range(n_modes):
|
| 173 |
+
mode_displacement += projections[k] * mode_vectors[k]
|
| 174 |
+
|
| 175 |
+
# Residual displacement not captured by modes
|
| 176 |
+
residual = displacement - mode_displacement
|
| 177 |
+
|
| 178 |
+
# Generate intermediate frames
|
| 179 |
+
taus = np.linspace(0, 1, n_frames + 2)[1:-1]
|
| 180 |
+
path_frames = []
|
| 181 |
+
|
| 182 |
+
for tau in taus:
|
| 183 |
+
# Apply mode-projected + residual displacement at each tau
|
| 184 |
+
# Mode component applies smoothly; residual is linear
|
| 185 |
+
ca_interp = ca0_aligned + tau * mode_displacement + tau * residual
|
| 186 |
+
|
| 187 |
+
# Build full backbone by interpolating all 4 atom types
|
| 188 |
+
coords_tau = (1.0 - tau) * c0_aligned + tau * c1
|
| 189 |
+
# Override CA positions with ANM-interpolated values
|
| 190 |
+
coords_tau[common_mask, 1, :] = ca_interp
|
| 191 |
+
|
| 192 |
+
# Adjust N, C positions relative to CA shift
|
| 193 |
+
# The N/CA/C triangle is preserved by blending the ANM CA shift
|
| 194 |
+
# with the linear interpolation of N and C
|
| 195 |
+
ca_shift = ca_interp - ((1.0 - tau) * ca0_aligned + tau * ca1_valid)
|
| 196 |
+
coords_tau[common_mask, 0, :] += ca_shift # N atoms
|
| 197 |
+
coords_tau[common_mask, 2, :] += ca_shift # C atoms
|
| 198 |
+
|
| 199 |
+
# Reconstruct O from N, CA, C
|
| 200 |
+
coords_tau = _reconstruct_oxygen(coords_tau)
|
| 201 |
+
|
| 202 |
+
path_frames.append((
|
| 203 |
+
coords_tau.astype(np.float32),
|
| 204 |
+
common_mask.copy(),
|
| 205 |
+
float(tau),
|
| 206 |
+
))
|
| 207 |
+
|
| 208 |
+
return path_frames
|
code/utils/path_utils.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transition-path interpolation utilities for conformational induction.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- Kabsch-aligned backbone interpolation between two conformational states
|
| 6 |
+
- Gaussian Schrödinger Bridge (DSB) stochastic interpolation
|
| 7 |
+
- Precomputed frame loading (for AlphaFlow / AFsample2)
|
| 8 |
+
- Unified dispatcher: generate_path_frames()
|
| 9 |
+
- Per-residue displacement computation (for allosteric hinge weighting)
|
| 10 |
+
- Monotonically increasing path weight generation
|
| 11 |
+
|
| 12 |
+
Used by the path-aware training, guidance, and refinement modules.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import logging
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _kabsch_align(mobile_ca, ref_ca):
|
| 23 |
+
"""
|
| 24 |
+
Kabsch alignment of mobile onto ref (CA atoms only).
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
mobile_ca: [N, 3] array
|
| 28 |
+
ref_ca: [N, 3] array
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
R: [3, 3] rotation matrix
|
| 32 |
+
t_mobile: [3] mobile centroid
|
| 33 |
+
t_ref: [3] ref centroid
|
| 34 |
+
Such that: aligned = (mobile - t_mobile) @ R.T + t_ref
|
| 35 |
+
"""
|
| 36 |
+
t_mobile = mobile_ca.mean(axis=0)
|
| 37 |
+
t_ref = ref_ca.mean(axis=0)
|
| 38 |
+
|
| 39 |
+
m = mobile_ca - t_mobile
|
| 40 |
+
r = ref_ca - t_ref
|
| 41 |
+
|
| 42 |
+
H = m.T @ r
|
| 43 |
+
U, S, Vt = np.linalg.svd(H)
|
| 44 |
+
d = np.linalg.det(Vt.T @ U.T)
|
| 45 |
+
sign = np.array([1.0, 1.0, np.sign(d)])
|
| 46 |
+
R = Vt.T @ np.diag(sign) @ U.T
|
| 47 |
+
|
| 48 |
+
return R, t_mobile, t_ref
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def interpolate_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1, n_frames=5):
|
| 52 |
+
"""
|
| 53 |
+
Generate intermediate backbone conformations along the X0 -> X1 path.
|
| 54 |
+
|
| 55 |
+
1. Find common valid residues between X0 and X1
|
| 56 |
+
2. Kabsch-align X0 onto X1 using CA atoms
|
| 57 |
+
3. Linearly interpolate backbone coords at n_frames equally-spaced tau values
|
| 58 |
+
4. Reconstruct O from N/CA/C with ideal geometry
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
coords_x0: [N0, 4, 3] backbone coords (N, CA, C, O) for state 0
|
| 62 |
+
coords_x1: [N1, 4, 3] backbone coords for state 1
|
| 63 |
+
mask_x0: [N0] bool
|
| 64 |
+
mask_x1: [N1] bool
|
| 65 |
+
n_frames: number of intermediate frames (excluding endpoints)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
path_frames: list of (coords_tau, mask_tau, tau) tuples
|
| 69 |
+
coords_tau: [N_common, 4, 3] interpolated backbone coords
|
| 70 |
+
mask_tau: [N_common] bool
|
| 71 |
+
tau: float in (0, 1) exclusive
|
| 72 |
+
"""
|
| 73 |
+
# Use common length
|
| 74 |
+
n_common = min(len(coords_x0), len(coords_x1))
|
| 75 |
+
c0 = coords_x0[:n_common].copy()
|
| 76 |
+
c1 = coords_x1[:n_common].copy()
|
| 77 |
+
m0 = mask_x0[:n_common]
|
| 78 |
+
m1 = mask_x1[:n_common]
|
| 79 |
+
|
| 80 |
+
# Valid in both states
|
| 81 |
+
common_mask = m0 & m1
|
| 82 |
+
if common_mask.sum() < 5:
|
| 83 |
+
return []
|
| 84 |
+
|
| 85 |
+
# Kabsch-align X0 onto X1 using valid CA atoms
|
| 86 |
+
ca0 = c0[common_mask, 1, :] # CA atoms
|
| 87 |
+
ca1 = c1[common_mask, 1, :]
|
| 88 |
+
|
| 89 |
+
R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
|
| 90 |
+
|
| 91 |
+
# Apply alignment to all X0 backbone atoms
|
| 92 |
+
n_res = n_common
|
| 93 |
+
flat0 = c0.reshape(-1, 3)
|
| 94 |
+
aligned0 = (flat0 - t_mobile) @ R.T + t_ref
|
| 95 |
+
c0_aligned = aligned0.reshape(n_res, 4, 3)
|
| 96 |
+
|
| 97 |
+
# Generate intermediate frames
|
| 98 |
+
taus = np.linspace(0, 1, n_frames + 2)[1:-1] # exclude endpoints
|
| 99 |
+
path_frames = []
|
| 100 |
+
|
| 101 |
+
for tau in taus:
|
| 102 |
+
# Linear interpolation: X_tau = (1 - tau) * X0_aligned + tau * X1
|
| 103 |
+
coords_tau = (1.0 - tau) * c0_aligned + tau * c1
|
| 104 |
+
|
| 105 |
+
# Reconstruct O from N, CA, C with ideal C=O bond geometry
|
| 106 |
+
C_pos = coords_tau[:, 2, :] # C atoms
|
| 107 |
+
CA_pos = coords_tau[:, 1, :] # CA atoms
|
| 108 |
+
C_CA = C_pos - CA_pos
|
| 109 |
+
C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
|
| 110 |
+
C_CA_norm = np.maximum(C_CA_norm, 1e-8)
|
| 111 |
+
O_pos = C_pos + (C_CA / C_CA_norm) * 1.24 # ideal C=O bond length
|
| 112 |
+
coords_tau[:, 3, :] = O_pos
|
| 113 |
+
|
| 114 |
+
path_frames.append((
|
| 115 |
+
coords_tau.astype(np.float32),
|
| 116 |
+
common_mask.copy(),
|
| 117 |
+
float(tau),
|
| 118 |
+
))
|
| 119 |
+
|
| 120 |
+
return path_frames
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def compute_residue_displacements(coords_x0, coords_x1, mask_x0, mask_x1):
|
| 124 |
+
"""
|
| 125 |
+
Per-residue CA displacement between X0 and X1 after Kabsch alignment.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
coords_x0: [N0, 4, 3] backbone coords for state 0
|
| 129 |
+
coords_x1: [N1, 4, 3] backbone coords for state 1
|
| 130 |
+
mask_x0: [N0] bool
|
| 131 |
+
mask_x1: [N1] bool
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
displacements: [N_common] array of per-residue CA RMSD
|
| 135 |
+
common_mask: [N_common] bool — which residues are valid
|
| 136 |
+
"""
|
| 137 |
+
n_common = min(len(coords_x0), len(coords_x1))
|
| 138 |
+
c0 = coords_x0[:n_common]
|
| 139 |
+
c1 = coords_x1[:n_common]
|
| 140 |
+
m0 = mask_x0[:n_common]
|
| 141 |
+
m1 = mask_x1[:n_common]
|
| 142 |
+
common_mask = m0 & m1
|
| 143 |
+
|
| 144 |
+
if common_mask.sum() < 5:
|
| 145 |
+
return np.zeros(n_common), common_mask
|
| 146 |
+
|
| 147 |
+
ca0 = c0[common_mask, 1, :]
|
| 148 |
+
ca1 = c1[common_mask, 1, :]
|
| 149 |
+
|
| 150 |
+
R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
|
| 151 |
+
|
| 152 |
+
# Align all CA of X0
|
| 153 |
+
all_ca0 = c0[:, 1, :]
|
| 154 |
+
aligned_ca0 = (all_ca0 - t_mobile) @ R.T + t_ref
|
| 155 |
+
|
| 156 |
+
# Per-residue displacement
|
| 157 |
+
all_ca1 = c1[:, 1, :]
|
| 158 |
+
displacements = np.linalg.norm(aligned_ca0 - all_ca1, axis=-1)
|
| 159 |
+
|
| 160 |
+
# Zero out invalid residues
|
| 161 |
+
displacements[~common_mask] = 0.0
|
| 162 |
+
|
| 163 |
+
return displacements.astype(np.float32), common_mask
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def generate_path_weights(n_frames, mode='linear'):
|
| 167 |
+
"""
|
| 168 |
+
Generate monotonically increasing weights for path frames.
|
| 169 |
+
|
| 170 |
+
The weights increase toward tau=1 (the goal state), so that
|
| 171 |
+
intermediate conformations closer to X1 are weighted more heavily.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
n_frames: number of intermediate frames
|
| 175 |
+
mode: weight schedule
|
| 176 |
+
'linear': w_tau = tau
|
| 177 |
+
'quadratic': w_tau = tau^2
|
| 178 |
+
'exponential': w_tau = (exp(tau) - 1) / (e - 1)
|
| 179 |
+
'uniform': w_tau = 1/n_frames (equal weighting)
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
weights: [n_frames] numpy array, normalized to sum to 1
|
| 183 |
+
"""
|
| 184 |
+
if n_frames == 0:
|
| 185 |
+
return np.array([], dtype=np.float32)
|
| 186 |
+
|
| 187 |
+
taus = np.linspace(0, 1, n_frames + 2)[1:-1] # same as interpolation
|
| 188 |
+
|
| 189 |
+
if mode == 'linear':
|
| 190 |
+
weights = taus.copy()
|
| 191 |
+
elif mode == 'quadratic':
|
| 192 |
+
weights = taus ** 2
|
| 193 |
+
elif mode == 'exponential':
|
| 194 |
+
weights = (np.exp(taus) - 1.0) / (np.e - 1.0)
|
| 195 |
+
elif mode == 'uniform':
|
| 196 |
+
weights = np.ones(n_frames, dtype=np.float32)
|
| 197 |
+
else:
|
| 198 |
+
raise ValueError(f"Unknown weight mode: {mode}")
|
| 199 |
+
|
| 200 |
+
# Normalize to sum to 1
|
| 201 |
+
total = weights.sum()
|
| 202 |
+
if total > 0:
|
| 203 |
+
weights = weights / total
|
| 204 |
+
|
| 205 |
+
return weights.astype(np.float32)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# ---------------------------------------------------------------------------
|
| 209 |
+
# Gaussian Schrödinger Bridge (AlignDSB) interpolation
|
| 210 |
+
# ---------------------------------------------------------------------------
|
| 211 |
+
|
| 212 |
+
def dsb_backbone_path(coords_x0, coords_x1, mask_x0, mask_x1,
|
| 213 |
+
n_frames=5, sigma=0.5, n_samples=20, seed=42):
|
| 214 |
+
"""
|
| 215 |
+
Gaussian Schrödinger Bridge with t*(1-t) variance schedule.
|
| 216 |
+
|
| 217 |
+
Analytic formula (no neural network):
|
| 218 |
+
X_t = (1-t) * X0_aligned + t * X1 + sqrt(t * (1-t)) * sigma * Z
|
| 219 |
+
|
| 220 |
+
Variance peaks at t=0.5 (maximum uncertainty mid-transition) and vanishes
|
| 221 |
+
at endpoints. sigma controls noise amplitude in Angstroms.
|
| 222 |
+
|
| 223 |
+
For each tau, samples n_samples noisy interpolations and selects the
|
| 224 |
+
median (by RMSD to the mean) for robustness.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
coords_x0: [N0, 4, 3] backbone coords for state 0
|
| 228 |
+
coords_x1: [N1, 4, 3] backbone coords for state 1
|
| 229 |
+
mask_x0: [N0] bool
|
| 230 |
+
mask_x1: [N1] bool
|
| 231 |
+
n_frames: number of intermediate frames
|
| 232 |
+
sigma: noise amplitude (Angstroms)
|
| 233 |
+
n_samples: number of samples per frame for median selection
|
| 234 |
+
seed: random seed
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
path_frames: list of (coords_tau, mask_tau, tau) tuples
|
| 238 |
+
"""
|
| 239 |
+
rng = np.random.RandomState(seed)
|
| 240 |
+
|
| 241 |
+
n_common = min(len(coords_x0), len(coords_x1))
|
| 242 |
+
c0 = coords_x0[:n_common].copy()
|
| 243 |
+
c1 = coords_x1[:n_common].copy()
|
| 244 |
+
m0 = mask_x0[:n_common]
|
| 245 |
+
m1 = mask_x1[:n_common]
|
| 246 |
+
|
| 247 |
+
common_mask = m0 & m1
|
| 248 |
+
if common_mask.sum() < 5:
|
| 249 |
+
return []
|
| 250 |
+
|
| 251 |
+
# Kabsch-align X0 onto X1
|
| 252 |
+
ca0 = c0[common_mask, 1, :]
|
| 253 |
+
ca1 = c1[common_mask, 1, :]
|
| 254 |
+
R, t_mobile, t_ref = _kabsch_align(ca0, ca1)
|
| 255 |
+
|
| 256 |
+
flat0 = c0.reshape(-1, 3)
|
| 257 |
+
aligned0 = (flat0 - t_mobile) @ R.T + t_ref
|
| 258 |
+
c0_aligned = aligned0.reshape(n_common, 4, 3)
|
| 259 |
+
|
| 260 |
+
taus = np.linspace(0, 1, n_frames + 2)[1:-1]
|
| 261 |
+
path_frames = []
|
| 262 |
+
|
| 263 |
+
for tau in taus:
|
| 264 |
+
noise_scale = np.sqrt(tau * (1.0 - tau)) * sigma
|
| 265 |
+
|
| 266 |
+
# Generate n_samples noisy interpolations
|
| 267 |
+
samples = []
|
| 268 |
+
for _ in range(n_samples):
|
| 269 |
+
Z = rng.randn(n_common, 4, 3).astype(np.float64)
|
| 270 |
+
X_t = (1.0 - tau) * c0_aligned + tau * c1 + noise_scale * Z
|
| 271 |
+
samples.append(X_t)
|
| 272 |
+
|
| 273 |
+
samples = np.array(samples) # [n_samples, N, 4, 3]
|
| 274 |
+
mean_sample = samples.mean(axis=0) # [N, 4, 3]
|
| 275 |
+
|
| 276 |
+
# Select median sample by RMSD to mean (CA atoms)
|
| 277 |
+
rmsds = []
|
| 278 |
+
for s in samples:
|
| 279 |
+
diff = s[common_mask, 1, :] - mean_sample[common_mask, 1, :]
|
| 280 |
+
rmsd = np.sqrt((diff ** 2).sum() / common_mask.sum())
|
| 281 |
+
rmsds.append(rmsd)
|
| 282 |
+
median_idx = np.argsort(rmsds)[len(rmsds) // 2]
|
| 283 |
+
coords_tau = samples[median_idx]
|
| 284 |
+
|
| 285 |
+
# Reconstruct O from N, CA, C
|
| 286 |
+
C_pos = coords_tau[:, 2, :]
|
| 287 |
+
CA_pos = coords_tau[:, 1, :]
|
| 288 |
+
C_CA = C_pos - CA_pos
|
| 289 |
+
C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
|
| 290 |
+
C_CA_norm = np.maximum(C_CA_norm, 1e-8)
|
| 291 |
+
coords_tau[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24
|
| 292 |
+
|
| 293 |
+
path_frames.append((
|
| 294 |
+
coords_tau.astype(np.float32),
|
| 295 |
+
common_mask.copy(),
|
| 296 |
+
float(tau),
|
| 297 |
+
))
|
| 298 |
+
|
| 299 |
+
return path_frames
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# ---------------------------------------------------------------------------
|
| 303 |
+
# Precomputed frame loading (for AlphaFlow / AFsample2)
|
| 304 |
+
# ---------------------------------------------------------------------------
|
| 305 |
+
|
| 306 |
+
def load_precomputed_frames(target, method, precomputed_dir,
|
| 307 |
+
coords_x0, coords_x1, mask_x0, mask_x1,
|
| 308 |
+
n_frames=5):
|
| 309 |
+
"""
|
| 310 |
+
Load pre-generated frames from .npz and Kabsch-align to this complex's
|
| 311 |
+
receptor coordinate frame.
|
| 312 |
+
|
| 313 |
+
Expected file: {precomputed_dir}/{target}/{method}/frames.npz
|
| 314 |
+
with keys: 'frames' [n_frames, N_ref, 4, 3], 'taus' [n_frames],
|
| 315 |
+
'mask' [N_ref] bool
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
target: target name (e.g. 'cam')
|
| 319 |
+
method: method name ('alphaflow' or 'afsample2')
|
| 320 |
+
precomputed_dir: root directory for precomputed frames
|
| 321 |
+
coords_x0, coords_x1: apo/holo backbone coords for alignment
|
| 322 |
+
mask_x0, mask_x1: residue masks
|
| 323 |
+
n_frames: number of frames to return
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
path_frames: list of (coords_tau, mask_tau, tau) tuples
|
| 327 |
+
"""
|
| 328 |
+
npz_path = os.path.join(precomputed_dir, target, method, 'frames.npz')
|
| 329 |
+
if not os.path.exists(npz_path):
|
| 330 |
+
logger.warning(f"Precomputed frames not found: {npz_path}, "
|
| 331 |
+
f"falling back to linear interpolation")
|
| 332 |
+
return interpolate_backbone_path(coords_x0, coords_x1,
|
| 333 |
+
mask_x0, mask_x1, n_frames)
|
| 334 |
+
|
| 335 |
+
data = np.load(npz_path)
|
| 336 |
+
pre_frames = data['frames'] # [K, N_ref, 4, 3]
|
| 337 |
+
pre_taus = data['taus'] # [K]
|
| 338 |
+
pre_mask = data['mask'] # [N_ref]
|
| 339 |
+
|
| 340 |
+
n_common = min(len(coords_x0), len(coords_x1), len(pre_mask))
|
| 341 |
+
m0 = mask_x0[:n_common]
|
| 342 |
+
m1 = mask_x1[:n_common]
|
| 343 |
+
pm = pre_mask[:n_common]
|
| 344 |
+
common_mask = m0 & m1 & pm
|
| 345 |
+
|
| 346 |
+
if common_mask.sum() < 5:
|
| 347 |
+
logger.warning(f"Too few common residues for {target}/{method}, "
|
| 348 |
+
f"falling back to linear")
|
| 349 |
+
return interpolate_backbone_path(coords_x0, coords_x1,
|
| 350 |
+
mask_x0, mask_x1, n_frames)
|
| 351 |
+
|
| 352 |
+
# Align precomputed frames to the holo receptor (X1) coordinate frame
|
| 353 |
+
# The precomputed frames were generated from the reference apo sequence
|
| 354 |
+
# and may be in a different coordinate frame
|
| 355 |
+
ref_ca = coords_x1[:n_common][common_mask, 1, :] # holo CA as reference
|
| 356 |
+
|
| 357 |
+
path_frames = []
|
| 358 |
+
K = min(len(pre_frames), n_frames)
|
| 359 |
+
|
| 360 |
+
# Select n_frames evenly spaced from available frames
|
| 361 |
+
if len(pre_frames) > n_frames:
|
| 362 |
+
indices = np.linspace(0, len(pre_frames) - 1, n_frames).astype(int)
|
| 363 |
+
else:
|
| 364 |
+
indices = np.arange(K)
|
| 365 |
+
|
| 366 |
+
for idx in indices:
|
| 367 |
+
frame = pre_frames[idx, :n_common].copy() # [N_common, 4, 3]
|
| 368 |
+
tau = float(pre_taus[idx])
|
| 369 |
+
|
| 370 |
+
# Kabsch-align frame CA to holo CA
|
| 371 |
+
frame_ca = frame[common_mask, 1, :]
|
| 372 |
+
R, t_frame, t_ref = _kabsch_align(frame_ca, ref_ca)
|
| 373 |
+
|
| 374 |
+
flat_frame = frame.reshape(-1, 3)
|
| 375 |
+
aligned = (flat_frame - t_frame) @ R.T + t_ref
|
| 376 |
+
frame_aligned = aligned.reshape(n_common, 4, 3)
|
| 377 |
+
|
| 378 |
+
# Reconstruct O
|
| 379 |
+
C_pos = frame_aligned[:, 2, :]
|
| 380 |
+
CA_pos = frame_aligned[:, 1, :]
|
| 381 |
+
C_CA = C_pos - CA_pos
|
| 382 |
+
C_CA_norm = np.linalg.norm(C_CA, axis=-1, keepdims=True)
|
| 383 |
+
C_CA_norm = np.maximum(C_CA_norm, 1e-8)
|
| 384 |
+
frame_aligned[:, 3, :] = C_pos + (C_CA / C_CA_norm) * 1.24
|
| 385 |
+
|
| 386 |
+
path_frames.append((
|
| 387 |
+
frame_aligned.astype(np.float32),
|
| 388 |
+
common_mask.copy(),
|
| 389 |
+
tau,
|
| 390 |
+
))
|
| 391 |
+
|
| 392 |
+
return path_frames
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# ---------------------------------------------------------------------------
|
| 396 |
+
# Unified dispatcher
|
| 397 |
+
# ---------------------------------------------------------------------------
|
| 398 |
+
|
| 399 |
+
def generate_path_frames(coords_x0, coords_x1, mask_x0, mask_x1,
|
| 400 |
+
method='linear', n_frames=5,
|
| 401 |
+
precomputed_dir=None, target=None, **kwargs):
|
| 402 |
+
"""
|
| 403 |
+
Dispatch to method-specific frame generation.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
coords_x0, coords_x1: [N, 4, 3] backbone coords for apo/holo
|
| 407 |
+
mask_x0, mask_x1: [N] bool masks
|
| 408 |
+
method: one of 'linear', 'alphaflow', 'afsample2', 'dsb', 'anm'
|
| 409 |
+
n_frames: number of intermediate frames
|
| 410 |
+
precomputed_dir: directory for precomputed frames (alphaflow/afsample2)
|
| 411 |
+
target: target name (needed for precomputed methods)
|
| 412 |
+
**kwargs: method-specific parameters (sigma, n_modes, etc.)
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
path_frames: list of (coords_tau, mask_tau, tau) tuples
|
| 416 |
+
"""
|
| 417 |
+
if method == 'linear':
|
| 418 |
+
return interpolate_backbone_path(
|
| 419 |
+
coords_x0, coords_x1, mask_x0, mask_x1, n_frames)
|
| 420 |
+
|
| 421 |
+
elif method in ('alphaflow', 'afsample2'):
|
| 422 |
+
if precomputed_dir is None:
|
| 423 |
+
raise ValueError(f"precomputed_dir required for method '{method}'")
|
| 424 |
+
if target is None:
|
| 425 |
+
raise ValueError(f"target name required for method '{method}'")
|
| 426 |
+
return load_precomputed_frames(
|
| 427 |
+
target, method, precomputed_dir,
|
| 428 |
+
coords_x0, coords_x1, mask_x0, mask_x1, n_frames)
|
| 429 |
+
|
| 430 |
+
elif method == 'dsb':
|
| 431 |
+
return dsb_backbone_path(
|
| 432 |
+
coords_x0, coords_x1, mask_x0, mask_x1,
|
| 433 |
+
n_frames=n_frames,
|
| 434 |
+
sigma=kwargs.get('sigma', 0.5),
|
| 435 |
+
n_samples=kwargs.get('n_samples', 20),
|
| 436 |
+
seed=kwargs.get('seed', 42))
|
| 437 |
+
|
| 438 |
+
elif method == 'anm':
|
| 439 |
+
from utils.anm import anm_backbone_path
|
| 440 |
+
return anm_backbone_path(
|
| 441 |
+
coords_x0, coords_x1, mask_x0, mask_x1,
|
| 442 |
+
n_frames=n_frames,
|
| 443 |
+
n_modes=kwargs.get('n_modes', 10),
|
| 444 |
+
cutoff=kwargs.get('cutoff', 15.0))
|
| 445 |
+
|
| 446 |
+
else:
|
| 447 |
+
raise ValueError(f"Unknown path method: '{method}'. "
|
| 448 |
+
f"Choose from: linear, alphaflow, afsample2, dsb, anm")
|
code/utils/pdb_utils.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PDB parsing utilities for Allo-Designer.
|
| 3 |
+
Extracts backbone geometry, computes local frames, and identifies interface residues.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from Bio import PDB
|
| 8 |
+
from Bio.PDB import PDBParser, MMCIFParser, PDBIO
|
| 9 |
+
from Bio.PDB.Polypeptide import is_aa
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings("ignore", category=PDB.PDBExceptions.PDBConstructionWarning)
|
| 12 |
+
|
| 13 |
+
AA3_TO_IDX = {
|
| 14 |
+
'ALA': 0, 'ARG': 1, 'ASN': 2, 'ASP': 3, 'CYS': 4,
|
| 15 |
+
'GLN': 5, 'GLU': 6, 'GLY': 7, 'HIS': 8, 'ILE': 9,
|
| 16 |
+
'LEU': 10, 'LYS': 11, 'MET': 12, 'PHE': 13, 'PRO': 14,
|
| 17 |
+
'SER': 15, 'THR': 16, 'TRP': 17, 'TYR': 18, 'VAL': 19,
|
| 18 |
+
'UNK': 20,
|
| 19 |
+
}
|
| 20 |
+
NUM_AA = 21 # 20 standard + UNK
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_structure(pdb_path: str, model_id: int = 0):
|
| 24 |
+
"""Load a PDB/CIF file and return the first model."""
|
| 25 |
+
if pdb_path.endswith('.cif') or pdb_path.endswith('.mmcif'):
|
| 26 |
+
parser = MMCIFParser(QUIET=True)
|
| 27 |
+
else:
|
| 28 |
+
parser = PDBParser(QUIET=True)
|
| 29 |
+
struct = parser.get_structure("protein", pdb_path)
|
| 30 |
+
return list(struct.get_models())[model_id]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_residues(chain, only_standard: bool = True):
|
| 34 |
+
"""Return a list of standard amino acid residues from a chain."""
|
| 35 |
+
residues = []
|
| 36 |
+
for res in chain.get_residues():
|
| 37 |
+
if only_standard and not is_aa(res, standard=True):
|
| 38 |
+
continue
|
| 39 |
+
if res.get_id()[0] != ' ': # skip HETATM
|
| 40 |
+
continue
|
| 41 |
+
residues.append(res)
|
| 42 |
+
return residues
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_backbone_coords(residues):
|
| 46 |
+
"""
|
| 47 |
+
Extract backbone atom coordinates (N, CA, C, O) for each residue.
|
| 48 |
+
Returns: coords [N_res, 4, 3], mask [N_res] (True = all backbone atoms present)
|
| 49 |
+
"""
|
| 50 |
+
N = len(residues)
|
| 51 |
+
coords = np.zeros((N, 4, 3), dtype=np.float32)
|
| 52 |
+
mask = np.zeros(N, dtype=bool)
|
| 53 |
+
|
| 54 |
+
for i, res in enumerate(residues):
|
| 55 |
+
try:
|
| 56 |
+
coords[i, 0] = res['N'].get_vector().get_array()
|
| 57 |
+
coords[i, 1] = res['CA'].get_vector().get_array()
|
| 58 |
+
coords[i, 2] = res['C'].get_vector().get_array()
|
| 59 |
+
if 'O' in res:
|
| 60 |
+
coords[i, 3] = res['O'].get_vector().get_array()
|
| 61 |
+
else:
|
| 62 |
+
# Estimate O position if missing
|
| 63 |
+
coords[i, 3] = coords[i, 2]
|
| 64 |
+
mask[i] = True
|
| 65 |
+
except KeyError:
|
| 66 |
+
pass
|
| 67 |
+
return coords, mask
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_aa_indices(residues):
|
| 71 |
+
"""Return integer amino acid indices for each residue."""
|
| 72 |
+
return np.array([
|
| 73 |
+
AA3_TO_IDX.get(res.get_resname(), AA3_TO_IDX['UNK'])
|
| 74 |
+
for res in residues
|
| 75 |
+
], dtype=np.int64)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def compute_backbone_frames(coords, mask):
|
| 79 |
+
"""
|
| 80 |
+
Compute SE(3)-equivariant backbone frames from N, CA, C atoms.
|
| 81 |
+
Frame: z-axis = CA->C, y-axis = component of CA->N perpendicular to z, x-axis = y x z.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
origins: [N, 3] = CA positions
|
| 85 |
+
rotations: [N, 3, 3] = rotation matrices (columns are x, y, z axes)
|
| 86 |
+
"""
|
| 87 |
+
N_res = coords.shape[0]
|
| 88 |
+
origins = coords[:, 1, :] # CA positions [N, 3]
|
| 89 |
+
rotations = np.zeros((N_res, 3, 3), dtype=np.float32)
|
| 90 |
+
|
| 91 |
+
for i in range(N_res):
|
| 92 |
+
if not mask[i]:
|
| 93 |
+
rotations[i] = np.eye(3)
|
| 94 |
+
continue
|
| 95 |
+
ca = coords[i, 1]
|
| 96 |
+
n = coords[i, 0]
|
| 97 |
+
c = coords[i, 2]
|
| 98 |
+
|
| 99 |
+
# z-axis: CA -> C
|
| 100 |
+
z = c - ca
|
| 101 |
+
z_norm = np.linalg.norm(z)
|
| 102 |
+
if z_norm < 1e-6:
|
| 103 |
+
rotations[i] = np.eye(3)
|
| 104 |
+
continue
|
| 105 |
+
z = z / z_norm
|
| 106 |
+
|
| 107 |
+
# y-axis: CA -> N, orthogonalized
|
| 108 |
+
y = n - ca
|
| 109 |
+
y = y - np.dot(y, z) * z
|
| 110 |
+
y_norm = np.linalg.norm(y)
|
| 111 |
+
if y_norm < 1e-6:
|
| 112 |
+
rotations[i] = np.eye(3)
|
| 113 |
+
continue
|
| 114 |
+
y = y / y_norm
|
| 115 |
+
|
| 116 |
+
# x-axis: y cross z
|
| 117 |
+
x = np.cross(y, z)
|
| 118 |
+
|
| 119 |
+
rotations[i] = np.stack([x, y, z], axis=-1) # columns are axes
|
| 120 |
+
|
| 121 |
+
return origins, rotations
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def compute_torsion_angles(coords, mask):
|
| 125 |
+
"""
|
| 126 |
+
Compute backbone torsion angles (phi, psi, omega) for each residue.
|
| 127 |
+
Returns sin/cos of each angle. [N, 6]
|
| 128 |
+
"""
|
| 129 |
+
N = len(coords)
|
| 130 |
+
angles = np.zeros((N, 6), dtype=np.float32)
|
| 131 |
+
|
| 132 |
+
def dihedral(p0, p1, p2, p3):
|
| 133 |
+
"""Praxelis dihedral angle computation."""
|
| 134 |
+
b1 = p1 - p0
|
| 135 |
+
b2 = p2 - p1
|
| 136 |
+
b3 = p3 - p2
|
| 137 |
+
n1 = np.cross(b1, b2)
|
| 138 |
+
n2 = np.cross(b2, b3)
|
| 139 |
+
n1_norm = np.linalg.norm(n1)
|
| 140 |
+
n2_norm = np.linalg.norm(n2)
|
| 141 |
+
if n1_norm < 1e-6 or n2_norm < 1e-6:
|
| 142 |
+
return 0.0
|
| 143 |
+
n1 = n1 / n1_norm
|
| 144 |
+
n2 = n2 / n2_norm
|
| 145 |
+
m1 = np.cross(n1, b2 / (np.linalg.norm(b2) + 1e-8))
|
| 146 |
+
cos_a = np.clip(np.dot(n1, n2), -1, 1)
|
| 147 |
+
sin_a = np.dot(m1, n2)
|
| 148 |
+
return np.arctan2(sin_a, cos_a)
|
| 149 |
+
|
| 150 |
+
for i in range(N):
|
| 151 |
+
if not mask[i]:
|
| 152 |
+
continue
|
| 153 |
+
ca_i = coords[i, 1]
|
| 154 |
+
n_i = coords[i, 0]
|
| 155 |
+
c_i = coords[i, 2]
|
| 156 |
+
|
| 157 |
+
# Phi: C_{i-1} - N_i - CA_i - C_i
|
| 158 |
+
if i > 0 and mask[i - 1]:
|
| 159 |
+
c_prev = coords[i - 1, 2]
|
| 160 |
+
phi = dihedral(c_prev, n_i, ca_i, c_i)
|
| 161 |
+
angles[i, 0] = np.sin(phi)
|
| 162 |
+
angles[i, 1] = np.cos(phi)
|
| 163 |
+
|
| 164 |
+
# Psi: N_i - CA_i - C_i - N_{i+1}
|
| 165 |
+
if i < N - 1 and mask[i + 1]:
|
| 166 |
+
n_next = coords[i + 1, 0]
|
| 167 |
+
psi = dihedral(n_i, ca_i, c_i, n_next)
|
| 168 |
+
angles[i, 2] = np.sin(psi)
|
| 169 |
+
angles[i, 3] = np.cos(psi)
|
| 170 |
+
|
| 171 |
+
# Omega: CA_{i-1} - C_{i-1} - N_i - CA_i
|
| 172 |
+
if i > 0 and mask[i - 1]:
|
| 173 |
+
ca_prev = coords[i - 1, 1]
|
| 174 |
+
c_prev = coords[i - 1, 2]
|
| 175 |
+
omega = dihedral(ca_prev, c_prev, n_i, ca_i)
|
| 176 |
+
angles[i, 4] = np.sin(omega)
|
| 177 |
+
angles[i, 5] = np.cos(omega)
|
| 178 |
+
|
| 179 |
+
return angles
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def get_interface_residues(rec_coords, binder_coords, rec_mask, binder_mask, cutoff: float = 8.0):
|
| 183 |
+
"""
|
| 184 |
+
Find interface residues: receptor residues within cutoff of any binder Cα, and vice versa.
|
| 185 |
+
Uses CA-CA distances.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
rec_interface: bool array [N_rec]
|
| 189 |
+
binder_interface: bool array [N_binder]
|
| 190 |
+
"""
|
| 191 |
+
rec_ca = rec_coords[:, 1, :] # [N_rec, 3]
|
| 192 |
+
binder_ca = binder_coords[:, 1, :] # [N_binder, 3]
|
| 193 |
+
|
| 194 |
+
# Pairwise CA-CA distances [N_rec, N_binder]
|
| 195 |
+
diff = rec_ca[:, None, :] - binder_ca[None, :, :] # [N_rec, N_binder, 3]
|
| 196 |
+
dist = np.sqrt((diff ** 2).sum(axis=-1)) # [N_rec, N_binder]
|
| 197 |
+
|
| 198 |
+
# Mask out residues without coordinates
|
| 199 |
+
dist[~rec_mask, :] = np.inf
|
| 200 |
+
dist[:, ~binder_mask] = np.inf
|
| 201 |
+
|
| 202 |
+
rec_interface = (dist < cutoff).any(axis=1)
|
| 203 |
+
binder_interface = (dist < cutoff).any(axis=0)
|
| 204 |
+
|
| 205 |
+
return rec_interface, binder_interface
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def align_structures(mobile_ca, ref_ca, mobile_coords=None):
|
| 209 |
+
"""
|
| 210 |
+
Kabsch alignment: align mobile to ref using CA positions.
|
| 211 |
+
Returns aligned CA coords and optionally full backbone coords.
|
| 212 |
+
"""
|
| 213 |
+
assert mobile_ca.shape == ref_ca.shape, "Must have same number of residues"
|
| 214 |
+
|
| 215 |
+
# Center
|
| 216 |
+
mobile_center = mobile_ca.mean(axis=0)
|
| 217 |
+
ref_center = ref_ca.mean(axis=0)
|
| 218 |
+
m = mobile_ca - mobile_center
|
| 219 |
+
r = ref_ca - ref_center
|
| 220 |
+
|
| 221 |
+
# SVD
|
| 222 |
+
H = m.T @ r
|
| 223 |
+
U, S, Vt = np.linalg.svd(H)
|
| 224 |
+
d = np.sign(np.linalg.det(Vt.T @ U.T))
|
| 225 |
+
D = np.diag([1, 1, d])
|
| 226 |
+
R = Vt.T @ D @ U.T # rotation matrix
|
| 227 |
+
|
| 228 |
+
mobile_ca_aligned = (m @ R.T) + ref_center
|
| 229 |
+
|
| 230 |
+
if mobile_coords is not None:
|
| 231 |
+
# Apply same rotation to full backbone
|
| 232 |
+
N_res, N_atoms, _ = mobile_coords.shape
|
| 233 |
+
flat = mobile_coords.reshape(-1, 3) - mobile_center
|
| 234 |
+
aligned_flat = (flat @ R.T) + ref_center
|
| 235 |
+
mobile_coords_aligned = aligned_flat.reshape(N_res, N_atoms, 3)
|
| 236 |
+
return mobile_ca_aligned, R, mobile_coords_aligned
|
| 237 |
+
|
| 238 |
+
return mobile_ca_aligned, R
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def compute_ca_rmsd(coords1, coords2, mask=None):
|
| 242 |
+
"""Compute CA-RMSD between two sets of backbone coordinates."""
|
| 243 |
+
ca1 = coords1[:, 1, :]
|
| 244 |
+
ca2 = coords2[:, 1, :]
|
| 245 |
+
if mask is not None:
|
| 246 |
+
ca1 = ca1[mask]
|
| 247 |
+
ca2 = ca2[mask]
|
| 248 |
+
diff = ca1 - ca2
|
| 249 |
+
return np.sqrt((diff ** 2).sum(axis=-1).mean())
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def compute_fraction_native_contacts(
|
| 253 |
+
native_rec_ca, native_binder_ca,
|
| 254 |
+
model_rec_ca=None, model_binder_ca=None,
|
| 255 |
+
cutoff=8.0,
|
| 256 |
+
# Legacy 2-arg signature support
|
| 257 |
+
mask=None, delta=1.0,
|
| 258 |
+
):
|
| 259 |
+
"""
|
| 260 |
+
Compute fraction of native inter-chain contacts (fNAT).
|
| 261 |
+
|
| 262 |
+
fNAT = |recovered inter-chain contacts| / |native inter-chain contacts|
|
| 263 |
+
|
| 264 |
+
A native contact is a (receptor_i, binder_j) pair with CA-CA distance
|
| 265 |
+
< cutoff in the native complex. A contact is "recovered" if the same
|
| 266 |
+
pair is < cutoff in the model complex.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
native_rec_ca: [N_rec, 3] receptor CA coords in native complex
|
| 270 |
+
native_binder_ca: [N_bind, 3] binder CA coords in native complex
|
| 271 |
+
model_rec_ca: [N_rec, 3] receptor CA in model (default: same as native)
|
| 272 |
+
model_binder_ca: [N_bind, 3] binder CA in model (default: same as native)
|
| 273 |
+
cutoff: contact distance threshold in Angstroms (default 8.0 for CA-CA)
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
fNAT in [0, 1]. Returns 0.0 if no native contacts exist.
|
| 277 |
+
"""
|
| 278 |
+
if model_rec_ca is None:
|
| 279 |
+
model_rec_ca = native_rec_ca
|
| 280 |
+
if model_binder_ca is None:
|
| 281 |
+
model_binder_ca = native_binder_ca
|
| 282 |
+
|
| 283 |
+
# Inter-chain distance matrices [N_rec, N_bind]
|
| 284 |
+
native_dist = np.sqrt(
|
| 285 |
+
((native_rec_ca[:, None, :] - native_binder_ca[None, :, :]) ** 2).sum(-1)
|
| 286 |
+
)
|
| 287 |
+
model_dist = np.sqrt(
|
| 288 |
+
((model_rec_ca[:, None, :] - model_binder_ca[None, :, :]) ** 2).sum(-1)
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
native_contacts = native_dist < cutoff
|
| 292 |
+
recovered = native_contacts & (model_dist < cutoff)
|
| 293 |
+
|
| 294 |
+
n_native = native_contacts.sum()
|
| 295 |
+
if n_native == 0:
|
| 296 |
+
return 0.0
|
| 297 |
+
return float(recovered.sum()) / float(n_native)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def rbf_encode(distances, d_min=0.0, d_max=20.0, n_bins=16):
|
| 301 |
+
"""
|
| 302 |
+
RBF encoding of distances using Gaussian basis functions.
|
| 303 |
+
Returns: [*distances.shape, n_bins]
|
| 304 |
+
"""
|
| 305 |
+
centers = np.linspace(d_min, d_max, n_bins)
|
| 306 |
+
sigma = (d_max - d_min) / (n_bins - 1)
|
| 307 |
+
encoded = np.exp(-((distances[..., None] - centers) ** 2) / (2 * sigma ** 2))
|
| 308 |
+
return encoded.astype(np.float32)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# Candidate sidechain atoms for chi1 (first atom after CB)
|
| 312 |
+
_CHI1_ATOMS = ['CG', 'CG1', 'OG', 'OG1', 'SG']
|
| 313 |
+
# Candidate sidechain atoms for chi2 (second dihedral: CA-CB-XG-XD)
|
| 314 |
+
_CHI2_ATOMS = ['CD', 'CD1', 'SD', 'OD1', 'ND1', 'CE', 'NE', 'OE1']
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def _dihedral_4pts(p0, p1, p2, p3):
|
| 318 |
+
"""Compute dihedral angle between four 3D points (radians)."""
|
| 319 |
+
b1 = p1 - p0
|
| 320 |
+
b2 = p2 - p1
|
| 321 |
+
b3 = p3 - p2
|
| 322 |
+
n1 = np.cross(b1, b2)
|
| 323 |
+
n2 = np.cross(b2, b3)
|
| 324 |
+
n1_norm = np.linalg.norm(n1)
|
| 325 |
+
n2_norm = np.linalg.norm(n2)
|
| 326 |
+
if n1_norm < 1e-6 or n2_norm < 1e-6:
|
| 327 |
+
return 0.0
|
| 328 |
+
n1 = n1 / n1_norm
|
| 329 |
+
n2 = n2 / n2_norm
|
| 330 |
+
m1 = np.cross(n1, b2 / (np.linalg.norm(b2) + 1e-8))
|
| 331 |
+
return np.arctan2(np.dot(m1, n2), np.dot(n1, n2))
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def compute_chi_angles(residues, mask):
|
| 335 |
+
"""
|
| 336 |
+
Compute chi1 and chi2 sidechain torsion angles for each residue.
|
| 337 |
+
|
| 338 |
+
Chi1: N - CA - CB - XG (first sidechain dihedral)
|
| 339 |
+
Chi2: CA - CB - XG - XD (second sidechain dihedral)
|
| 340 |
+
|
| 341 |
+
For residues lacking the atoms (Gly, or missing coordinates), returns zeros.
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
chi_feats: [N, 4] (sin_chi1, cos_chi1, sin_chi2, cos_chi2)
|
| 345 |
+
"""
|
| 346 |
+
N = len(residues)
|
| 347 |
+
chi_feats = np.zeros((N, 4), dtype=np.float32)
|
| 348 |
+
|
| 349 |
+
for i, res in enumerate(residues):
|
| 350 |
+
if not mask[i]:
|
| 351 |
+
continue
|
| 352 |
+
atoms = {atom.get_name(): atom.get_vector().get_array() for atom in res.get_atoms()
|
| 353 |
+
if atom.get_name() in ('N', 'CA', 'CB') + tuple(_CHI1_ATOMS) + tuple(_CHI2_ATOMS)}
|
| 354 |
+
|
| 355 |
+
n_pos = atoms.get('N')
|
| 356 |
+
ca_pos = atoms.get('CA')
|
| 357 |
+
cb_pos = atoms.get('CB')
|
| 358 |
+
|
| 359 |
+
if n_pos is None or ca_pos is None or cb_pos is None:
|
| 360 |
+
continue
|
| 361 |
+
|
| 362 |
+
# Chi1: N - CA - CB - XG
|
| 363 |
+
xg_pos = None
|
| 364 |
+
for aname in _CHI1_ATOMS:
|
| 365 |
+
if aname in atoms:
|
| 366 |
+
xg_pos = atoms[aname]
|
| 367 |
+
break
|
| 368 |
+
|
| 369 |
+
if xg_pos is not None:
|
| 370 |
+
chi1 = _dihedral_4pts(np.array(n_pos), np.array(ca_pos),
|
| 371 |
+
np.array(cb_pos), np.array(xg_pos))
|
| 372 |
+
chi_feats[i, 0] = np.sin(chi1)
|
| 373 |
+
chi_feats[i, 1] = np.cos(chi1)
|
| 374 |
+
|
| 375 |
+
# Chi2: CA - CB - XG - XD
|
| 376 |
+
xd_pos = None
|
| 377 |
+
for aname in _CHI2_ATOMS:
|
| 378 |
+
if aname in atoms:
|
| 379 |
+
xd_pos = atoms[aname]
|
| 380 |
+
break
|
| 381 |
+
|
| 382 |
+
if xd_pos is not None:
|
| 383 |
+
chi2 = _dihedral_4pts(np.array(ca_pos), np.array(cb_pos),
|
| 384 |
+
np.array(xg_pos), np.array(xd_pos))
|
| 385 |
+
chi_feats[i, 2] = np.sin(chi2)
|
| 386 |
+
chi_feats[i, 3] = np.cos(chi2)
|
| 387 |
+
|
| 388 |
+
return chi_feats
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def get_cb_positions(residues, coords, mask):
|
| 392 |
+
"""
|
| 393 |
+
Return CB positions for each residue (CA position for Gly or missing CB).
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
cb_pos: [N, 3]
|
| 397 |
+
"""
|
| 398 |
+
N = len(residues)
|
| 399 |
+
cb_pos = coords[:, 1, :].copy() # default to CA
|
| 400 |
+
|
| 401 |
+
for i, res in enumerate(residues):
|
| 402 |
+
if not mask[i]:
|
| 403 |
+
continue
|
| 404 |
+
try:
|
| 405 |
+
cb_pos[i] = res['CB'].get_vector().get_array()
|
| 406 |
+
except KeyError:
|
| 407 |
+
pass # Gly or missing CB: keep CA
|
| 408 |
+
|
| 409 |
+
return cb_pos.astype(np.float32)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
# Simplified hydrophobicity groups for contact energy
|
| 413 |
+
_HYDROPHOBIC = {'ALA', 'VAL', 'ILE', 'LEU', 'MET', 'PHE', 'TRP', 'PRO', 'TYR'}
|
| 414 |
+
_POS_CHARGED = {'ARG', 'LYS', 'HIS'}
|
| 415 |
+
_NEG_CHARGED = {'ASP', 'GLU'}
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _residue_group(resname):
|
| 419 |
+
if resname in _HYDROPHOBIC:
|
| 420 |
+
return 'H'
|
| 421 |
+
if resname in _POS_CHARGED:
|
| 422 |
+
return '+'
|
| 423 |
+
if resname in _NEG_CHARGED:
|
| 424 |
+
return '-'
|
| 425 |
+
return 'P' # polar
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def compute_contact_energy(rec_residues, binder_residues,
|
| 429 |
+
rec_cb, binder_cb,
|
| 430 |
+
rec_mask, binder_mask,
|
| 431 |
+
cutoff: float = 8.0):
|
| 432 |
+
"""
|
| 433 |
+
Compute a simple CB-CB contact energy as a physics-based ddG proxy.
|
| 434 |
+
|
| 435 |
+
Uses a 4-group hydrophobicity potential:
|
| 436 |
+
HH: -1.0 (hydrophobic-hydrophobic, favorable)
|
| 437 |
+
+-: -0.5 (opposite charges, favorable)
|
| 438 |
+
H+/-: +0.3 (hydrophobic-charged, unfavorable)
|
| 439 |
+
else: 0.0
|
| 440 |
+
|
| 441 |
+
Returns a scalar in [0, 1] via sigmoid normalization.
|
| 442 |
+
"""
|
| 443 |
+
n_rec = len(rec_residues)
|
| 444 |
+
n_binder = len(binder_residues)
|
| 445 |
+
|
| 446 |
+
# CB-CB distance matrix [n_rec, n_binder]
|
| 447 |
+
diff = rec_cb[:, None, :] - binder_cb[None, :, :] # [n_rec, n_binder, 3]
|
| 448 |
+
dist = np.sqrt((diff ** 2).sum(axis=-1)) # [n_rec, n_binder]
|
| 449 |
+
|
| 450 |
+
# Mask invalid residues
|
| 451 |
+
dist[~rec_mask, :] = np.inf
|
| 452 |
+
dist[:, ~binder_mask] = np.inf
|
| 453 |
+
|
| 454 |
+
contact_mask = dist < cutoff
|
| 455 |
+
|
| 456 |
+
energy = 0.0
|
| 457 |
+
for i in range(n_rec):
|
| 458 |
+
for j in range(n_binder):
|
| 459 |
+
if not contact_mask[i, j]:
|
| 460 |
+
continue
|
| 461 |
+
gi = _residue_group(rec_residues[i].get_resname())
|
| 462 |
+
gj = _residue_group(binder_residues[j].get_resname())
|
| 463 |
+
if gi == 'H' and gj == 'H':
|
| 464 |
+
energy -= 1.0
|
| 465 |
+
elif (gi == '+' and gj == '-') or (gi == '-' and gj == '+'):
|
| 466 |
+
energy -= 0.5
|
| 467 |
+
elif (gi == 'H' and gj in ('+', '-')) or (gj == 'H' and gi in ('+', '-')):
|
| 468 |
+
energy += 0.3
|
| 469 |
+
|
| 470 |
+
# Normalize: sigmoid of (energy / 10) shifted so that 0 contacts → score 0.3
|
| 471 |
+
score = 1.0 / (1.0 + np.exp(-(energy - 5.0) / 5.0))
|
| 472 |
+
return float(score)
|
code/utils/sam.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sharpness-Aware Minimization (SAM) optimizer wrapper.
|
| 3 |
+
Seeks parameters in flatter minima for better OOD generalization.
|
| 4 |
+
Reference: Foret et al., "Sharpness-Aware Minimization for Efficiently Improving Generalization" (ICLR 2021)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SAM(torch.optim.Optimizer):
|
| 11 |
+
def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
|
| 12 |
+
defaults = dict(rho=rho, **kwargs)
|
| 13 |
+
super().__init__(params, defaults)
|
| 14 |
+
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
| 15 |
+
|
| 16 |
+
@torch.no_grad()
|
| 17 |
+
def first_step(self):
|
| 18 |
+
grad_norm = self._grad_norm()
|
| 19 |
+
for group in self.param_groups:
|
| 20 |
+
scale = group['rho'] / (grad_norm + 1e-12)
|
| 21 |
+
for p in group['params']:
|
| 22 |
+
if p.grad is None:
|
| 23 |
+
continue
|
| 24 |
+
e_w = p.grad * scale
|
| 25 |
+
p.add_(e_w)
|
| 26 |
+
self.state[p]['e_w'] = e_w
|
| 27 |
+
|
| 28 |
+
@torch.no_grad()
|
| 29 |
+
def second_step(self):
|
| 30 |
+
for group in self.param_groups:
|
| 31 |
+
for p in group['params']:
|
| 32 |
+
if p.grad is None:
|
| 33 |
+
continue
|
| 34 |
+
p.sub_(self.state[p]['e_w'])
|
| 35 |
+
self.base_optimizer.step()
|
| 36 |
+
|
| 37 |
+
def _grad_norm(self):
|
| 38 |
+
shared_device = self.param_groups[0]['params'][0].device
|
| 39 |
+
norm = torch.norm(
|
| 40 |
+
torch.stack([
|
| 41 |
+
p.grad.norm(p=2).to(shared_device)
|
| 42 |
+
for group in self.param_groups
|
| 43 |
+
for p in group['params']
|
| 44 |
+
if p.grad is not None
|
| 45 |
+
]),
|
| 46 |
+
p=2,
|
| 47 |
+
)
|
| 48 |
+
return norm
|
| 49 |
+
|
| 50 |
+
def step(self, closure=None):
|
| 51 |
+
raise NotImplementedError("SAM requires manual first_step() and second_step() calls")
|
| 52 |
+
|
| 53 |
+
def zero_grad(self):
|
| 54 |
+
self.base_optimizer.zero_grad()
|
data/sample/README.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sample dataset (in-repo)
|
| 2 |
+
|
| 3 |
+
This directory ships a single pre-built target — **`cam` (Calmodulin)** — so users
|
| 4 |
+
can run a smoke test of the training and evaluation pipeline without first
|
| 5 |
+
downloading the full multi-target dataset (~10 GB on Zenodo) or rebuilding
|
| 6 |
+
from raw PDB files (~30 min per target).
|
| 7 |
+
|
| 8 |
+
## Contents
|
| 9 |
+
|
| 10 |
+
```
|
| 11 |
+
sample/
|
| 12 |
+
└── cam/
|
| 13 |
+
├── train.pkl # 84 paired holo/apo complex graphs (~24 MB)
|
| 14 |
+
├── val.pkl # 12 validation graphs (~1.3 MB)
|
| 15 |
+
└── test.pkl # 96 held-out evaluation graphs (~25 MB)
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
Each pickle is a list of dicts produced by `code/data/build_dataset.py`.
|
| 19 |
+
Splits follow the family-stratified scheme used in the paper
|
| 20 |
+
(equivalent to `data/processed_familysplit/cam/` train+val and
|
| 21 |
+
`data/processed_familysplit_v5/cam/test.pkl` in the source tree).
|
| 22 |
+
|
| 23 |
+
## Smoke test (1-epoch end-to-end)
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
# Train both phases for 1 epoch
|
| 27 |
+
python code/scripts/train.py \
|
| 28 |
+
--target cam \
|
| 29 |
+
--phase both \
|
| 30 |
+
--data_dir data/sample \
|
| 31 |
+
--checkpoint_dir checkpoints_smoke \
|
| 32 |
+
--epochs 1 \
|
| 33 |
+
--no_wandb
|
| 34 |
+
|
| 35 |
+
# Evaluate
|
| 36 |
+
python code/scripts/evaluate.py \
|
| 37 |
+
--target cam \
|
| 38 |
+
--checkpoint checkpoints_smoke/best_phase2.pt \
|
| 39 |
+
--data_dir data/sample \
|
| 40 |
+
--outdir eval_smoke
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
Expected runtime: ~1 minute on a single GPU.
|
| 44 |
+
|
| 45 |
+
## Want more data?
|
| 46 |
+
|
| 47 |
+
- All 12 paper targets, pre-built: see `data/DOWNLOAD.md` for the Zenodo link.
|
| 48 |
+
- Build from raw PDBs locally: `scripts/build_data.sh paper12`.
|
| 49 |
+
- Per-target PDB lists and chain mappings: `data/target_lists/*.txt` (68 targets).
|
data/sample/cam/test.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:38e36d092bbcf4222c762e351fe305e8627f47c78c4acda74170c650ac09e1e8
|
| 3 |
+
size 25608454
|
data/sample/esm2_embeddings/cam/1IWQ_A.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c52e5a467dc0e73ef7139475bdaafd05e6df8872c345e31fb3da1d35497c00a1
|
| 3 |
+
size 712791
|
data/sample/esm2_embeddings/cam/1IWQ_B.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b7afc0e91e9095d3945adbf4fcf287fd167ac4fed51d10718db65ee80b89890e
|
| 3 |
+
size 93271
|
data/sample/esm2_embeddings/cam/1K93_A.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ef76f07da93f60f77ee00fbd719fd81621e50492d245bf62f1a820140e853d61
|
| 3 |
+
size 2484311
|
data/sample/esm2_embeddings/cam/1K93_B.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c9d65d07aaa8b1c219f5458d2a0502739f482b698c5579bbb4ebee681b5aecb
|
| 3 |
+
size 2392151
|
data/sample/esm2_embeddings/cam/1NWD_A.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4138c902d0ca0e40ff0c8b0522f71974ba15d3ec1a11db2bb00f9fe8227339f9
|
| 3 |
+
size 758871
|
data/sample/esm2_embeddings/cam/1NWD_B.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ee429a3b3cbaedf0471f02f7d636e4737f98630f655402317dd3438f8c69c30d
|
| 3 |
+
size 144471
|
data/sample/esm2_embeddings/cam/1SY9_A.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a7ca8538945c4b1e524ef3440c5aa7f73afd5d273034063518a0251e2a59f01
|
| 3 |
+
size 758871
|
data/sample/esm2_embeddings/cam/1SY9_B.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b6258dc842103adc893aacee23c195e2b7da3038e4704e0c9166e1e5581ac784
|
| 3 |
+
size 98391
|
data/sample/esm2_embeddings/cam/2BBM_A.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0630d3cbc267244f3a40ee83155373675f654cfd720aeaf03271c6438cb68b1d
|
| 3 |
+
size 758871
|
data/sample/esm2_embeddings/cam/2BBM_B.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:485ff76a5d7ed5459035608a25322601d8fc3d1b73acd49a13d1dcb1fec22ac3
|
| 3 |
+
size 134231
|
data/sample/esm2_embeddings/cam/2HQW_A.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e8b37dc4704f3c83e4c1d723da66d1873a8dbf27a6e9313450312b9e847f9f44
|
| 3 |
+
size 707671
|
data/sample/esm2_embeddings/cam/2HQW_B.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:564220203f56b067cda236c35cd153d3b20dcaf0fb2080fadb664eed8f413607
|
| 3 |
+
size 113751
|
data/sample/esm2_embeddings/cam/2O5G_A.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1187a8edd2d25b219c7ca4cc970a6a771fe8605ab2937163dc3ee595fad97bab
|
| 3 |
+
size 753751
|
data/sample/esm2_embeddings/cam/2O5G_B.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:604f44b0738f0c9d0f3e1410187a94d95dd7002cead28e23635ddd33bf419b60
|
| 3 |
+
size 98391
|
data/sample/esm2_embeddings/cam/3D33_A.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4c553c3e7dfe746babf92dfcae09dab8ecbcf1490cf475aead72fc5f2d30a43
|
| 3 |
+
size 472151
|