Feature Extraction
Transformers
Safetensors
esmfold2
biology
protein-structure
multimodal-protein-model
custom_code
Instructions to use Synthyra/ESMFold2-Fast with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/ESMFold2-Fast with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/ESMFold2-Fast", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/ESMFold2-Fast", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse files- LICENSE +9 -0
- README.md +91 -168
- __init__.py +12 -0
- configuration_esmc.py +89 -0
- configuration_esmc_sae.py +77 -0
- configuration_esmfold2.py +298 -0
- esmfold2_affine3d.py +561 -0
- esmfold2_aligner.py +102 -0
- esmfold2_atom_indexer.py +16 -0
- esmfold2_conformers.py +292 -0
- esmfold2_constants.py +563 -0
- esmfold2_constants_esm3.py +138 -0
- esmfold2_input_builder.py +255 -0
- esmfold2_metrics.py +374 -0
- esmfold2_misc.py +505 -0
- esmfold2_mmcif_parsing.py +470 -0
- esmfold2_molecular_complex.py +1226 -0
- esmfold2_msa.py +507 -0
- esmfold2_msa_filter_sequences.py +83 -0
- esmfold2_normalize_coordinates.py +80 -0
- esmfold2_output.py +225 -0
- esmfold2_paired_msa.py +246 -0
- esmfold2_parsing.py +113 -0
- esmfold2_predicted_aligned_error.py +105 -0
- esmfold2_prepare_input.py +1464 -0
- esmfold2_processor.py +356 -0
- esmfold2_protein_chain.py +1376 -0
- esmfold2_protein_complex.py +1241 -0
- esmfold2_protein_structure.py +307 -0
- esmfold2_residue_constants.py +1224 -0
- esmfold2_sequential_dataclass.py +158 -0
- esmfold2_system.py +46 -0
- esmfold2_types.py +34 -0
- esmfold2_utils_types.py +34 -0
- modeling_esmc.py +1667 -0
- modeling_esmc_sae.py +363 -0
- modeling_esmfold2.py +1288 -0
- modeling_esmfold2_common.py +0 -0
- protein_utils.py +488 -0
LICENSE
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**License (MIT)**
|
| 2 |
+
|
| 3 |
+
Copyright 2026 Chan Zuckerberg Biohub, Inc.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
| 6 |
+
|
| 7 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
| 8 |
+
|
| 9 |
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
README.md
CHANGED
|
@@ -1,199 +1,122 @@
|
|
| 1 |
---
|
| 2 |
library_name: transformers
|
| 3 |
-
tags:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
---
|
| 5 |
|
| 6 |
-
#
|
| 7 |
|
| 8 |
-
|
| 9 |
|
|
|
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
|
| 40 |
-
##
|
| 41 |
|
| 42 |
-
|
|
|
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
|
|
|
| 53 |
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
|
|
|
|
| 57 |
|
| 58 |
-
##
|
| 59 |
|
| 60 |
-
|
|
|
|
| 61 |
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
## Training Details
|
| 77 |
-
|
| 78 |
-
### Training Data
|
| 79 |
-
|
| 80 |
-
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
-
|
| 82 |
-
[More Information Needed]
|
| 83 |
-
|
| 84 |
-
### Training Procedure
|
| 85 |
-
|
| 86 |
-
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
-
|
| 88 |
-
#### Preprocessing [optional]
|
| 89 |
-
|
| 90 |
-
[More Information Needed]
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
#### Training Hyperparameters
|
| 94 |
-
|
| 95 |
-
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
-
|
| 97 |
-
#### Speeds, Sizes, Times [optional]
|
| 98 |
-
|
| 99 |
-
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
-
|
| 101 |
-
[More Information Needed]
|
| 102 |
-
|
| 103 |
-
## Evaluation
|
| 104 |
-
|
| 105 |
-
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
-
|
| 107 |
-
### Testing Data, Factors & Metrics
|
| 108 |
-
|
| 109 |
-
#### Testing Data
|
| 110 |
-
|
| 111 |
-
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
-
|
| 113 |
-
[More Information Needed]
|
| 114 |
-
|
| 115 |
-
#### Factors
|
| 116 |
-
|
| 117 |
-
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
-
|
| 119 |
-
[More Information Needed]
|
| 120 |
-
|
| 121 |
-
#### Metrics
|
| 122 |
-
|
| 123 |
-
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
-
|
| 125 |
-
[More Information Needed]
|
| 126 |
-
|
| 127 |
-
### Results
|
| 128 |
-
|
| 129 |
-
[More Information Needed]
|
| 130 |
-
|
| 131 |
-
#### Summary
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
## Model Examination [optional]
|
| 136 |
-
|
| 137 |
-
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
-
|
| 139 |
-
[More Information Needed]
|
| 140 |
-
|
| 141 |
-
## Environmental Impact
|
| 142 |
-
|
| 143 |
-
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
-
|
| 145 |
-
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
-
|
| 147 |
-
- **Hardware Type:** [More Information Needed]
|
| 148 |
-
- **Hours used:** [More Information Needed]
|
| 149 |
-
- **Cloud Provider:** [More Information Needed]
|
| 150 |
-
- **Compute Region:** [More Information Needed]
|
| 151 |
-
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
-
|
| 153 |
-
## Technical Specifications [optional]
|
| 154 |
-
|
| 155 |
-
### Model Architecture and Objective
|
| 156 |
-
|
| 157 |
-
[More Information Needed]
|
| 158 |
-
|
| 159 |
-
### Compute Infrastructure
|
| 160 |
-
|
| 161 |
-
[More Information Needed]
|
| 162 |
-
|
| 163 |
-
#### Hardware
|
| 164 |
-
|
| 165 |
-
[More Information Needed]
|
| 166 |
-
|
| 167 |
-
#### Software
|
| 168 |
-
|
| 169 |
-
[More Information Needed]
|
| 170 |
-
|
| 171 |
-
## Citation [optional]
|
| 172 |
-
|
| 173 |
-
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
-
|
| 175 |
-
**BibTeX:**
|
| 176 |
-
|
| 177 |
-
[More Information Needed]
|
| 178 |
-
|
| 179 |
-
**APA:**
|
| 180 |
-
|
| 181 |
-
[More Information Needed]
|
| 182 |
-
|
| 183 |
-
## Glossary [optional]
|
| 184 |
-
|
| 185 |
-
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
-
|
| 187 |
-
[More Information Needed]
|
| 188 |
-
|
| 189 |
-
## More Information [optional]
|
| 190 |
-
|
| 191 |
-
[More Information Needed]
|
| 192 |
-
|
| 193 |
-
## Model Card Authors [optional]
|
| 194 |
-
|
| 195 |
-
[More Information Needed]
|
| 196 |
-
|
| 197 |
-
## Model Card Contact
|
| 198 |
-
|
| 199 |
-
[More Information Needed]
|
|
|
|
| 1 |
---
|
| 2 |
library_name: transformers
|
| 3 |
+
tags:
|
| 4 |
+
- biology
|
| 5 |
+
- protein-structure
|
| 6 |
+
- esmfold2
|
| 7 |
+
- multimodal-protein-model
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# FastPLMs ESMFold2
|
| 11 |
|
| 12 |
+
FastPLMs ESMFold2 is a self-contained Hugging Face `AutoModel` wrapper for Biohub's ESMFold2 and ESMFold2-Fast structure predictors. It vendors the released Biohub ESMFold2 model code, ESMC backbone code, input builder, MSA helpers, and structure export utilities needed for remote-code loading.
|
| 13 |
|
| 14 |
+
## Load With AutoModel
|
| 15 |
|
| 16 |
+
```python
|
| 17 |
+
import torch
|
| 18 |
+
from transformers import AutoModel
|
| 19 |
|
| 20 |
+
model = AutoModel.from_pretrained(
|
| 21 |
+
"Synthyra/ESMFold2-Fast",
|
| 22 |
+
trust_remote_code=True,
|
| 23 |
+
dtype=torch.bfloat16,
|
| 24 |
+
device_map="cuda",
|
| 25 |
+
).eval()
|
| 26 |
+
```
|
| 27 |
|
| 28 |
+
Use `Synthyra/ESMFold2` for the full model and `Synthyra/ESMFold2-Fast` for the faster release variant.
|
| 29 |
|
| 30 |
+
## Fold One Protein
|
| 31 |
|
| 32 |
+
```python
|
| 33 |
+
sequence = "MKTLLILAVVAAALA"
|
| 34 |
|
| 35 |
+
result = model.fold_protein(
|
| 36 |
+
sequence,
|
| 37 |
+
num_loops=3,
|
| 38 |
+
num_sampling_steps=50,
|
| 39 |
+
num_diffusion_samples=1,
|
| 40 |
+
seed=0,
|
| 41 |
+
)
|
| 42 |
|
| 43 |
+
print(float(result.plddt.mean()))
|
| 44 |
+
print(float(result.ptm))
|
| 45 |
+
```
|
| 46 |
|
| 47 |
+
## Save mmCIF or PDB
|
| 48 |
|
| 49 |
+
```python
|
| 50 |
+
model.save_as_cif(result, "prediction.cif")
|
| 51 |
+
model.save_as_pdb(result, "prediction.pdb")
|
| 52 |
|
| 53 |
+
cif_text = model.result_to_cif(result)
|
| 54 |
+
pdb_text = model.result_to_pdb(result)
|
| 55 |
+
```
|
| 56 |
|
| 57 |
+
`result_to_cif` preserves the full `MolecularComplex`. `result_to_pdb` converts through Biohub's protein-only `ProteinComplex` representation, so use mmCIF for complexes with ligands or nucleic acids.
|
| 58 |
|
| 59 |
+
## Fold Complexes
|
| 60 |
|
| 61 |
+
```python
|
| 62 |
+
types = model.input_types
|
| 63 |
|
| 64 |
+
complex_input = types.StructurePredictionInput(
|
| 65 |
+
sequences=[
|
| 66 |
+
types.ProteinInput(id="A", sequence="MKTLLILAVVAAALA"),
|
| 67 |
+
types.DNAInput(id="B", sequence="GATAGC"),
|
| 68 |
+
types.LigandInput(id="L", ccd=["SAH"]),
|
| 69 |
+
]
|
| 70 |
+
)
|
| 71 |
|
| 72 |
+
result = model.fold(
|
| 73 |
+
complex_input,
|
| 74 |
+
num_loops=3,
|
| 75 |
+
num_sampling_steps=50,
|
| 76 |
+
num_diffusion_samples=1,
|
| 77 |
+
seed=0,
|
| 78 |
+
)
|
| 79 |
|
| 80 |
+
model.save_as_cif(result, "complex_prediction.cif")
|
| 81 |
+
```
|
| 82 |
|
| 83 |
+
## Use MSAs
|
| 84 |
|
| 85 |
+
```python
|
| 86 |
+
types = model.input_types
|
| 87 |
|
| 88 |
+
msa = types.MSA.from_a3m("query.a3m", max_sequences=128)
|
| 89 |
+
input_with_msa = types.StructurePredictionInput(
|
| 90 |
+
sequences=[
|
| 91 |
+
types.ProteinInput(id="A", sequence=msa.query, msa=msa),
|
| 92 |
+
]
|
| 93 |
+
)
|
| 94 |
|
| 95 |
+
result = model.fold(input_with_msa, num_sampling_steps=50, seed=0)
|
| 96 |
+
```
|
| 97 |
|
| 98 |
+
## Raw Tensor Inference
|
| 99 |
|
| 100 |
+
```python
|
| 101 |
+
features, chain_infos = model.prepare_structure_input(complex_input, seed=0)
|
| 102 |
|
| 103 |
+
with torch.inference_mode():
|
| 104 |
+
output = model(
|
| 105 |
+
**features,
|
| 106 |
+
num_loops=3,
|
| 107 |
+
num_sampling_steps=50,
|
| 108 |
+
num_diffusion_samples=1,
|
| 109 |
+
)
|
| 110 |
|
| 111 |
+
decoded = model.input_builder.decode(output, features, chain_infos)
|
| 112 |
+
```
|
| 113 |
|
| 114 |
+
Set `load_esmc=False` when loading if you want to provide precomputed `lm_hidden_states` manually or run folding-trunk tests without loading the 6B ESMC backbone:
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
model = AutoModel.from_pretrained(
|
| 118 |
+
"Synthyra/ESMFold2-Fast",
|
| 119 |
+
trust_remote_code=True,
|
| 120 |
+
load_esmc=False,
|
| 121 |
+
).cuda().eval()
|
| 122 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
from .configuration_esmfold2 import ESMFold2Config
|
| 5 |
+
from .modeling_esmfold2 import ESMFold2Model
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def ensure_vendored_esm() -> None:
|
| 9 |
+
sys.modules["esm"] = importlib.import_module(f"{__name__}.esm")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__all__ = ["ESMFold2Config", "ESMFold2Model", "ensure_vendored_esm"]
|
configuration_esmc.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Biohub. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""ESMC model configuration."""
|
| 15 |
+
|
| 16 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ESMCConfig(PretrainedConfig):
|
| 20 |
+
"""
|
| 21 |
+
This is the configuration class to store the configuration of a [`ESMCModel`]. It is used to
|
| 22 |
+
instantiate an ESMC model according to the specified arguments, defining the model architecture.
|
| 23 |
+
|
| 24 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model
|
| 25 |
+
outputs. Read the documentation from [`PretrainedConfig`] for more information.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
vocab_size (`int`, *optional*, defaults to 64):
|
| 29 |
+
Vocabulary size of the ESMC model. Defines the number of different amino acid tokens that
|
| 30 |
+
can be represented by the ``input_ids`` passed to [`ESMCModel`].
|
| 31 |
+
d_model (`int`, *optional*, defaults to 2560):
|
| 32 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 33 |
+
n_heads (`int`, *optional*, defaults to 40):
|
| 34 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 35 |
+
n_layers (`int`, *optional*, defaults to 80):
|
| 36 |
+
Number of hidden layers in the Transformer encoder.
|
| 37 |
+
pad_token_id (`int`, *optional*, defaults to 1):
|
| 38 |
+
Index of the padding token in the vocabulary (``"<pad>"``).
|
| 39 |
+
mask_token_id (`int`, *optional*, defaults to 32):
|
| 40 |
+
Index of the mask token in the vocabulary (``"<mask>"``), used for masked language modelling.
|
| 41 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 42 |
+
The standard deviation of the truncated normal initialiser for weight matrix initialisation.
|
| 43 |
+
classifier_dropout (`float`, *optional*, defaults to 0.1):
|
| 44 |
+
Dropout ratio for the classification head.
|
| 45 |
+
|
| 46 |
+
Examples:
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
>>> from transformers import ESMCConfig, ESMCModel
|
| 50 |
+
|
| 51 |
+
>>> # Initializing an ESMC EvolutionaryScale/esmc-600m-2024-12 style configuration
|
| 52 |
+
>>> configuration = ESMCConfig()
|
| 53 |
+
|
| 54 |
+
>>> # Initializing a model (with random weights) from the EvolutionaryScale/esmc-600m-2024-12 style configuration
|
| 55 |
+
>>> model = ESMCModel(configuration)
|
| 56 |
+
|
| 57 |
+
>>> # Accessing the model configuration
|
| 58 |
+
>>> configuration = model.config
|
| 59 |
+
```
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
model_type = "esmc"
|
| 63 |
+
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
vocab_size: int = 64,
|
| 67 |
+
d_model: int = 2560,
|
| 68 |
+
n_heads: int = 40,
|
| 69 |
+
n_layers: int = 80,
|
| 70 |
+
pad_token_id: int = 1,
|
| 71 |
+
mask_token_id: int = 32,
|
| 72 |
+
initializer_range: float = 0.02,
|
| 73 |
+
classifier_dropout: float = 0.1,
|
| 74 |
+
**kwargs,
|
| 75 |
+
):
|
| 76 |
+
super().__init__(
|
| 77 |
+
pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
self.vocab_size = vocab_size
|
| 81 |
+
self.d_model = d_model
|
| 82 |
+
self.n_heads = n_heads
|
| 83 |
+
self.n_layers = n_layers
|
| 84 |
+
self.initializer_range = initializer_range
|
| 85 |
+
self.classifier_dropout = classifier_dropout
|
| 86 |
+
self.tie_word_embeddings = False
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
__all__ = ["ESMCConfig"]
|
configuration_esmc_sae.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Biohub. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""ESMC sparse autoencoder (SAE) configuration."""
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
|
| 18 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class ESMCSAEParams:
|
| 23 |
+
"""Parameters for one backbone layer's SAE inside :class:`ESMCSAEModel`.
|
| 24 |
+
|
| 25 |
+
The SAE itself is an internal ``nn.Module``; this dataclass just bundles
|
| 26 |
+
the handful of fields needed to instantiate one.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
d_model: int = 2560
|
| 30 |
+
codebook_dim: int = 65536
|
| 31 |
+
k: int = 64
|
| 32 |
+
layer: int = 0
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ESMCSAEConfig(PretrainedConfig):
|
| 36 |
+
"""
|
| 37 |
+
Configuration class for [`ESMCSAEModel`] — a container that holds one
|
| 38 |
+
SAE per backbone layer for a fixed ``(model, codebook_dim, k)`` group.
|
| 39 |
+
|
| 40 |
+
All SAEs in a container share ``d_model``, ``codebook_dim``, and ``k``;
|
| 41 |
+
they differ only in the backbone layer they were trained on.
|
| 42 |
+
``available_layers`` lists the backbone-layer indices the repo ships;
|
| 43 |
+
each entry ``i`` is stored on disk as ``layer_{i}.safetensors`` (the
|
| 44 |
+
filename index *is* the backbone layer, so a single-layer repo for
|
| 45 |
+
layer 23 stores ``layer_23.safetensors``).
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
d_model (`int`, *optional*, defaults to 2560):
|
| 49 |
+
Dimensionality of the ESMC hidden states fed into the SAEs.
|
| 50 |
+
codebook_dim (`int`, *optional*, defaults to 65536):
|
| 51 |
+
Number of sparse features in each SAE's codebook.
|
| 52 |
+
k (`int`, *optional*, defaults to 64):
|
| 53 |
+
Top-k sparsity per SAE.
|
| 54 |
+
available_layers (`list[int]`, *optional*, defaults to ``[0]``):
|
| 55 |
+
Which backbone-layer indices the repo ships.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
model_type = "esmc_sae"
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
d_model: int = 2560,
|
| 63 |
+
codebook_dim: int = 65536,
|
| 64 |
+
k: int = 64,
|
| 65 |
+
available_layers: list[int] | None = None,
|
| 66 |
+
**kwargs,
|
| 67 |
+
):
|
| 68 |
+
super().__init__(**kwargs)
|
| 69 |
+
self.d_model = d_model
|
| 70 |
+
self.codebook_dim = codebook_dim
|
| 71 |
+
self.k = k
|
| 72 |
+
self.available_layers = (
|
| 73 |
+
list(available_layers) if available_layers is not None else [0]
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
__all__ = ["ESMCSAEConfig", "ESMCSAEParams"]
|
configuration_esmfold2.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Biohub. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""ESMFold2 model configuration."""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from dataclasses import asdict, dataclass, field
|
| 19 |
+
|
| 20 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 21 |
+
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
# Nested dataclass configs
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
|
| 26 |
+
_DEFAULT_ESMC_HF_REPO = "biohub/ESMC-6B"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class MSAEncoderConfig:
|
| 31 |
+
"""Config for the optional MSA encoder module (Large MSA models only)."""
|
| 32 |
+
|
| 33 |
+
enabled: bool = False
|
| 34 |
+
d_msa: int = 128
|
| 35 |
+
d_hidden: int = 32
|
| 36 |
+
n_layers: int = 4
|
| 37 |
+
n_heads_msa: int = 8
|
| 38 |
+
msa_head_width: int = 32
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class ParcaeConfig:
|
| 43 |
+
"""Release-only config for the parcae diffusion-loop scheduler."""
|
| 44 |
+
|
| 45 |
+
enabled: bool = True
|
| 46 |
+
poisson_mean: float = 3.0
|
| 47 |
+
min_steps: int = 1
|
| 48 |
+
max_steps: int | None = 6
|
| 49 |
+
coda_n_layers: int = 2
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class LMEncoderConfig:
|
| 54 |
+
"""Release-only config for the LM-side pair encoder."""
|
| 55 |
+
|
| 56 |
+
enabled: bool = True
|
| 57 |
+
n_layers: int = 4
|
| 58 |
+
lm_dropout: float = 0.25
|
| 59 |
+
per_loop_lm_dropout: bool = True
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class AtomAttentionConfig:
|
| 64 |
+
"""Config for SWA atom encoder/decoder with 3D RoPE."""
|
| 65 |
+
|
| 66 |
+
d_atom: int = 128
|
| 67 |
+
d_token: int = 768
|
| 68 |
+
n_blocks: int = 3
|
| 69 |
+
n_heads: int = 4
|
| 70 |
+
swa_window_size: int = 128
|
| 71 |
+
expansion_ratio: int = 2
|
| 72 |
+
# 3D RoPE config
|
| 73 |
+
spatial_rope_base_frequency: float = 20.0
|
| 74 |
+
n_spatial_rope_pairs_per_axis: int = 2
|
| 75 |
+
n_uid_rope_pairs: int = 10
|
| 76 |
+
uid_rope_base_frequency: float = 10000.0
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass
|
| 80 |
+
class FoldingTrunkConfig:
|
| 81 |
+
n_layers: int = 24
|
| 82 |
+
n_heads: int = 8
|
| 83 |
+
dropout: float = 0.0
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@dataclass
|
| 87 |
+
class InputsEmbedderConfig:
|
| 88 |
+
d_inputs: int = 451
|
| 89 |
+
atom_encoder: AtomAttentionConfig = field(default_factory=AtomAttentionConfig)
|
| 90 |
+
|
| 91 |
+
def __post_init__(self):
|
| 92 |
+
if isinstance(self.atom_encoder, dict):
|
| 93 |
+
self.atom_encoder = AtomAttentionConfig(**self.atom_encoder)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@dataclass
|
| 97 |
+
class DiffusionModuleConfig:
|
| 98 |
+
"""Config for the DiffusionModule."""
|
| 99 |
+
|
| 100 |
+
sigma_data: float = 16.0
|
| 101 |
+
c_atom: int = 128
|
| 102 |
+
c_token: int = 768
|
| 103 |
+
c_z: int = 256
|
| 104 |
+
c_s_inputs: int = 451
|
| 105 |
+
fourier_dim: int = 256
|
| 106 |
+
relpos_r_max: int = 32
|
| 107 |
+
relpos_s_max: int = 2
|
| 108 |
+
atom_num_blocks: int = 3
|
| 109 |
+
atom_num_heads: int = 4
|
| 110 |
+
token_num_blocks: int = 12
|
| 111 |
+
token_num_heads: int = 16
|
| 112 |
+
transition_multiplier: int = 2
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@dataclass
|
| 116 |
+
class DiffusionStructureHeadConfig:
|
| 117 |
+
"""Config for the diffusion-based structure prediction head."""
|
| 118 |
+
|
| 119 |
+
diffusion_module: DiffusionModuleConfig = field(
|
| 120 |
+
default_factory=DiffusionModuleConfig
|
| 121 |
+
)
|
| 122 |
+
distogram_bins: int = 128
|
| 123 |
+
|
| 124 |
+
# Training noise: sigma ~ sigma_data * exp(mu + sigma * N(0,1))
|
| 125 |
+
train_noise_log_mean: float = -1.2
|
| 126 |
+
train_noise_log_std: float = 1.5
|
| 127 |
+
|
| 128 |
+
# Sampling defaults (ODE)
|
| 129 |
+
gamma_0: float = 0.605
|
| 130 |
+
gamma_min: float = 1.107
|
| 131 |
+
noise_scale: float = 0.0
|
| 132 |
+
step_scale: float = 1.0
|
| 133 |
+
|
| 134 |
+
# Inference schedule defaults
|
| 135 |
+
inference_s_max: float = 160.0
|
| 136 |
+
inference_s_min: float = 4e-4
|
| 137 |
+
inference_p: float = 8.0
|
| 138 |
+
inference_num_steps: int = 68
|
| 139 |
+
|
| 140 |
+
def __post_init__(self):
|
| 141 |
+
if isinstance(self.diffusion_module, dict):
|
| 142 |
+
self.diffusion_module = DiffusionModuleConfig(**self.diffusion_module)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@dataclass
|
| 146 |
+
class ConfidenceHeadConfig:
|
| 147 |
+
enabled: bool = True
|
| 148 |
+
num_plddt_bins: int = 50
|
| 149 |
+
num_pde_bins: int = 64
|
| 150 |
+
num_pae_bins: int = 64
|
| 151 |
+
min_dist: float = 2.0
|
| 152 |
+
max_dist: float = 52.0
|
| 153 |
+
distogram_bins: int = 128
|
| 154 |
+
folding_trunk: FoldingTrunkConfig = field(
|
| 155 |
+
default_factory=lambda: FoldingTrunkConfig(n_layers=4)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def __post_init__(self):
|
| 159 |
+
if isinstance(self.folding_trunk, dict):
|
| 160 |
+
self.folding_trunk = FoldingTrunkConfig(**self.folding_trunk)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
# Top-level config
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ESMFold2Config(PretrainedConfig):
|
| 169 |
+
"""
|
| 170 |
+
Configuration for the ESMFold2 structure prediction model.
|
| 171 |
+
|
| 172 |
+
Uses SWA atom encoders with 3D RoPE, a diffusion transformer,
|
| 173 |
+
a folding trunk, and an ESMC 6B PLM backbone.
|
| 174 |
+
|
| 175 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control
|
| 176 |
+
the model outputs. Read the documentation from [`PretrainedConfig`] for more
|
| 177 |
+
information.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
d_single (`int`, defaults to 384):
|
| 181 |
+
Dimensionality of single (per-residue) representations.
|
| 182 |
+
d_pair (`int`, defaults to 256):
|
| 183 |
+
Dimensionality of pair (residue-residue) representations.
|
| 184 |
+
n_relative_residx_bins (`int`, defaults to 32):
|
| 185 |
+
Number of bins for relative residue index encoding.
|
| 186 |
+
n_relative_chain_bins (`int`, defaults to 2):
|
| 187 |
+
Number of bins for relative chain encoding.
|
| 188 |
+
num_loops (`int`, defaults to 10):
|
| 189 |
+
Number of trunk loops for iterative refinement.
|
| 190 |
+
num_diffusion_samples (`int`, defaults to 8):
|
| 191 |
+
Number of parallel structure predictions to generate.
|
| 192 |
+
lm_dropout (`float`, defaults to 0.0):
|
| 193 |
+
Dropout probability on LM pair embeddings. When > 0, dropout is
|
| 194 |
+
applied with ``training=True`` (including at inference) to match
|
| 195 |
+
the experimental training recipe used by binder design.
|
| 196 |
+
force_lm_dropout_during_inference (`bool`, defaults to False):
|
| 197 |
+
When True, apply ``lm_dropout`` even when ``model.eval()`` and
|
| 198 |
+
``lm_dropout`` > 0. Binder-design loads set this to True.
|
| 199 |
+
disable_msa_features (`bool`, defaults to False):
|
| 200 |
+
When True, zero out MSA-derived ``profile`` and ``deletion_mean``
|
| 201 |
+
before the inputs embedder (experimental medium/large checkpoints).
|
| 202 |
+
inputs (`InputsEmbedderConfig`):
|
| 203 |
+
Configuration for the inputs embedder module.
|
| 204 |
+
folding_trunk (`FoldingTrunkConfig`):
|
| 205 |
+
Configuration for the folding trunk.
|
| 206 |
+
structure_head (`DiffusionStructureHeadConfig`):
|
| 207 |
+
Configuration for the diffusion-based structure prediction head.
|
| 208 |
+
confidence_head (`ConfidenceHeadConfig`):
|
| 209 |
+
Configuration for the confidence prediction head.
|
| 210 |
+
|
| 211 |
+
Examples:
|
| 212 |
+
|
| 213 |
+
```python
|
| 214 |
+
>>> from transformers import ESMFold2Config, ESMFold2ExperimentalModel
|
| 215 |
+
|
| 216 |
+
>>> # Initializing an ESMFold2 configuration
|
| 217 |
+
>>> configuration = ESMFold2Config(type="experimental")
|
| 218 |
+
|
| 219 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 220 |
+
>>> model = ESMFold2ExperimentalModel(configuration)
|
| 221 |
+
|
| 222 |
+
>>> # Accessing the model configuration
|
| 223 |
+
>>> configuration = model.config
|
| 224 |
+
```
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
model_type = "esmfold2"
|
| 228 |
+
has_no_defaults_at_init = True
|
| 229 |
+
|
| 230 |
+
def __init__(self, **kwargs):
|
| 231 |
+
super().__init__(**kwargs)
|
| 232 |
+
|
| 233 |
+
self.type: str = kwargs.get("type", "release")
|
| 234 |
+
if self.type not in ("release", "experimental"):
|
| 235 |
+
raise ValueError(
|
| 236 |
+
f"ESMFold2Config.type must be 'release' or 'experimental', "
|
| 237 |
+
f"got {self.type!r}"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# Top-level scalar fields
|
| 241 |
+
self.d_single: int = kwargs.get("d_single", 384)
|
| 242 |
+
self.d_pair: int = kwargs.get("d_pair", 256)
|
| 243 |
+
self.n_relative_residx_bins: int = kwargs.get("n_relative_residx_bins", 32)
|
| 244 |
+
self.n_relative_chain_bins: int = kwargs.get("n_relative_chain_bins", 2)
|
| 245 |
+
self.num_loops: int = kwargs.get("num_loops", 10)
|
| 246 |
+
self.num_diffusion_samples: int = kwargs.get("num_diffusion_samples", 8)
|
| 247 |
+
# If True, ``profile`` / ``deletion_mean`` are zeroed before the inputs
|
| 248 |
+
# embedder.
|
| 249 |
+
self.disable_msa_features: bool = kwargs.get("disable_msa_features", False)
|
| 250 |
+
self.lm_dropout: float = kwargs.get("lm_dropout", 0.0)
|
| 251 |
+
self.force_lm_dropout_during_inference: bool = kwargs.get(
|
| 252 |
+
"force_lm_dropout_during_inference", False
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
self.lm_d_model: int = kwargs.get("lm_d_model", 2560)
|
| 256 |
+
self.lm_num_layers: int = kwargs.get("lm_num_layers", 80)
|
| 257 |
+
# Required, no default — every shipped HF export must name its ESMC backbone.
|
| 258 |
+
self.esmc_id: str = kwargs.get("esmc_id", _DEFAULT_ESMC_HF_REPO)
|
| 259 |
+
|
| 260 |
+
def _init_nested(cls, val):
|
| 261 |
+
if isinstance(val, cls):
|
| 262 |
+
return val
|
| 263 |
+
if isinstance(val, dict):
|
| 264 |
+
return cls(**val)
|
| 265 |
+
return cls()
|
| 266 |
+
|
| 267 |
+
self.inputs = _init_nested(InputsEmbedderConfig, kwargs.get("inputs"))
|
| 268 |
+
self.folding_trunk = _init_nested(
|
| 269 |
+
FoldingTrunkConfig, kwargs.get("folding_trunk")
|
| 270 |
+
)
|
| 271 |
+
self.structure_head = _init_nested(
|
| 272 |
+
DiffusionStructureHeadConfig, kwargs.get("structure_head")
|
| 273 |
+
)
|
| 274 |
+
self.confidence_head = _init_nested(
|
| 275 |
+
ConfidenceHeadConfig, kwargs.get("confidence_head")
|
| 276 |
+
)
|
| 277 |
+
self.msa_encoder = _init_nested(MSAEncoderConfig, kwargs.get("msa_encoder"))
|
| 278 |
+
# Release-only modules — ignored when ``type == "experimental"``.
|
| 279 |
+
self.parcae = _init_nested(ParcaeConfig, kwargs.get("parcae"))
|
| 280 |
+
self.lm_encoder = _init_nested(LMEncoderConfig, kwargs.get("lm_encoder"))
|
| 281 |
+
# If True, MSA encoder output replaces the pair stream; if False, it is added.
|
| 282 |
+
self.msa_encoder_overwrite: bool = bool(
|
| 283 |
+
kwargs.get("msa_encoder_overwrite", True)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
def to_dict(self):
|
| 287 |
+
output = super().to_dict()
|
| 288 |
+
output["inputs"] = asdict(self.inputs)
|
| 289 |
+
output["folding_trunk"] = asdict(self.folding_trunk)
|
| 290 |
+
output["structure_head"] = asdict(self.structure_head)
|
| 291 |
+
output["confidence_head"] = asdict(self.confidence_head)
|
| 292 |
+
output["msa_encoder"] = asdict(self.msa_encoder)
|
| 293 |
+
output["parcae"] = asdict(self.parcae)
|
| 294 |
+
output["lm_encoder"] = asdict(self.lm_encoder)
|
| 295 |
+
return output
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
__all__ = ["ESMFold2Config", "MSAEncoderConfig", "ParcaeConfig", "LMEncoderConfig"]
|
esmfold2_affine3d.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import typing as T
|
| 4 |
+
from abc import ABC
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from typing_extensions import Self
|
| 10 |
+
|
| 11 |
+
from .esmfold2_misc import fp32_autocast_context
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Rotation(ABC):
|
| 15 |
+
@classmethod
|
| 16 |
+
def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ...
|
| 17 |
+
|
| 18 |
+
@classmethod
|
| 19 |
+
def random(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ...
|
| 20 |
+
|
| 21 |
+
def __getitem__(self, idx: T.Any) -> Self: ...
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def tensor(self) -> torch.Tensor:
|
| 25 |
+
# We claim that this should be zero-cost abstraction that returns the raw tensor backing this
|
| 26 |
+
# object. The raw tensor should always have exactly 1 more dim than self.shape, which should be
|
| 27 |
+
# implemented using reshaping
|
| 28 |
+
...
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def shape(self) -> torch.Size:
|
| 32 |
+
# The "shape" of the rotation, as if it was a torch.tensor object
|
| 33 |
+
# This means that 1x4 quaternions are treated as size (1,) for example
|
| 34 |
+
...
|
| 35 |
+
|
| 36 |
+
def as_matrix(self) -> RotationMatrix: ...
|
| 37 |
+
|
| 38 |
+
def as_quat(self, normalize: bool = False) -> RotationQuat: ...
|
| 39 |
+
|
| 40 |
+
def compose(self, other: Self) -> Self:
|
| 41 |
+
# To be safe, we force users to explicitly convert between rotation types.
|
| 42 |
+
...
|
| 43 |
+
|
| 44 |
+
def convert_compose(self, other: Self) -> Self:
|
| 45 |
+
# This function will automatically convert between types of rotations
|
| 46 |
+
...
|
| 47 |
+
|
| 48 |
+
def apply(self, p: torch.Tensor) -> torch.Tensor:
|
| 49 |
+
# rotates points by this rotation object
|
| 50 |
+
...
|
| 51 |
+
|
| 52 |
+
def invert(self) -> Self: ...
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def dtype(self) -> torch.dtype:
|
| 56 |
+
return self.tensor.dtype
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def device(self) -> torch.device:
|
| 60 |
+
return self.tensor.device
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def requires_grad(self) -> bool:
|
| 64 |
+
return self.tensor.requires_grad
|
| 65 |
+
|
| 66 |
+
@classmethod
|
| 67 |
+
def _from_tensor(cls, t: torch.Tensor) -> Self:
|
| 68 |
+
# This function exists to simplify the below functions, esp type signatures
|
| 69 |
+
# Its implementation is different from Affine3D.from_tensor and does not
|
| 70 |
+
# autodetect rotation types.
|
| 71 |
+
return cls(t) # type: ignore
|
| 72 |
+
|
| 73 |
+
def to(self, **kwargs) -> Self:
|
| 74 |
+
return self._from_tensor(self.tensor.to(**kwargs))
|
| 75 |
+
|
| 76 |
+
def detach(self, *args, **kwargs) -> Self:
|
| 77 |
+
return self._from_tensor(self.tensor.detach(**kwargs))
|
| 78 |
+
|
| 79 |
+
def tensor_apply(self, func) -> Self:
|
| 80 |
+
# Applys a function to the underlying tensor
|
| 81 |
+
return self._from_tensor(
|
| 82 |
+
torch.stack([func(x) for x in self.tensor.unbind(dim=-1)], dim=-1)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class RotationMatrix(Rotation):
|
| 87 |
+
def __init__(self, rots: torch.Tensor):
|
| 88 |
+
if rots.shape[-1] == 9:
|
| 89 |
+
rots = rots.unflatten(-1, (3, 3))
|
| 90 |
+
assert rots.shape[-1] == 3
|
| 91 |
+
assert rots.shape[-2] == 3
|
| 92 |
+
# Force full precision
|
| 93 |
+
rots = rots.to(torch.float32)
|
| 94 |
+
self._rots = rots
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def identity(cls, shape, **tensor_kwargs):
|
| 98 |
+
rots = torch.eye(3, **tensor_kwargs)
|
| 99 |
+
rots = rots.view(*[1 for _ in range(len(shape))], 3, 3)
|
| 100 |
+
rots = rots.expand(*shape, -1, -1)
|
| 101 |
+
return cls(rots)
|
| 102 |
+
|
| 103 |
+
@classmethod
|
| 104 |
+
def random(cls, shape, **tensor_kwargs):
|
| 105 |
+
return RotationQuat.random(shape, **tensor_kwargs).as_matrix()
|
| 106 |
+
|
| 107 |
+
def __getitem__(self, idx: T.Any) -> RotationMatrix:
|
| 108 |
+
indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
|
| 109 |
+
return RotationMatrix(self._rots[indices + (slice(None), slice(None))])
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def shape(self) -> torch.Size:
|
| 113 |
+
return self._rots.shape[:-2]
|
| 114 |
+
|
| 115 |
+
def as_matrix(self) -> RotationMatrix:
|
| 116 |
+
return self
|
| 117 |
+
|
| 118 |
+
def as_quat(self, normalize: bool = False) -> RotationQuat:
|
| 119 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
| 120 |
+
self._rots.flatten(-2), dim=-1
|
| 121 |
+
)
|
| 122 |
+
q_abs = _sqrt_subgradient(
|
| 123 |
+
torch.stack(
|
| 124 |
+
[
|
| 125 |
+
1.0 + m00 + m11 + m22,
|
| 126 |
+
1.0 + m00 - m11 - m22,
|
| 127 |
+
1.0 - m00 + m11 - m22,
|
| 128 |
+
1.0 - m00 - m11 + m22,
|
| 129 |
+
],
|
| 130 |
+
dim=-1,
|
| 131 |
+
)
|
| 132 |
+
)
|
| 133 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
| 134 |
+
quat_by_rijk = torch.stack(
|
| 135 |
+
[
|
| 136 |
+
x
|
| 137 |
+
for lst in [
|
| 138 |
+
[q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01],
|
| 139 |
+
[m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20],
|
| 140 |
+
[m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21],
|
| 141 |
+
[m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2],
|
| 142 |
+
]
|
| 143 |
+
for x in lst
|
| 144 |
+
],
|
| 145 |
+
dim=-1,
|
| 146 |
+
).unflatten(-1, (4, 4))
|
| 147 |
+
|
| 148 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
| 149 |
+
# the candidate won't be picked.
|
| 150 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 151 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 152 |
+
|
| 153 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
| 154 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
| 155 |
+
# We manually implement one_hot so torch.compile works
|
| 156 |
+
one_hot = torch.zeros_like(q_abs, dtype=torch.bool)
|
| 157 |
+
one_hot.scatter_(-1, q_abs.argmax(dim=-1, keepdim=True), True)
|
| 158 |
+
quat = quat_candidates[one_hot, :].reshape(q_abs.shape)
|
| 159 |
+
return RotationQuat(quat)
|
| 160 |
+
|
| 161 |
+
def compose(self, other: RotationMatrix) -> RotationMatrix:
|
| 162 |
+
with fp32_autocast_context(self._rots.device.type):
|
| 163 |
+
return RotationMatrix(self._rots @ other._rots)
|
| 164 |
+
|
| 165 |
+
def convert_compose(self, other: Rotation):
|
| 166 |
+
return self.compose(other.as_matrix())
|
| 167 |
+
|
| 168 |
+
def apply(self, p: torch.Tensor) -> torch.Tensor:
|
| 169 |
+
with fp32_autocast_context(self.device.type):
|
| 170 |
+
if self._rots.shape[-3] == 1:
|
| 171 |
+
# This is a slight speedup over einsum for batched rotations
|
| 172 |
+
return p @ self._rots.transpose(-1, -2).squeeze(-3)
|
| 173 |
+
else:
|
| 174 |
+
# einsum way faster than bmm!
|
| 175 |
+
return torch.einsum("...ij,...j", self._rots, p)
|
| 176 |
+
|
| 177 |
+
def invert(self) -> RotationMatrix:
|
| 178 |
+
return RotationMatrix(self._rots.transpose(-1, -2))
|
| 179 |
+
|
| 180 |
+
@property
|
| 181 |
+
def tensor(self) -> torch.Tensor:
|
| 182 |
+
return self._rots.flatten(-2)
|
| 183 |
+
|
| 184 |
+
def to_3x3(self) -> torch.Tensor:
|
| 185 |
+
return self._rots
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def from_graham_schmidt(
|
| 189 |
+
x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12
|
| 190 |
+
) -> RotationMatrix:
|
| 191 |
+
# A low eps here is necessary for good stability!
|
| 192 |
+
return RotationMatrix(_graham_schmidt(x_axis, xy_plane, eps))
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class RotationQuat(Rotation):
|
| 196 |
+
def __init__(self, quats: torch.Tensor, normalized=False):
|
| 197 |
+
assert quats.shape[-1] == 4
|
| 198 |
+
self._normalized = normalized
|
| 199 |
+
# Force float32 as well
|
| 200 |
+
if normalized:
|
| 201 |
+
self._quats = F.normalize(quats.to(torch.float32), dim=-1)
|
| 202 |
+
self._quats = self._quats.where(self._quats[..., :1] >= 0, -self._quats)
|
| 203 |
+
else:
|
| 204 |
+
self._quats = quats.to(torch.float32)
|
| 205 |
+
|
| 206 |
+
@classmethod
|
| 207 |
+
def identity(cls, shape, **tensor_kwargs):
|
| 208 |
+
q = torch.ones((*shape, 4), **tensor_kwargs)
|
| 209 |
+
mult = torch.tensor([1, 0, 0, 0], device=q.device)
|
| 210 |
+
return RotationQuat(q * mult)
|
| 211 |
+
|
| 212 |
+
@classmethod
|
| 213 |
+
def random(cls, shape, **tensor_kwargs):
|
| 214 |
+
quat = torch.randn((*shape, 4), **tensor_kwargs)
|
| 215 |
+
return RotationQuat(quat, normalized=True)
|
| 216 |
+
|
| 217 |
+
def __getitem__(self, idx: T.Any) -> RotationQuat:
|
| 218 |
+
indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
|
| 219 |
+
return RotationQuat(self._quats[indices + (slice(None),)])
|
| 220 |
+
|
| 221 |
+
@property
|
| 222 |
+
def shape(self) -> torch.Size:
|
| 223 |
+
return self._quats.shape[:-1]
|
| 224 |
+
|
| 225 |
+
def compose(self, other: RotationQuat) -> RotationQuat:
|
| 226 |
+
with fp32_autocast_context(self._quats.device.type):
|
| 227 |
+
return RotationQuat(_quat_mult(self._quats, other._quats))
|
| 228 |
+
|
| 229 |
+
def convert_compose(self, other: Rotation):
|
| 230 |
+
return self.compose(other.as_quat())
|
| 231 |
+
|
| 232 |
+
def as_matrix(self) -> RotationMatrix:
|
| 233 |
+
q = self.normalized().tensor
|
| 234 |
+
r, i, j, k = torch.unbind(q, -1)
|
| 235 |
+
two_s = 2.0 / torch.linalg.norm(q, dim=-1)
|
| 236 |
+
|
| 237 |
+
o = torch.stack(
|
| 238 |
+
(
|
| 239 |
+
1 - two_s * (j * j + k * k),
|
| 240 |
+
two_s * (i * j - k * r),
|
| 241 |
+
two_s * (i * k + j * r),
|
| 242 |
+
two_s * (i * j + k * r),
|
| 243 |
+
1 - two_s * (i * i + k * k),
|
| 244 |
+
two_s * (j * k - i * r),
|
| 245 |
+
two_s * (i * k - j * r),
|
| 246 |
+
two_s * (j * k + i * r),
|
| 247 |
+
1 - two_s * (i * i + j * j),
|
| 248 |
+
),
|
| 249 |
+
-1,
|
| 250 |
+
)
|
| 251 |
+
return RotationMatrix(o.reshape(q.shape[:-1] + (3, 3)))
|
| 252 |
+
|
| 253 |
+
def as_quat(self, normalize: bool = False) -> RotationQuat:
|
| 254 |
+
return self
|
| 255 |
+
|
| 256 |
+
def apply(self, p: torch.Tensor) -> torch.Tensor:
|
| 257 |
+
return _quat_rotation(self.normalized()._quats, p)
|
| 258 |
+
|
| 259 |
+
def invert(self) -> RotationQuat:
|
| 260 |
+
return RotationQuat(_quat_invert(self._quats))
|
| 261 |
+
|
| 262 |
+
@property
|
| 263 |
+
def tensor(self) -> torch.Tensor:
|
| 264 |
+
return self._quats
|
| 265 |
+
|
| 266 |
+
def normalized(self) -> RotationQuat:
|
| 267 |
+
return self if self._normalized else RotationQuat(self._quats, normalized=True)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
@dataclass(frozen=True)
|
| 271 |
+
class Affine3D:
|
| 272 |
+
trans: torch.Tensor
|
| 273 |
+
rot: Rotation
|
| 274 |
+
|
| 275 |
+
def __post_init__(self):
|
| 276 |
+
assert self.trans.shape[:-1] == self.rot.shape
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def identity(
|
| 280 |
+
shape_or_affine: T.Union[tuple[int, ...], "Affine3D"],
|
| 281 |
+
rotation_type: T.Type[Rotation] = RotationMatrix,
|
| 282 |
+
**tensor_kwargs,
|
| 283 |
+
):
|
| 284 |
+
# Creates a new identity Affine3D object with a specified shape
|
| 285 |
+
# or the same shape as another Affine3D object.
|
| 286 |
+
if isinstance(shape_or_affine, Affine3D):
|
| 287 |
+
kwargs = {"dtype": shape_or_affine.dtype, "device": shape_or_affine.device}
|
| 288 |
+
kwargs.update(tensor_kwargs)
|
| 289 |
+
shape = shape_or_affine.shape
|
| 290 |
+
rotation_type = type(shape_or_affine.rot)
|
| 291 |
+
else:
|
| 292 |
+
kwargs = tensor_kwargs
|
| 293 |
+
shape = shape_or_affine
|
| 294 |
+
return Affine3D(
|
| 295 |
+
torch.zeros((*shape, 3), **kwargs), rotation_type.identity(shape, **kwargs)
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
@staticmethod
|
| 299 |
+
def random(
|
| 300 |
+
shape: tuple[int, ...],
|
| 301 |
+
std: float = 1,
|
| 302 |
+
rotation_type: T.Type[Rotation] = RotationMatrix,
|
| 303 |
+
**tensor_kwargs,
|
| 304 |
+
) -> "Affine3D":
|
| 305 |
+
return Affine3D(
|
| 306 |
+
trans=torch.randn((*shape, 3), **tensor_kwargs).mul(std),
|
| 307 |
+
rot=rotation_type.random(shape, **tensor_kwargs),
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
def __getitem__(self, idx: T.Any) -> "Affine3D":
|
| 311 |
+
indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
|
| 312 |
+
return Affine3D(trans=self.trans[indices + (slice(None),)], rot=self.rot[idx])
|
| 313 |
+
|
| 314 |
+
@property
|
| 315 |
+
def shape(self) -> torch.Size:
|
| 316 |
+
return self.trans.shape[:-1]
|
| 317 |
+
|
| 318 |
+
@property
|
| 319 |
+
def dtype(self) -> torch.dtype:
|
| 320 |
+
return self.trans.dtype
|
| 321 |
+
|
| 322 |
+
@property
|
| 323 |
+
def device(self) -> torch.device:
|
| 324 |
+
return self.trans.device
|
| 325 |
+
|
| 326 |
+
@property
|
| 327 |
+
def requires_grad(self) -> bool:
|
| 328 |
+
return self.trans.requires_grad
|
| 329 |
+
|
| 330 |
+
def to(self, **kwargs) -> "Affine3D":
|
| 331 |
+
return Affine3D(self.trans.to(**kwargs), self.rot.to(**kwargs))
|
| 332 |
+
|
| 333 |
+
def detach(self, *args, **kwargs) -> "Affine3D":
|
| 334 |
+
return Affine3D(self.trans.detach(**kwargs), self.rot.detach(**kwargs))
|
| 335 |
+
|
| 336 |
+
def tensor_apply(self, func) -> "Affine3D":
|
| 337 |
+
# Applys a function to the underlying tensor
|
| 338 |
+
return self.from_tensor(
|
| 339 |
+
torch.stack([func(x) for x in self.tensor.unbind(dim=-1)], dim=-1)
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
def as_matrix(self):
|
| 343 |
+
return Affine3D(trans=self.trans, rot=self.rot.as_matrix())
|
| 344 |
+
|
| 345 |
+
def as_quat(self, normalize: bool = False):
|
| 346 |
+
return Affine3D(trans=self.trans, rot=self.rot.as_quat(normalize))
|
| 347 |
+
|
| 348 |
+
def compose(self, other: "Affine3D", autoconvert: bool = False):
|
| 349 |
+
rot = self.rot
|
| 350 |
+
new_rot = (rot.convert_compose if autoconvert else rot.compose)(other.rot)
|
| 351 |
+
new_trans = rot.apply(other.trans) + self.trans
|
| 352 |
+
return Affine3D(trans=new_trans, rot=new_rot)
|
| 353 |
+
|
| 354 |
+
def compose_rotation(self, other: Rotation, autoconvert: bool = False):
|
| 355 |
+
return Affine3D(
|
| 356 |
+
trans=self.trans,
|
| 357 |
+
rot=(self.rot.convert_compose if autoconvert else self.rot.compose)(other),
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
def scale(self, v: torch.Tensor | float):
|
| 361 |
+
return Affine3D(self.trans * v, self.rot)
|
| 362 |
+
|
| 363 |
+
def mask(self, mask: torch.Tensor, with_zero=False):
|
| 364 |
+
# Returns a transform where True positions in mask is identity
|
| 365 |
+
if with_zero:
|
| 366 |
+
tensor = self.tensor
|
| 367 |
+
return Affine3D.from_tensor(
|
| 368 |
+
torch.zeros_like(tensor).where(mask[..., None], tensor)
|
| 369 |
+
)
|
| 370 |
+
else:
|
| 371 |
+
identity = self.identity(
|
| 372 |
+
self.shape,
|
| 373 |
+
rotation_type=type(self.rot),
|
| 374 |
+
device=self.device,
|
| 375 |
+
dtype=self.dtype,
|
| 376 |
+
).tensor
|
| 377 |
+
return Affine3D.from_tensor(identity.where(mask[..., None], self.tensor))
|
| 378 |
+
|
| 379 |
+
def apply(self, p: torch.Tensor) -> torch.Tensor:
|
| 380 |
+
return self.rot.apply(p) + self.trans
|
| 381 |
+
|
| 382 |
+
def invert(self):
|
| 383 |
+
inv_rot = self.rot.invert()
|
| 384 |
+
return Affine3D(trans=-inv_rot.apply(self.trans), rot=inv_rot)
|
| 385 |
+
|
| 386 |
+
@property
|
| 387 |
+
def tensor(self) -> torch.Tensor:
|
| 388 |
+
return torch.cat([self.rot.tensor, self.trans], dim=-1)
|
| 389 |
+
|
| 390 |
+
@staticmethod
|
| 391 |
+
def from_tensor(t: torch.Tensor) -> "Affine3D":
|
| 392 |
+
match t.shape[-1]:
|
| 393 |
+
case 4:
|
| 394 |
+
# Assume tensor 4x4 for backward compat with alphafold
|
| 395 |
+
trans = t[..., :3, 3]
|
| 396 |
+
rot = RotationMatrix(t[..., :3, :3])
|
| 397 |
+
case 6:
|
| 398 |
+
# Assume quaternion representation with real part = 1
|
| 399 |
+
trans = t[..., -3:]
|
| 400 |
+
rot = RotationQuat(F.pad(t[..., :3], (1, 0), value=1))
|
| 401 |
+
case 7:
|
| 402 |
+
trans = t[..., -3:]
|
| 403 |
+
rot = RotationQuat(t[..., :4])
|
| 404 |
+
case 12:
|
| 405 |
+
trans = t[..., -3:]
|
| 406 |
+
rot = RotationMatrix(t[..., :-3].unflatten(-1, (3, 3)))
|
| 407 |
+
case _:
|
| 408 |
+
raise RuntimeError(
|
| 409 |
+
f"Cannot detect rotation fromat from {t.shape[-1] -3}-d flat vector"
|
| 410 |
+
)
|
| 411 |
+
return Affine3D(trans, rot)
|
| 412 |
+
|
| 413 |
+
@staticmethod
|
| 414 |
+
def from_tensor_pair(t: torch.Tensor, r: torch.Tensor) -> "Affine3D":
|
| 415 |
+
return Affine3D(t, RotationMatrix(r))
|
| 416 |
+
|
| 417 |
+
@staticmethod
|
| 418 |
+
def from_graham_schmidt(
|
| 419 |
+
neg_x_axis: torch.Tensor,
|
| 420 |
+
origin: torch.Tensor,
|
| 421 |
+
xy_plane: torch.Tensor,
|
| 422 |
+
eps: float = 1e-10,
|
| 423 |
+
):
|
| 424 |
+
# The arguments of this function is for parity with AlphaFold
|
| 425 |
+
x_axis = origin - neg_x_axis
|
| 426 |
+
xy_plane = xy_plane - origin
|
| 427 |
+
return Affine3D(
|
| 428 |
+
trans=origin, rot=RotationMatrix.from_graham_schmidt(x_axis, xy_plane, eps)
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
@staticmethod
|
| 432 |
+
def cat(affines: list["Affine3D"], dim: int = 0):
|
| 433 |
+
if dim < 0:
|
| 434 |
+
dim = len(affines[0].shape) + dim
|
| 435 |
+
return Affine3D.from_tensor(torch.cat([x.tensor for x in affines], dim=dim))
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def _quat_mult(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 439 |
+
"""
|
| 440 |
+
Multiply two quaternions.
|
| 441 |
+
Usual torch rules for broadcasting apply.
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
| 445 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
| 449 |
+
"""
|
| 450 |
+
aw, ax, ay, az = torch.unbind(a, -1)
|
| 451 |
+
bw, bx, by, bz = torch.unbind(b, -1)
|
| 452 |
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
| 453 |
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
| 454 |
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
| 455 |
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
| 456 |
+
return torch.stack((ow, ox, oy, oz), -1)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def _quat_rotation(q: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
|
| 460 |
+
"""
|
| 461 |
+
Rotates p by quaternion q. Usual torch rules for broadcasting apply.
|
| 462 |
+
|
| 463 |
+
Args:
|
| 464 |
+
q: Quaternions as tensor of shape (..., 4), real part first.
|
| 465 |
+
p: Points as tensor of shape (..., 3)
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
The rotated version of p, of shape (..., 3)
|
| 469 |
+
"""
|
| 470 |
+
aw, ax, ay, az = torch.unbind(q, -1)
|
| 471 |
+
bx, by, bz = torch.unbind(p, -1)
|
| 472 |
+
# fmt: off
|
| 473 |
+
ow = - ax * bx - ay * by - az * bz
|
| 474 |
+
ox = aw * bx + ay * bz - az * by
|
| 475 |
+
oy = aw * by - ax * bz + az * bx
|
| 476 |
+
oz = aw * bz + ax * by - ay * bx
|
| 477 |
+
# fmt: on
|
| 478 |
+
q_mul_pts = torch.stack((ow, ox, oy, oz), -1)
|
| 479 |
+
return _quat_mult(q_mul_pts, _quat_invert(q))[..., 1:]
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def _quat_invert(q: torch.Tensor):
|
| 483 |
+
return q * torch.tensor([1, -1, -1, -1], device=q.device)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def _sqrt_subgradient(x: torch.Tensor) -> torch.Tensor:
|
| 487 |
+
# Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0.
|
| 488 |
+
ret = torch.zeros_like(x)
|
| 489 |
+
positive_mask = x > 0
|
| 490 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 491 |
+
return ret
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def _graham_schmidt(x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12):
|
| 495 |
+
# A low eps here is necessary for good stability!
|
| 496 |
+
with fp32_autocast_context(x_axis.device.type):
|
| 497 |
+
e1 = xy_plane
|
| 498 |
+
|
| 499 |
+
denom = torch.sqrt((x_axis**2).sum(dim=-1, keepdim=True) + eps)
|
| 500 |
+
x_axis = x_axis / denom
|
| 501 |
+
dot = (x_axis * e1).sum(dim=-1, keepdim=True)
|
| 502 |
+
e1 = e1 - x_axis * dot
|
| 503 |
+
denom = torch.sqrt((e1**2).sum(dim=-1, keepdim=True) + eps)
|
| 504 |
+
e1 = e1 / denom
|
| 505 |
+
e2 = torch.cross(x_axis, e1, dim=-1)
|
| 506 |
+
|
| 507 |
+
rots = torch.stack([x_axis, e1, e2], dim=-1)
|
| 508 |
+
|
| 509 |
+
return rots
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def build_affine3d_from_coordinates(
|
| 513 |
+
coords: torch.Tensor, # (N, CA, C).
|
| 514 |
+
) -> tuple[Affine3D, torch.Tensor]:
|
| 515 |
+
_MAX_SUPPORTED_DISTANCE = 1e6
|
| 516 |
+
coord_mask = torch.all(
|
| 517 |
+
torch.all(torch.isfinite(coords) & (coords < _MAX_SUPPORTED_DISTANCE), dim=-1),
|
| 518 |
+
dim=-1,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
def atom3_to_backbone_affine(bb_positions: torch.Tensor) -> Affine3D:
|
| 522 |
+
N, CA, C = bb_positions.unbind(dim=-2)
|
| 523 |
+
return Affine3D.from_graham_schmidt(C, CA, N)
|
| 524 |
+
|
| 525 |
+
coords = coords.clone().float()
|
| 526 |
+
coords[~coord_mask] = 0
|
| 527 |
+
|
| 528 |
+
# NOTE(thayes): If you have already normalized the coordinates, then
|
| 529 |
+
# the black hole affine translations will be zeros and the rotations will be
|
| 530 |
+
# the identity.
|
| 531 |
+
average_per_n_ca_c = coords.masked_fill(~coord_mask[..., None, None], 0).sum(1) / (
|
| 532 |
+
coord_mask.sum(-1)[..., None, None] + 1e-8
|
| 533 |
+
)
|
| 534 |
+
affine_from_average = atom3_to_backbone_affine(
|
| 535 |
+
average_per_n_ca_c.float()
|
| 536 |
+
).as_matrix()
|
| 537 |
+
|
| 538 |
+
B, S, _, _ = coords.shape
|
| 539 |
+
assert isinstance(B, int)
|
| 540 |
+
assert isinstance(S, int)
|
| 541 |
+
affine_rot_mats = affine_from_average.rot.tensor[..., None, :].expand(B, S, 9)
|
| 542 |
+
affine_trans = affine_from_average.trans[..., None, :].expand(B, S, 3)
|
| 543 |
+
|
| 544 |
+
# We use the identity rotation whereever we have no coordinates. This is
|
| 545 |
+
# important because otherwise the rotation matrices will be all zeros, which
|
| 546 |
+
# will cause collapse in the distance/direction attention mechanism.
|
| 547 |
+
identity_rot = RotationMatrix.identity(
|
| 548 |
+
(B, S), dtype=torch.float32, device=coords.device, requires_grad=False
|
| 549 |
+
)
|
| 550 |
+
affine_rot_mats = affine_rot_mats.where(
|
| 551 |
+
coord_mask.any(-1)[..., None, None], identity_rot.tensor
|
| 552 |
+
)
|
| 553 |
+
black_hole_affine = Affine3D(affine_trans, RotationMatrix(affine_rot_mats))
|
| 554 |
+
|
| 555 |
+
affine = atom3_to_backbone_affine(coords.float())
|
| 556 |
+
affine = Affine3D.from_tensor(
|
| 557 |
+
affine.tensor.where(coord_mask[..., None], black_hole_affine.tensor)
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
return affine, coord_mask
|
| 561 |
+
|
esmfold2_aligner.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import Field, replace
|
| 4 |
+
from typing import Any, ClassVar, Protocol, TypeVar
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .esmfold2_protein_structure import compute_affine_and_rmsd
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Alignable(Protocol):
|
| 13 |
+
# Trick to detect whether an object is a dataclass
|
| 14 |
+
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
|
| 15 |
+
|
| 16 |
+
@property
|
| 17 |
+
def atom37_positions(self) -> np.ndarray: # type: ignore
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def atom37_mask(self) -> np.ndarray: # type: ignore
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
def __len__(self) -> int: ...
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
T = TypeVar("T", bound=Alignable)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Aligner:
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
mobile: Alignable,
|
| 34 |
+
target: Alignable,
|
| 35 |
+
only_use_backbone: bool = False,
|
| 36 |
+
use_reflection: bool = False,
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
Aligns a mobile protein chain against a target protein chain.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
mobile (ProteinChain): Protein chain to be aligned.
|
| 43 |
+
target (ProteinChain): Protein chain target.
|
| 44 |
+
only_use_backbone (bool): Whether to only use backbone atoms.
|
| 45 |
+
use_reflection (bool): Whether to align to target reflection.
|
| 46 |
+
"""
|
| 47 |
+
# Check proteins must have same number of residues
|
| 48 |
+
assert len(mobile) == len(target)
|
| 49 |
+
|
| 50 |
+
# Determine overlapping atoms
|
| 51 |
+
joint_atom37_mask = mobile.atom37_mask.astype(bool) & target.atom37_mask.astype(
|
| 52 |
+
bool
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Backbone atoms are first sites in atom37 representation
|
| 56 |
+
if only_use_backbone:
|
| 57 |
+
joint_atom37_mask[:, 3:] = False
|
| 58 |
+
|
| 59 |
+
# Extract matching atom positions and convert to batched tensors
|
| 60 |
+
mobile_atom_tensor = (
|
| 61 |
+
torch.from_numpy(mobile.atom37_positions).type(torch.double).unsqueeze(0)
|
| 62 |
+
)
|
| 63 |
+
target_atom_tensor = (
|
| 64 |
+
torch.from_numpy(target.atom37_positions).type(torch.double).unsqueeze(0)
|
| 65 |
+
)
|
| 66 |
+
joint_atom37_mask = (
|
| 67 |
+
torch.from_numpy(joint_atom37_mask).type(torch.bool).unsqueeze(0)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# If using reflection flip target
|
| 71 |
+
if use_reflection:
|
| 72 |
+
target_atom_tensor = -target_atom_tensor
|
| 73 |
+
|
| 74 |
+
# Compute alignment and rmsd
|
| 75 |
+
affine3D, rmsd = compute_affine_and_rmsd(
|
| 76 |
+
mobile_atom_tensor, target_atom_tensor, atom_exists_mask=joint_atom37_mask
|
| 77 |
+
)
|
| 78 |
+
self._affine3D = affine3D
|
| 79 |
+
self._rmsd = rmsd.item()
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def rmsd(self):
|
| 83 |
+
return self._rmsd
|
| 84 |
+
|
| 85 |
+
def apply(self, mobile: T) -> T:
|
| 86 |
+
"""Apply alignment to a protein chain"""
|
| 87 |
+
# Extract atom positions and convert to batched tensors
|
| 88 |
+
mobile_atom_tensor = (
|
| 89 |
+
torch.from_numpy(mobile.atom37_positions[mobile.atom37_mask])
|
| 90 |
+
.type(torch.float32)
|
| 91 |
+
.unsqueeze(0)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Transform atom arrays
|
| 95 |
+
aligned_atom_tensor = self._affine3D.apply(mobile_atom_tensor).squeeze(0)
|
| 96 |
+
|
| 97 |
+
# Rebuild atom37 positions
|
| 98 |
+
aligned_atom37_positions = np.full_like(mobile.atom37_positions, np.nan)
|
| 99 |
+
aligned_atom37_positions[mobile.atom37_mask] = aligned_atom_tensor
|
| 100 |
+
|
| 101 |
+
return replace(mobile, atom37_positions=aligned_atom37_positions)
|
| 102 |
+
|
esmfold2_atom_indexer.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from .esmfold2_protein_structure import index_by_atom_name
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class AtomIndexer:
|
| 7 |
+
def __init__(self, structure, property: str, dim: int):
|
| 8 |
+
self.structure = structure
|
| 9 |
+
self.property = property
|
| 10 |
+
self.dim = dim
|
| 11 |
+
|
| 12 |
+
def __getitem__(self, atom_names: str | list[str]) -> np.ndarray:
|
| 13 |
+
return index_by_atom_name(
|
| 14 |
+
getattr(self.structure, self.property), atom_names, self.dim
|
| 15 |
+
)
|
| 16 |
+
|
esmfold2_conformers.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CCD conformer loading utilities.
|
| 2 |
+
|
| 3 |
+
Loads idealized conformer coordinates from a CCD pickle file containing RDKit molecules.
|
| 4 |
+
Conformer priority follows AF3 Section 2.8: Computed > Ideal > first available.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import pickle
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from huggingface_hub import hf_hub_download
|
| 15 |
+
|
| 16 |
+
from .esmfold2_constants import RES_TYPE_TO_CCD
|
| 17 |
+
|
| 18 |
+
if os.environ.get("ESMCFOLD_CCD_PATH"):
|
| 19 |
+
CCD_PICKLE_PATH = Path(os.environ["ESMCFOLD_CCD_PATH"])
|
| 20 |
+
else:
|
| 21 |
+
CCD_PICKLE_PATH = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Lazily loaded CCD dictionary
|
| 25 |
+
_CCD_MOLECULES: dict | None = None
|
| 26 |
+
|
| 27 |
+
# Caches
|
| 28 |
+
_CCD_CONFORMERS: dict[str, dict[str, np.ndarray]] = {}
|
| 29 |
+
_CCD_ATOM_CACHE: dict[str, list[tuple[str, str, int]]] = {}
|
| 30 |
+
_CCD_BONDS_CACHE: dict[str, list[tuple[str, str]]] = {}
|
| 31 |
+
_CCD_LEAVING_ATOMS_CACHE: dict[str, set[str]] = {}
|
| 32 |
+
_IDEALIZED_POS_CACHE: dict[tuple[int, str], np.ndarray | None] = {}
|
| 33 |
+
_LIGAND_IDEALIZED_POS_CACHE: dict[tuple[str, str], np.ndarray | None] = {}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_ccd(cache_dir: Path | str | None = None) -> dict:
|
| 37 |
+
"""Load CCD molecules from pickle file, downloading if needed.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
cache_dir: Directory to cache the downloaded CCD pickle.
|
| 41 |
+
If None, uses CCD_PICKLE_PATH env var or downloads to ~/.cache/esmcfold/.
|
| 42 |
+
"""
|
| 43 |
+
global _CCD_MOLECULES
|
| 44 |
+
if _CCD_MOLECULES is not None:
|
| 45 |
+
return _CCD_MOLECULES
|
| 46 |
+
|
| 47 |
+
# Determine pickle path
|
| 48 |
+
if CCD_PICKLE_PATH is not None and CCD_PICKLE_PATH.exists():
|
| 49 |
+
pkl_path = CCD_PICKLE_PATH
|
| 50 |
+
elif cache_dir is not None:
|
| 51 |
+
cache_dir = Path(cache_dir)
|
| 52 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
pkl_path = cache_dir / "ccd.pkl"
|
| 54 |
+
else:
|
| 55 |
+
try:
|
| 56 |
+
pkl_path = Path(
|
| 57 |
+
hf_hub_download(repo_id="biohub/ESMFold2", filename="ccd.pkl")
|
| 58 |
+
)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
raise FileNotFoundError(
|
| 61 |
+
f"Failed to download CCD pickle file from Hugging Face repository: {e}"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
if not pkl_path.exists():
|
| 65 |
+
raise FileNotFoundError(
|
| 66 |
+
f"CCD pickle file not found: {pkl_path}. Please set the ESMCFOLD_CCD_PATH environment variable to the path of a valid CCD pickle file or download the file from the Hugging Face repository."
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
print(f"Loading CCD dictionary from {pkl_path}")
|
| 70 |
+
with open(pkl_path, "rb") as f:
|
| 71 |
+
_CCD_MOLECULES = pickle.load(f)
|
| 72 |
+
|
| 73 |
+
if _CCD_MOLECULES is None:
|
| 74 |
+
_CCD_MOLECULES = {}
|
| 75 |
+
|
| 76 |
+
return _CCD_MOLECULES
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _get_ccd_molecules() -> dict:
|
| 80 |
+
"""Get CCD molecules, loading lazily on first call."""
|
| 81 |
+
global _CCD_MOLECULES
|
| 82 |
+
if _CCD_MOLECULES is None:
|
| 83 |
+
return load_ccd()
|
| 84 |
+
return _CCD_MOLECULES
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _get_ccd_mol_with_significant_h(comp_id: str):
|
| 88 |
+
"""Get CCD molecule with only chemically significant hydrogens.
|
| 89 |
+
|
| 90 |
+
Returns (mol, conformer) tuple or (None, None) if not available.
|
| 91 |
+
"""
|
| 92 |
+
ccd = _get_ccd_molecules()
|
| 93 |
+
if comp_id not in ccd:
|
| 94 |
+
return None, None
|
| 95 |
+
|
| 96 |
+
mol = ccd[comp_id]
|
| 97 |
+
if mol.GetNumConformers() == 0:
|
| 98 |
+
return None, None
|
| 99 |
+
|
| 100 |
+
# Find the "Computed" conformer (RDKit ETKDGv3), fall back to "Ideal"
|
| 101 |
+
conf_idx = 0
|
| 102 |
+
for i, c in enumerate(mol.GetConformers()):
|
| 103 |
+
props = c.GetPropsAsDict()
|
| 104 |
+
if props.get("name") == "Computed":
|
| 105 |
+
conf_idx = i
|
| 106 |
+
break
|
| 107 |
+
else:
|
| 108 |
+
for i, c in enumerate(mol.GetConformers()):
|
| 109 |
+
props = c.GetPropsAsDict()
|
| 110 |
+
if props.get("name") == "Ideal":
|
| 111 |
+
conf_idx = i
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
from rdkit import Chem
|
| 115 |
+
|
| 116 |
+
mol_no_h = Chem.RemoveHs(mol, sanitize=False)
|
| 117 |
+
|
| 118 |
+
if mol_no_h.GetNumConformers() == 0:
|
| 119 |
+
return None, None
|
| 120 |
+
|
| 121 |
+
return mol_no_h, mol_no_h.GetConformer(
|
| 122 |
+
min(conf_idx, mol_no_h.GetNumConformers() - 1)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_ccd_conformer(comp_id: str) -> dict[str, np.ndarray] | None:
|
| 127 |
+
"""Get idealized conformer as dict of atom_name -> position [3].
|
| 128 |
+
|
| 129 |
+
Conformer priority: Computed > Ideal > first available.
|
| 130 |
+
"""
|
| 131 |
+
if comp_id in _CCD_CONFORMERS:
|
| 132 |
+
cached = _CCD_CONFORMERS[comp_id]
|
| 133 |
+
return cached if cached else None
|
| 134 |
+
|
| 135 |
+
mol, conf = _get_ccd_mol_with_significant_h(comp_id)
|
| 136 |
+
if mol is None or conf is None:
|
| 137 |
+
_CCD_CONFORMERS[comp_id] = {}
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
conformer: dict[str, np.ndarray] = {}
|
| 141 |
+
for atom in mol.GetAtoms():
|
| 142 |
+
props = atom.GetPropsAsDict()
|
| 143 |
+
atom_name = props.get("name")
|
| 144 |
+
if not isinstance(atom_name, str) or not atom_name:
|
| 145 |
+
continue
|
| 146 |
+
idx = atom.GetIdx()
|
| 147 |
+
pos = conf.GetAtomPosition(idx)
|
| 148 |
+
conformer[atom_name] = np.array([pos.x, pos.y, pos.z], dtype=np.float32)
|
| 149 |
+
|
| 150 |
+
_CCD_CONFORMERS[comp_id] = conformer
|
| 151 |
+
return conformer if conformer else None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_idealized_atom_pos(res_type: int, atom_name: str) -> np.ndarray | None:
|
| 155 |
+
"""Get idealized position for a standard residue atom.
|
| 156 |
+
|
| 157 |
+
Uses res_type index to look up CCD component, then returns position.
|
| 158 |
+
Returns None if not found.
|
| 159 |
+
"""
|
| 160 |
+
cache_key = (res_type, atom_name)
|
| 161 |
+
if cache_key in _IDEALIZED_POS_CACHE:
|
| 162 |
+
return _IDEALIZED_POS_CACHE[cache_key]
|
| 163 |
+
|
| 164 |
+
comp_id = RES_TYPE_TO_CCD.get(res_type)
|
| 165 |
+
if comp_id:
|
| 166 |
+
ccd_conformer = get_ccd_conformer(comp_id)
|
| 167 |
+
if ccd_conformer and atom_name in ccd_conformer:
|
| 168 |
+
pos = ccd_conformer[atom_name]
|
| 169 |
+
_IDEALIZED_POS_CACHE[cache_key] = pos
|
| 170 |
+
return pos
|
| 171 |
+
|
| 172 |
+
_IDEALIZED_POS_CACHE[cache_key] = None
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_ligand_idealized_atom_pos(res_name: str, atom_name: str) -> np.ndarray | None:
|
| 177 |
+
"""Get idealized position for a ligand/modified residue atom.
|
| 178 |
+
|
| 179 |
+
Returns None if not found.
|
| 180 |
+
"""
|
| 181 |
+
cache_key = (res_name, atom_name)
|
| 182 |
+
if cache_key in _LIGAND_IDEALIZED_POS_CACHE:
|
| 183 |
+
return _LIGAND_IDEALIZED_POS_CACHE[cache_key]
|
| 184 |
+
|
| 185 |
+
ccd_conformer = get_ccd_conformer(res_name)
|
| 186 |
+
if ccd_conformer and atom_name in ccd_conformer:
|
| 187 |
+
pos = ccd_conformer[atom_name]
|
| 188 |
+
_LIGAND_IDEALIZED_POS_CACHE[cache_key] = pos
|
| 189 |
+
return pos
|
| 190 |
+
|
| 191 |
+
_LIGAND_IDEALIZED_POS_CACHE[cache_key] = None
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def get_ligand_ccd_atoms_with_charges(
|
| 196 |
+
comp_id: str,
|
| 197 |
+
) -> list[tuple[str, str, int]] | None:
|
| 198 |
+
"""Get list of (atom_name, element, charge) for a CCD component.
|
| 199 |
+
|
| 200 |
+
Uses RDKit RemoveHs(sanitize=False) to keep chemically significant hydrogens.
|
| 201 |
+
Returns None if CCD data not available.
|
| 202 |
+
"""
|
| 203 |
+
if comp_id in _CCD_ATOM_CACHE:
|
| 204 |
+
cached = _CCD_ATOM_CACHE[comp_id]
|
| 205 |
+
return cached if cached else None
|
| 206 |
+
|
| 207 |
+
mol, _ = _get_ccd_mol_with_significant_h(comp_id)
|
| 208 |
+
if mol is None:
|
| 209 |
+
_CCD_ATOM_CACHE[comp_id] = []
|
| 210 |
+
return None
|
| 211 |
+
|
| 212 |
+
atoms: list[tuple[str, str, int]] = []
|
| 213 |
+
for atom in mol.GetAtoms():
|
| 214 |
+
props = atom.GetPropsAsDict()
|
| 215 |
+
atom_name = props.get("name")
|
| 216 |
+
if not isinstance(atom_name, str) or not atom_name:
|
| 217 |
+
continue
|
| 218 |
+
element = atom.GetSymbol()
|
| 219 |
+
charge = atom.GetFormalCharge()
|
| 220 |
+
atoms.append((atom_name, element, charge))
|
| 221 |
+
|
| 222 |
+
_CCD_ATOM_CACHE[comp_id] = atoms
|
| 223 |
+
return atoms if atoms else None
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_ligand_ccd_bonds(comp_id: str) -> list[tuple[str, str]] | None:
|
| 227 |
+
"""Get list of (atom1_name, atom2_name) bonds for a CCD component.
|
| 228 |
+
|
| 229 |
+
Returns None if CCD data not available.
|
| 230 |
+
"""
|
| 231 |
+
if comp_id in _CCD_BONDS_CACHE:
|
| 232 |
+
cached = _CCD_BONDS_CACHE[comp_id]
|
| 233 |
+
return cached if cached else None
|
| 234 |
+
|
| 235 |
+
mol, _ = _get_ccd_mol_with_significant_h(comp_id)
|
| 236 |
+
if mol is None:
|
| 237 |
+
_CCD_BONDS_CACHE[comp_id] = []
|
| 238 |
+
return None
|
| 239 |
+
|
| 240 |
+
# Get included atom names
|
| 241 |
+
included_atoms = set()
|
| 242 |
+
for atom in mol.GetAtoms():
|
| 243 |
+
props = atom.GetPropsAsDict()
|
| 244 |
+
atom_name = props.get("name")
|
| 245 |
+
if isinstance(atom_name, str) and atom_name:
|
| 246 |
+
included_atoms.add(atom_name)
|
| 247 |
+
|
| 248 |
+
bonds: list[tuple[str, str]] = []
|
| 249 |
+
for bond in mol.GetBonds():
|
| 250 |
+
a1 = bond.GetBeginAtom()
|
| 251 |
+
a2 = bond.GetEndAtom()
|
| 252 |
+
n1 = a1.GetPropsAsDict().get("name")
|
| 253 |
+
n2 = a2.GetPropsAsDict().get("name")
|
| 254 |
+
if (
|
| 255 |
+
isinstance(n1, str)
|
| 256 |
+
and isinstance(n2, str)
|
| 257 |
+
and n1
|
| 258 |
+
and n2
|
| 259 |
+
and n1 in included_atoms
|
| 260 |
+
and n2 in included_atoms
|
| 261 |
+
):
|
| 262 |
+
bonds.append((n1, n2))
|
| 263 |
+
|
| 264 |
+
_CCD_BONDS_CACHE[comp_id] = bonds
|
| 265 |
+
return bonds if bonds else None
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def get_ccd_leaving_atoms(comp_id: str) -> set[str]:
|
| 269 |
+
"""Get set of atom names marked as leaving atoms in CCD.
|
| 270 |
+
|
| 271 |
+
Leaving atoms are removed during polymerization (e.g., OP3 in nucleotides).
|
| 272 |
+
"""
|
| 273 |
+
if comp_id in _CCD_LEAVING_ATOMS_CACHE:
|
| 274 |
+
return _CCD_LEAVING_ATOMS_CACHE[comp_id]
|
| 275 |
+
|
| 276 |
+
ccd = _get_ccd_molecules()
|
| 277 |
+
if comp_id not in ccd:
|
| 278 |
+
_CCD_LEAVING_ATOMS_CACHE[comp_id] = set()
|
| 279 |
+
return set()
|
| 280 |
+
|
| 281 |
+
mol = ccd[comp_id]
|
| 282 |
+
leaving_atoms = set()
|
| 283 |
+
for atom in mol.GetAtoms():
|
| 284 |
+
if atom.HasProp("leaving_atom"):
|
| 285 |
+
if atom.GetProp("leaving_atom") == "1":
|
| 286 |
+
name = atom.GetProp("name") if atom.HasProp("name") else ""
|
| 287 |
+
if name:
|
| 288 |
+
leaving_atoms.add(name)
|
| 289 |
+
|
| 290 |
+
_CCD_LEAVING_ATOMS_CACHE[comp_id] = leaving_atoms
|
| 291 |
+
return leaving_atoms
|
| 292 |
+
|
esmfold2_constants.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Constants for the ESMFold2 input pipeline.
|
| 2 |
+
|
| 3 |
+
Includes molecule types, residue types, vocabularies, atom lists, and element data.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# =============================================================================
|
| 7 |
+
# Molecule types
|
| 8 |
+
# =============================================================================
|
| 9 |
+
|
| 10 |
+
MOL_TYPE_PROTEIN = 0
|
| 11 |
+
MOL_TYPE_DNA = 1
|
| 12 |
+
MOL_TYPE_RNA = 2
|
| 13 |
+
MOL_TYPE_NONPOLYMER = 3
|
| 14 |
+
|
| 15 |
+
# =============================================================================
|
| 16 |
+
# Residue type indices
|
| 17 |
+
# =============================================================================
|
| 18 |
+
|
| 19 |
+
# Standard amino acids (indices 2-21), MSE mapped to MET
|
| 20 |
+
PROTEIN_RESIDUE_TO_RES_TYPE = {
|
| 21 |
+
"ALA": 2,
|
| 22 |
+
"ARG": 3,
|
| 23 |
+
"ASN": 4,
|
| 24 |
+
"ASP": 5,
|
| 25 |
+
"CYS": 6,
|
| 26 |
+
"GLN": 7,
|
| 27 |
+
"GLU": 8,
|
| 28 |
+
"GLY": 9,
|
| 29 |
+
"HIS": 10,
|
| 30 |
+
"ILE": 11,
|
| 31 |
+
"LEU": 12,
|
| 32 |
+
"LYS": 13,
|
| 33 |
+
"MET": 14,
|
| 34 |
+
"PHE": 15,
|
| 35 |
+
"PRO": 16,
|
| 36 |
+
"SER": 17,
|
| 37 |
+
"THR": 18,
|
| 38 |
+
"TRP": 19,
|
| 39 |
+
"TYR": 20,
|
| 40 |
+
"VAL": 21,
|
| 41 |
+
"MSE": 14, # Selenomethionine -> MET
|
| 42 |
+
}
|
| 43 |
+
PROTEIN_UNK_RES_TYPE = 22
|
| 44 |
+
|
| 45 |
+
# RNA nucleotides (indices 23-26, unknown=27)
|
| 46 |
+
RNA_RESIDUE_TO_RES_TYPE = {"A": 23, "G": 24, "C": 25, "U": 26}
|
| 47 |
+
RNA_UNK_RES_TYPE = 27
|
| 48 |
+
|
| 49 |
+
# DNA nucleotides (indices 28-31, unknown=32)
|
| 50 |
+
DNA_RESIDUE_TO_RES_TYPE = {"DA": 28, "DG": 29, "DC": 30, "DT": 31}
|
| 51 |
+
DNA_UNK_RES_TYPE = 32
|
| 52 |
+
|
| 53 |
+
GAP_RES_TYPE = 32
|
| 54 |
+
|
| 55 |
+
# =============================================================================
|
| 56 |
+
# Vocabularies
|
| 57 |
+
# =============================================================================
|
| 58 |
+
|
| 59 |
+
# 3-letter to 1-letter codes for proteins
|
| 60 |
+
PROTEIN_3TO1 = {
|
| 61 |
+
"ALA": "A",
|
| 62 |
+
"ARG": "R",
|
| 63 |
+
"ASN": "N",
|
| 64 |
+
"ASP": "D",
|
| 65 |
+
"CYS": "C",
|
| 66 |
+
"GLN": "Q",
|
| 67 |
+
"GLU": "E",
|
| 68 |
+
"GLY": "G",
|
| 69 |
+
"HIS": "H",
|
| 70 |
+
"ILE": "I",
|
| 71 |
+
"LEU": "L",
|
| 72 |
+
"LYS": "K",
|
| 73 |
+
"MET": "M",
|
| 74 |
+
"PHE": "F",
|
| 75 |
+
"PRO": "P",
|
| 76 |
+
"SER": "S",
|
| 77 |
+
"THR": "T",
|
| 78 |
+
"TRP": "W",
|
| 79 |
+
"TYR": "Y",
|
| 80 |
+
"VAL": "V",
|
| 81 |
+
"MSE": "M",
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
# 1-letter to 3-letter codes
|
| 85 |
+
PROTEIN_1TO3 = {v: k for k, v in PROTEIN_3TO1.items() if k != "MSE"}
|
| 86 |
+
PROTEIN_1TO3["X"] = "UNK"
|
| 87 |
+
|
| 88 |
+
# DNA 1-letter to CCD code
|
| 89 |
+
DNA_1TO3 = {"A": "DA", "T": "DT", "C": "DC", "G": "DG"}
|
| 90 |
+
|
| 91 |
+
# RNA 1-letter to CCD code
|
| 92 |
+
RNA_1TO3 = {"A": "A", "U": "U", "C": "C", "G": "G"}
|
| 93 |
+
|
| 94 |
+
# ESM-2 input_ids vocabulary for proteins
|
| 95 |
+
ESM_PROTEIN_VOCAB = {
|
| 96 |
+
"L": 4,
|
| 97 |
+
"A": 5,
|
| 98 |
+
"G": 6,
|
| 99 |
+
"V": 7,
|
| 100 |
+
"S": 8,
|
| 101 |
+
"E": 9,
|
| 102 |
+
"R": 10,
|
| 103 |
+
"T": 11,
|
| 104 |
+
"I": 12,
|
| 105 |
+
"D": 13,
|
| 106 |
+
"P": 14,
|
| 107 |
+
"K": 15,
|
| 108 |
+
"Q": 16,
|
| 109 |
+
"N": 17,
|
| 110 |
+
"F": 18,
|
| 111 |
+
"Y": 19,
|
| 112 |
+
"M": 20,
|
| 113 |
+
"H": 21,
|
| 114 |
+
"W": 22,
|
| 115 |
+
"C": 23,
|
| 116 |
+
"X": 3, # Unknown
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
# For DNA/RNA/ligands
|
| 120 |
+
DNA_RNA_LIGAND_INPUT_ID = 24
|
| 121 |
+
|
| 122 |
+
# MSA tokens
|
| 123 |
+
MSA_PAD_TOKEN_ID = 0
|
| 124 |
+
MSA_GAP_TOKEN_ID = 1 # Gap/insertion token for MSA
|
| 125 |
+
|
| 126 |
+
# res_type int -> CCD component ID (for conformer lookup)
|
| 127 |
+
RES_TYPE_TO_CCD = {
|
| 128 |
+
# Proteins (2-22)
|
| 129 |
+
2: "ALA",
|
| 130 |
+
3: "ARG",
|
| 131 |
+
4: "ASN",
|
| 132 |
+
5: "ASP",
|
| 133 |
+
6: "CYS",
|
| 134 |
+
7: "GLN",
|
| 135 |
+
8: "GLU",
|
| 136 |
+
9: "GLY",
|
| 137 |
+
10: "HIS",
|
| 138 |
+
11: "ILE",
|
| 139 |
+
12: "LEU",
|
| 140 |
+
13: "LYS",
|
| 141 |
+
14: "MET",
|
| 142 |
+
15: "PHE",
|
| 143 |
+
16: "PRO",
|
| 144 |
+
17: "SER",
|
| 145 |
+
18: "THR",
|
| 146 |
+
19: "TRP",
|
| 147 |
+
20: "TYR",
|
| 148 |
+
21: "VAL",
|
| 149 |
+
22: "UNK",
|
| 150 |
+
# RNA (23-27)
|
| 151 |
+
23: "A",
|
| 152 |
+
24: "G",
|
| 153 |
+
25: "C",
|
| 154 |
+
26: "U",
|
| 155 |
+
27: "N",
|
| 156 |
+
# DNA (28-32)
|
| 157 |
+
28: "DA",
|
| 158 |
+
29: "DG",
|
| 159 |
+
30: "DC",
|
| 160 |
+
31: "DT",
|
| 161 |
+
32: "DN",
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
# =============================================================================
|
| 165 |
+
# Charged atoms at physiological pH
|
| 166 |
+
# =============================================================================
|
| 167 |
+
|
| 168 |
+
CHARGED_ATOMS: dict[tuple[str, str], int] = {
|
| 169 |
+
("LYS", "NZ"): 1,
|
| 170 |
+
("ARG", "NH2"): 1,
|
| 171 |
+
("HIS", "ND1"): 1,
|
| 172 |
+
("PO4", "O2"): -1,
|
| 173 |
+
("PO4", "O3"): -1,
|
| 174 |
+
("PO4", "O4"): -1,
|
| 175 |
+
("SO4", "O3"): -1,
|
| 176 |
+
("SO4", "O4"): -1,
|
| 177 |
+
("MG", "MG"): 2,
|
| 178 |
+
("ZN", "ZN"): 2,
|
| 179 |
+
("CA", "CA"): 2,
|
| 180 |
+
("FE2", "FE"): 2,
|
| 181 |
+
("MN", "MN"): 2,
|
| 182 |
+
("CO", "CO"): 2,
|
| 183 |
+
("NCO", "CO"): 3,
|
| 184 |
+
("CU", "CU"): 2,
|
| 185 |
+
("NI", "NI"): 2,
|
| 186 |
+
("K", "K"): 1,
|
| 187 |
+
("NA", "NA"): 1,
|
| 188 |
+
("CD", "CD"): 2,
|
| 189 |
+
("CL", "CL"): -1,
|
| 190 |
+
("ACT", "OXT"): -1,
|
| 191 |
+
("NAD", "O2N"): -1,
|
| 192 |
+
("NAD", "N1N"): 1,
|
| 193 |
+
("NAP", "O2N"): -1,
|
| 194 |
+
("NAP", "N1N"): 1,
|
| 195 |
+
("IMD", "N3"): 1,
|
| 196 |
+
("SAM", "SD"): 1,
|
| 197 |
+
("FE", "FE"): 3,
|
| 198 |
+
("A1BH3", "N3"): 1,
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
# =============================================================================
|
| 202 |
+
# Element atomic numbers (Z=1 to 92)
|
| 203 |
+
# =============================================================================
|
| 204 |
+
|
| 205 |
+
ELEMENT_TO_ATOMIC_NUM = {
|
| 206 |
+
"H": 1,
|
| 207 |
+
"LI": 3,
|
| 208 |
+
"BE": 4,
|
| 209 |
+
"B": 5,
|
| 210 |
+
"C": 6,
|
| 211 |
+
"N": 7,
|
| 212 |
+
"O": 8,
|
| 213 |
+
"F": 9,
|
| 214 |
+
"NE": 10,
|
| 215 |
+
"NA": 11,
|
| 216 |
+
"MG": 12,
|
| 217 |
+
"AL": 13,
|
| 218 |
+
"SI": 14,
|
| 219 |
+
"P": 15,
|
| 220 |
+
"S": 16,
|
| 221 |
+
"CL": 17,
|
| 222 |
+
"AR": 18,
|
| 223 |
+
"K": 19,
|
| 224 |
+
"CA": 20,
|
| 225 |
+
"SC": 21,
|
| 226 |
+
"TI": 22,
|
| 227 |
+
"V": 23,
|
| 228 |
+
"CR": 24,
|
| 229 |
+
"MN": 25,
|
| 230 |
+
"FE": 26,
|
| 231 |
+
"CO": 27,
|
| 232 |
+
"NI": 28,
|
| 233 |
+
"CU": 29,
|
| 234 |
+
"ZN": 30,
|
| 235 |
+
"GA": 31,
|
| 236 |
+
"GE": 32,
|
| 237 |
+
"AS": 33,
|
| 238 |
+
"SE": 34,
|
| 239 |
+
"BR": 35,
|
| 240 |
+
"KR": 36,
|
| 241 |
+
"RB": 37,
|
| 242 |
+
"SR": 38,
|
| 243 |
+
"Y": 39,
|
| 244 |
+
"ZR": 40,
|
| 245 |
+
"NB": 41,
|
| 246 |
+
"MO": 42,
|
| 247 |
+
"TC": 43,
|
| 248 |
+
"RU": 44,
|
| 249 |
+
"RH": 45,
|
| 250 |
+
"PD": 46,
|
| 251 |
+
"AG": 47,
|
| 252 |
+
"CD": 48,
|
| 253 |
+
"IN": 49,
|
| 254 |
+
"SN": 50,
|
| 255 |
+
"SB": 51,
|
| 256 |
+
"TE": 52,
|
| 257 |
+
"I": 53,
|
| 258 |
+
"XE": 54,
|
| 259 |
+
"CS": 55,
|
| 260 |
+
"BA": 56,
|
| 261 |
+
"LA": 57,
|
| 262 |
+
"CE": 58,
|
| 263 |
+
"PR": 59,
|
| 264 |
+
"ND": 60,
|
| 265 |
+
"PM": 61,
|
| 266 |
+
"SM": 62,
|
| 267 |
+
"EU": 63,
|
| 268 |
+
"GD": 64,
|
| 269 |
+
"TB": 65,
|
| 270 |
+
"DY": 66,
|
| 271 |
+
"HO": 67,
|
| 272 |
+
"ER": 68,
|
| 273 |
+
"TM": 69,
|
| 274 |
+
"YB": 70,
|
| 275 |
+
"LU": 71,
|
| 276 |
+
"HF": 72,
|
| 277 |
+
"TA": 73,
|
| 278 |
+
"W": 74,
|
| 279 |
+
"RE": 75,
|
| 280 |
+
"OS": 76,
|
| 281 |
+
"IR": 77,
|
| 282 |
+
"PT": 78,
|
| 283 |
+
"AU": 79,
|
| 284 |
+
"HG": 80,
|
| 285 |
+
"TL": 81,
|
| 286 |
+
"PB": 82,
|
| 287 |
+
"BI": 83,
|
| 288 |
+
"PO": 84,
|
| 289 |
+
"AT": 85,
|
| 290 |
+
"RN": 86,
|
| 291 |
+
"FR": 87,
|
| 292 |
+
"RA": 88,
|
| 293 |
+
"AC": 89,
|
| 294 |
+
"TH": 90,
|
| 295 |
+
"PA": 91,
|
| 296 |
+
"U": 92,
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
# Inverse mapping: atomic number → element symbol
|
| 300 |
+
ELEMENT_NUMBER_TO_SYMBOL = {v: k for k, v in ELEMENT_TO_ATOMIC_NUM.items()}
|
| 301 |
+
|
| 302 |
+
# =============================================================================
|
| 303 |
+
# Standard heavy atoms per residue type
|
| 304 |
+
# =============================================================================
|
| 305 |
+
|
| 306 |
+
PROTEIN_HEAVY_ATOMS = {
|
| 307 |
+
"ALA": ["N", "CA", "C", "O", "CB"],
|
| 308 |
+
"ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
|
| 309 |
+
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"],
|
| 310 |
+
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"],
|
| 311 |
+
"CYS": ["N", "CA", "C", "O", "CB", "SG"],
|
| 312 |
+
"GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"],
|
| 313 |
+
"GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"],
|
| 314 |
+
"GLY": ["N", "CA", "C", "O"],
|
| 315 |
+
"HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"],
|
| 316 |
+
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"],
|
| 317 |
+
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"],
|
| 318 |
+
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"],
|
| 319 |
+
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
|
| 320 |
+
"PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
|
| 321 |
+
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD"],
|
| 322 |
+
"SER": ["N", "CA", "C", "O", "CB", "OG"],
|
| 323 |
+
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2"],
|
| 324 |
+
"TRP": [
|
| 325 |
+
"N",
|
| 326 |
+
"CA",
|
| 327 |
+
"C",
|
| 328 |
+
"O",
|
| 329 |
+
"CB",
|
| 330 |
+
"CG",
|
| 331 |
+
"CD1",
|
| 332 |
+
"CD2",
|
| 333 |
+
"NE1",
|
| 334 |
+
"CE2",
|
| 335 |
+
"CE3",
|
| 336 |
+
"CZ2",
|
| 337 |
+
"CZ3",
|
| 338 |
+
"CH2",
|
| 339 |
+
],
|
| 340 |
+
"TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"],
|
| 341 |
+
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2"],
|
| 342 |
+
"MSE": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
|
| 343 |
+
"UNK": ["N", "CA", "C", "O"],
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
DNA_HEAVY_ATOMS = {
|
| 347 |
+
"DA": [
|
| 348 |
+
"P",
|
| 349 |
+
"OP1",
|
| 350 |
+
"OP2",
|
| 351 |
+
"O5'",
|
| 352 |
+
"C5'",
|
| 353 |
+
"C4'",
|
| 354 |
+
"O4'",
|
| 355 |
+
"C3'",
|
| 356 |
+
"O3'",
|
| 357 |
+
"C2'",
|
| 358 |
+
"C1'",
|
| 359 |
+
"N9",
|
| 360 |
+
"C8",
|
| 361 |
+
"N7",
|
| 362 |
+
"C5",
|
| 363 |
+
"C6",
|
| 364 |
+
"N6",
|
| 365 |
+
"N1",
|
| 366 |
+
"C2",
|
| 367 |
+
"N3",
|
| 368 |
+
"C4",
|
| 369 |
+
],
|
| 370 |
+
"DG": [
|
| 371 |
+
"P",
|
| 372 |
+
"OP1",
|
| 373 |
+
"OP2",
|
| 374 |
+
"O5'",
|
| 375 |
+
"C5'",
|
| 376 |
+
"C4'",
|
| 377 |
+
"O4'",
|
| 378 |
+
"C3'",
|
| 379 |
+
"O3'",
|
| 380 |
+
"C2'",
|
| 381 |
+
"C1'",
|
| 382 |
+
"N9",
|
| 383 |
+
"C8",
|
| 384 |
+
"N7",
|
| 385 |
+
"C5",
|
| 386 |
+
"C6",
|
| 387 |
+
"O6",
|
| 388 |
+
"N1",
|
| 389 |
+
"C2",
|
| 390 |
+
"N2",
|
| 391 |
+
"N3",
|
| 392 |
+
"C4",
|
| 393 |
+
],
|
| 394 |
+
"DC": [
|
| 395 |
+
"P",
|
| 396 |
+
"OP1",
|
| 397 |
+
"OP2",
|
| 398 |
+
"O5'",
|
| 399 |
+
"C5'",
|
| 400 |
+
"C4'",
|
| 401 |
+
"O4'",
|
| 402 |
+
"C3'",
|
| 403 |
+
"O3'",
|
| 404 |
+
"C2'",
|
| 405 |
+
"C1'",
|
| 406 |
+
"N1",
|
| 407 |
+
"C2",
|
| 408 |
+
"O2",
|
| 409 |
+
"N3",
|
| 410 |
+
"C4",
|
| 411 |
+
"N4",
|
| 412 |
+
"C5",
|
| 413 |
+
"C6",
|
| 414 |
+
],
|
| 415 |
+
"DT": [
|
| 416 |
+
"P",
|
| 417 |
+
"OP1",
|
| 418 |
+
"OP2",
|
| 419 |
+
"O5'",
|
| 420 |
+
"C5'",
|
| 421 |
+
"C4'",
|
| 422 |
+
"O4'",
|
| 423 |
+
"C3'",
|
| 424 |
+
"O3'",
|
| 425 |
+
"C2'",
|
| 426 |
+
"C1'",
|
| 427 |
+
"N1",
|
| 428 |
+
"C2",
|
| 429 |
+
"O2",
|
| 430 |
+
"N3",
|
| 431 |
+
"C4",
|
| 432 |
+
"O4",
|
| 433 |
+
"C5",
|
| 434 |
+
"C7",
|
| 435 |
+
"C6",
|
| 436 |
+
],
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
RNA_HEAVY_ATOMS = {
|
| 440 |
+
"A": [
|
| 441 |
+
"P",
|
| 442 |
+
"OP1",
|
| 443 |
+
"OP2",
|
| 444 |
+
"O5'",
|
| 445 |
+
"C5'",
|
| 446 |
+
"C4'",
|
| 447 |
+
"O4'",
|
| 448 |
+
"C3'",
|
| 449 |
+
"O3'",
|
| 450 |
+
"C2'",
|
| 451 |
+
"O2'",
|
| 452 |
+
"C1'",
|
| 453 |
+
"N9",
|
| 454 |
+
"C8",
|
| 455 |
+
"N7",
|
| 456 |
+
"C5",
|
| 457 |
+
"C6",
|
| 458 |
+
"N6",
|
| 459 |
+
"N1",
|
| 460 |
+
"C2",
|
| 461 |
+
"N3",
|
| 462 |
+
"C4",
|
| 463 |
+
],
|
| 464 |
+
"G": [
|
| 465 |
+
"P",
|
| 466 |
+
"OP1",
|
| 467 |
+
"OP2",
|
| 468 |
+
"O5'",
|
| 469 |
+
"C5'",
|
| 470 |
+
"C4'",
|
| 471 |
+
"O4'",
|
| 472 |
+
"C3'",
|
| 473 |
+
"O3'",
|
| 474 |
+
"C2'",
|
| 475 |
+
"O2'",
|
| 476 |
+
"C1'",
|
| 477 |
+
"N9",
|
| 478 |
+
"C8",
|
| 479 |
+
"N7",
|
| 480 |
+
"C5",
|
| 481 |
+
"C6",
|
| 482 |
+
"O6",
|
| 483 |
+
"N1",
|
| 484 |
+
"C2",
|
| 485 |
+
"N2",
|
| 486 |
+
"N3",
|
| 487 |
+
"C4",
|
| 488 |
+
],
|
| 489 |
+
"C": [
|
| 490 |
+
"P",
|
| 491 |
+
"OP1",
|
| 492 |
+
"OP2",
|
| 493 |
+
"O5'",
|
| 494 |
+
"C5'",
|
| 495 |
+
"C4'",
|
| 496 |
+
"O4'",
|
| 497 |
+
"C3'",
|
| 498 |
+
"O3'",
|
| 499 |
+
"C2'",
|
| 500 |
+
"O2'",
|
| 501 |
+
"C1'",
|
| 502 |
+
"N1",
|
| 503 |
+
"C2",
|
| 504 |
+
"O2",
|
| 505 |
+
"N3",
|
| 506 |
+
"C4",
|
| 507 |
+
"N4",
|
| 508 |
+
"C5",
|
| 509 |
+
"C6",
|
| 510 |
+
],
|
| 511 |
+
"U": [
|
| 512 |
+
"P",
|
| 513 |
+
"OP1",
|
| 514 |
+
"OP2",
|
| 515 |
+
"O5'",
|
| 516 |
+
"C5'",
|
| 517 |
+
"C4'",
|
| 518 |
+
"O4'",
|
| 519 |
+
"C3'",
|
| 520 |
+
"O3'",
|
| 521 |
+
"C2'",
|
| 522 |
+
"O2'",
|
| 523 |
+
"C1'",
|
| 524 |
+
"N1",
|
| 525 |
+
"C2",
|
| 526 |
+
"O2",
|
| 527 |
+
"N3",
|
| 528 |
+
"C4",
|
| 529 |
+
"O4",
|
| 530 |
+
"C5",
|
| 531 |
+
"C6",
|
| 532 |
+
],
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
# Unknown nucleotide backbone atoms
|
| 536 |
+
DNA_BACKBONE_ATOMS = [
|
| 537 |
+
"P",
|
| 538 |
+
"OP1",
|
| 539 |
+
"OP2",
|
| 540 |
+
"O5'",
|
| 541 |
+
"C5'",
|
| 542 |
+
"C4'",
|
| 543 |
+
"O4'",
|
| 544 |
+
"C3'",
|
| 545 |
+
"O3'",
|
| 546 |
+
"C2'",
|
| 547 |
+
"C1'",
|
| 548 |
+
]
|
| 549 |
+
RNA_BACKBONE_ATOMS = [
|
| 550 |
+
"P",
|
| 551 |
+
"OP1",
|
| 552 |
+
"OP2",
|
| 553 |
+
"O5'",
|
| 554 |
+
"C5'",
|
| 555 |
+
"C4'",
|
| 556 |
+
"O4'",
|
| 557 |
+
"C3'",
|
| 558 |
+
"O3'",
|
| 559 |
+
"C2'",
|
| 560 |
+
"O2'",
|
| 561 |
+
"C1'",
|
| 562 |
+
]
|
| 563 |
+
|
esmfold2_constants_esm3.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from functools import cache
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from huggingface_hub import snapshot_download
|
| 6 |
+
|
| 7 |
+
SEQUENCE_BOS_TOKEN = 0
|
| 8 |
+
SEQUENCE_PAD_TOKEN = 1
|
| 9 |
+
SEQUENCE_EOS_TOKEN = 2
|
| 10 |
+
SEQUENCE_CHAINBREAK_TOKEN = 31
|
| 11 |
+
SEQUENCE_MASK_TOKEN = 32
|
| 12 |
+
|
| 13 |
+
VQVAE_CODEBOOK_SIZE = 4096
|
| 14 |
+
VQVAE_SPECIAL_TOKENS = {
|
| 15 |
+
"MASK": VQVAE_CODEBOOK_SIZE,
|
| 16 |
+
"EOS": VQVAE_CODEBOOK_SIZE + 1,
|
| 17 |
+
"BOS": VQVAE_CODEBOOK_SIZE + 2,
|
| 18 |
+
"PAD": VQVAE_CODEBOOK_SIZE + 3,
|
| 19 |
+
"CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4,
|
| 20 |
+
}
|
| 21 |
+
VQVAE_DIRECTION_LOSS_BINS = 16
|
| 22 |
+
VQVAE_PAE_BINS = 64
|
| 23 |
+
VQVAE_MAX_PAE_BIN = 31.0
|
| 24 |
+
VQVAE_PLDDT_BINS = 50
|
| 25 |
+
|
| 26 |
+
STRUCTURE_MASK_TOKEN = VQVAE_SPECIAL_TOKENS["MASK"]
|
| 27 |
+
STRUCTURE_BOS_TOKEN = VQVAE_SPECIAL_TOKENS["BOS"]
|
| 28 |
+
STRUCTURE_EOS_TOKEN = VQVAE_SPECIAL_TOKENS["EOS"]
|
| 29 |
+
STRUCTURE_PAD_TOKEN = VQVAE_SPECIAL_TOKENS["PAD"]
|
| 30 |
+
STRUCTURE_CHAINBREAK_TOKEN = VQVAE_SPECIAL_TOKENS["CHAINBREAK"]
|
| 31 |
+
STRUCTURE_UNDEFINED_TOKEN = 955
|
| 32 |
+
|
| 33 |
+
SASA_PAD_TOKEN = 0
|
| 34 |
+
|
| 35 |
+
SS8_PAD_TOKEN = 0
|
| 36 |
+
|
| 37 |
+
INTERPRO_PAD_TOKEN = 0
|
| 38 |
+
|
| 39 |
+
RESIDUE_PAD_TOKEN = 0
|
| 40 |
+
|
| 41 |
+
CHAIN_BREAK_STR = "|"
|
| 42 |
+
|
| 43 |
+
SEQUENCE_BOS_STR = "<cls>"
|
| 44 |
+
SEQUENCE_EOS_STR = "<eos>"
|
| 45 |
+
|
| 46 |
+
MASK_STR_SHORT = "_"
|
| 47 |
+
SEQUENCE_MASK_STR = "<mask>"
|
| 48 |
+
SASA_MASK_STR = "<unk>"
|
| 49 |
+
SS8_MASK_STR = "<unk>"
|
| 50 |
+
|
| 51 |
+
# fmt: off
|
| 52 |
+
SEQUENCE_VOCAB = [
|
| 53 |
+
"<cls>", "<pad>", "<eos>", "<unk>",
|
| 54 |
+
"L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
|
| 55 |
+
"Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
|
| 56 |
+
"O", ".", "-", "|",
|
| 57 |
+
"<mask>",
|
| 58 |
+
]
|
| 59 |
+
# fmt: on
|
| 60 |
+
|
| 61 |
+
SEQUENCE_STANDARD_AA_MIN_TOKEN = 4 # L
|
| 62 |
+
SEQUENCE_STANDARD_AA_MAX_TOKEN = 24 # X (exclusive)
|
| 63 |
+
|
| 64 |
+
SSE_8CLASS_VOCAB = "GHITEBSC"
|
| 65 |
+
SSE_3CLASS_VOCAB = "HEC"
|
| 66 |
+
SSE_8CLASS_TO_3CLASS_MAP = {
|
| 67 |
+
"G": "H",
|
| 68 |
+
"H": "H",
|
| 69 |
+
"I": "H",
|
| 70 |
+
"T": "C",
|
| 71 |
+
"E": "E",
|
| 72 |
+
"B": "E",
|
| 73 |
+
"S": "C",
|
| 74 |
+
"C": "C",
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
SASA_DISCRETIZATION_BOUNDARIES = [
|
| 78 |
+
0.8,
|
| 79 |
+
4.0,
|
| 80 |
+
9.6,
|
| 81 |
+
16.4,
|
| 82 |
+
24.5,
|
| 83 |
+
32.9,
|
| 84 |
+
42.0,
|
| 85 |
+
51.5,
|
| 86 |
+
61.2,
|
| 87 |
+
70.9,
|
| 88 |
+
81.6,
|
| 89 |
+
93.3,
|
| 90 |
+
107.2,
|
| 91 |
+
125.4,
|
| 92 |
+
151.4,
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
MAX_RESIDUE_ANNOTATIONS = 16
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
TFIDF_VECTOR_SIZE = 58641
|
| 99 |
+
|
| 100 |
+
FUNCTION_TOKENS_DEPTH = 8
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
@cache
|
| 105 |
+
def data_root(model: str):
|
| 106 |
+
if "INFRA_PROVIDER" in os.environ:
|
| 107 |
+
return Path("")
|
| 108 |
+
# Try to download from huggingface if it doesn't exist
|
| 109 |
+
if model.startswith("esm3"):
|
| 110 |
+
path = Path(snapshot_download(repo_id="biohub/esm3-sm-open-v1"))
|
| 111 |
+
elif model.startswith("esmc-300"):
|
| 112 |
+
path = Path(snapshot_download(repo_id="biohub/esmc-300m-2024-12"))
|
| 113 |
+
elif model.startswith("esmc-600"):
|
| 114 |
+
path = Path(snapshot_download(repo_id="biohub/esmc-600m-2024-12"))
|
| 115 |
+
elif model.startswith("esmc-6b"):
|
| 116 |
+
path = Path(snapshot_download(repo_id="biohub/esmc-6b-2024-12"))
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError(f"{model=} is an invalid model name.")
|
| 119 |
+
return path
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
IN_REPO_DATA_FOLDER = Path(__file__).parents[2] / "data"
|
| 123 |
+
|
| 124 |
+
INTERPRO_ENTRY = IN_REPO_DATA_FOLDER / "entry_list_safety_29026.list"
|
| 125 |
+
INTERPRO_HIERARCHY = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
|
| 126 |
+
INTERPRO2GO = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
|
| 127 |
+
INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json"
|
| 128 |
+
|
| 129 |
+
LSH_TABLE_PATHS = {"8bit": "data/hyperplanes_8bit_58641.npz"}
|
| 130 |
+
|
| 131 |
+
KEYWORDS_VOCABULARY = (
|
| 132 |
+
IN_REPO_DATA_FOLDER / "keyword_vocabulary_safety_filtered_58641.txt"
|
| 133 |
+
)
|
| 134 |
+
KEYWORDS_IDF = IN_REPO_DATA_FOLDER / "keyword_idf_safety_filtered_58641.npy"
|
| 135 |
+
|
| 136 |
+
RESID_CSV = "data/uniref90_and_mgnify90_residue_annotations_gt_1k_proteins.csv"
|
| 137 |
+
INTERPRO2KEYWORDS = IN_REPO_DATA_FOLDER / "interpro_29026_to_keywords_58641.csv"
|
| 138 |
+
|
esmfold2_input_builder.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any, Sequence, TypeAlias, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from .esmfold2_msa import MSA
|
| 7 |
+
|
| 8 |
+
# fmt: off
|
| 9 |
+
MSAInput: TypeAlias = Union[
|
| 10 |
+
MSA,
|
| 11 |
+
None,
|
| 12 |
+
]
|
| 13 |
+
# fmt: on
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class Modification:
|
| 18 |
+
position: int # zero-indexed
|
| 19 |
+
ccd: str
|
| 20 |
+
smiles: str | None = None # TODO(mlee): add smiles support
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ProteinInput:
|
| 25 |
+
id: str | list[str]
|
| 26 |
+
sequence: str
|
| 27 |
+
modifications: list[Modification] | None = None
|
| 28 |
+
msa: MSAInput = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class RNAInput:
|
| 33 |
+
id: str | list[str]
|
| 34 |
+
sequence: str
|
| 35 |
+
modifications: list[Modification] | None = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class DNAInput:
|
| 40 |
+
id: str | list[str]
|
| 41 |
+
sequence: str
|
| 42 |
+
modifications: list[Modification] | None = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class LigandInput:
|
| 47 |
+
id: str | list[str]
|
| 48 |
+
smiles: str | None = None
|
| 49 |
+
ccd: list[str] | None = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class DistogramConditioning:
|
| 54 |
+
chain_id: str
|
| 55 |
+
distogram: np.ndarray
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class PocketConditioning:
|
| 60 |
+
binder_chain_id: str
|
| 61 |
+
contacts: list[tuple[str, int]]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class CovalentBond:
|
| 66 |
+
chain_id1: str
|
| 67 |
+
res_idx1: int
|
| 68 |
+
atom_idx1: int
|
| 69 |
+
chain_id2: str
|
| 70 |
+
res_idx2: int
|
| 71 |
+
atom_idx2: int
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class StructurePredictionInput:
|
| 76 |
+
sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput]
|
| 77 |
+
pocket: PocketConditioning | None = None
|
| 78 |
+
distogram_conditioning: list[DistogramConditioning] | None = None
|
| 79 |
+
covalent_bonds: list[CovalentBond] | None = None
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput):
|
| 83 |
+
def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]:
|
| 84 |
+
chain_data: dict[str, Any] = {
|
| 85 |
+
"sequence": seq_input.sequence,
|
| 86 |
+
"id": seq_input.id,
|
| 87 |
+
"type": chain_type,
|
| 88 |
+
}
|
| 89 |
+
if hasattr(seq_input, "modifications") and seq_input.modifications:
|
| 90 |
+
mods = [
|
| 91 |
+
{"position": mod.position, "ccd": mod.ccd}
|
| 92 |
+
for mod in seq_input.modifications
|
| 93 |
+
]
|
| 94 |
+
chain_data["modifications"] = mods
|
| 95 |
+
if not hasattr(seq_input, "msa"):
|
| 96 |
+
pass
|
| 97 |
+
elif seq_input.msa is None:
|
| 98 |
+
chain_data["msa"] = None
|
| 99 |
+
elif isinstance(seq_input.msa, MSA):
|
| 100 |
+
chain_data["msa"] = {"sequences": seq_input.msa.sequences}
|
| 101 |
+
else:
|
| 102 |
+
error_msg = f"MSA must be None or MSA. Got {seq_input.msa} instead."
|
| 103 |
+
raise AttributeError(error_msg)
|
| 104 |
+
return chain_data
|
| 105 |
+
|
| 106 |
+
sequences = []
|
| 107 |
+
for seq_input in all_atom_input.sequences:
|
| 108 |
+
if isinstance(seq_input, ProteinInput):
|
| 109 |
+
sequences.append(create_chain_data(seq_input, "protein"))
|
| 110 |
+
elif isinstance(seq_input, RNAInput):
|
| 111 |
+
sequences.append(create_chain_data(seq_input, "rna"))
|
| 112 |
+
elif isinstance(seq_input, DNAInput):
|
| 113 |
+
sequences.append(create_chain_data(seq_input, "dna"))
|
| 114 |
+
elif isinstance(seq_input, LigandInput):
|
| 115 |
+
sequences.append(
|
| 116 |
+
{
|
| 117 |
+
"smiles": seq_input.smiles,
|
| 118 |
+
"id": seq_input.id,
|
| 119 |
+
"ccd": seq_input.ccd,
|
| 120 |
+
"type": "ligand",
|
| 121 |
+
}
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f"Unsupported sequence input type: {type(seq_input)}")
|
| 125 |
+
|
| 126 |
+
result: dict[str, Any] = {"sequences": sequences}
|
| 127 |
+
|
| 128 |
+
if all_atom_input.covalent_bonds is not None:
|
| 129 |
+
result["covalent_bonds"] = [
|
| 130 |
+
{
|
| 131 |
+
"chain_id1": bond.chain_id1,
|
| 132 |
+
"res_idx1": bond.res_idx1,
|
| 133 |
+
"atom_idx1": bond.atom_idx1,
|
| 134 |
+
"chain_id2": bond.chain_id2,
|
| 135 |
+
"res_idx2": bond.res_idx2,
|
| 136 |
+
"atom_idx2": bond.atom_idx2,
|
| 137 |
+
}
|
| 138 |
+
for bond in all_atom_input.covalent_bonds
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
if all_atom_input.pocket is not None:
|
| 142 |
+
result["pocket"] = {
|
| 143 |
+
"binder_chain_id": all_atom_input.pocket.binder_chain_id,
|
| 144 |
+
"contacts": all_atom_input.pocket.contacts,
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
if all_atom_input.distogram_conditioning is not None:
|
| 148 |
+
result["distogram_conditioning"] = [
|
| 149 |
+
{"chain_id": disto.chain_id, "distogram": disto.distogram.tolist()}
|
| 150 |
+
for disto in all_atom_input.distogram_conditioning
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
return result
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def deserialize_structure_prediction_input(
|
| 157 |
+
data: dict[str, Any],
|
| 158 |
+
) -> StructurePredictionInput:
|
| 159 |
+
"""Inverse of :func:`serialize_structure_prediction_input`.
|
| 160 |
+
|
| 161 |
+
Reconstructs a :class:`StructurePredictionInput` from the JSON-safe dict
|
| 162 |
+
produced by ``serialize_structure_prediction_input``. Values round-trip;
|
| 163 |
+
``DistogramConditioning.distogram`` dtype follows from JSON (``int64``
|
| 164 |
+
for integer entries, ``float64`` for floats) — cast back to the original
|
| 165 |
+
dtype if downstream code requires a specific one.
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
def _mods(chain: dict[str, Any]) -> list[Modification] | None:
|
| 169 |
+
raw = chain.get("modifications")
|
| 170 |
+
if not raw:
|
| 171 |
+
return None
|
| 172 |
+
return [Modification(position=m["position"], ccd=m["ccd"]) for m in raw]
|
| 173 |
+
|
| 174 |
+
def _msa(chain: dict[str, Any]) -> MSAInput:
|
| 175 |
+
if "msa" not in chain or chain["msa"] is None:
|
| 176 |
+
return None
|
| 177 |
+
msa_blk = chain["msa"]
|
| 178 |
+
if isinstance(msa_blk, str):
|
| 179 |
+
raise ValueError(f"Unexpected MSA string value: {msa_blk!r}")
|
| 180 |
+
return MSA.from_sequences(msa_blk["sequences"])
|
| 181 |
+
|
| 182 |
+
sequences: list[ProteinInput | RNAInput | DNAInput | LigandInput] = []
|
| 183 |
+
for chain in data["sequences"]:
|
| 184 |
+
t = chain["type"]
|
| 185 |
+
if t == "protein":
|
| 186 |
+
sequences.append(
|
| 187 |
+
ProteinInput(
|
| 188 |
+
id=chain["id"],
|
| 189 |
+
sequence=chain["sequence"],
|
| 190 |
+
modifications=_mods(chain),
|
| 191 |
+
msa=_msa(chain),
|
| 192 |
+
)
|
| 193 |
+
)
|
| 194 |
+
elif t == "rna":
|
| 195 |
+
sequences.append(
|
| 196 |
+
RNAInput(
|
| 197 |
+
id=chain["id"],
|
| 198 |
+
sequence=chain["sequence"],
|
| 199 |
+
modifications=_mods(chain),
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
elif t == "dna":
|
| 203 |
+
sequences.append(
|
| 204 |
+
DNAInput(
|
| 205 |
+
id=chain["id"],
|
| 206 |
+
sequence=chain["sequence"],
|
| 207 |
+
modifications=_mods(chain),
|
| 208 |
+
)
|
| 209 |
+
)
|
| 210 |
+
elif t == "ligand":
|
| 211 |
+
sequences.append(
|
| 212 |
+
LigandInput(
|
| 213 |
+
id=chain["id"], smiles=chain.get("smiles"), ccd=chain.get("ccd")
|
| 214 |
+
)
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
raise ValueError(f"Unsupported sequence type: {t!r}")
|
| 218 |
+
|
| 219 |
+
pocket: PocketConditioning | None = None
|
| 220 |
+
if (pocket_blk := data.get("pocket")) is not None:
|
| 221 |
+
pocket = PocketConditioning(
|
| 222 |
+
binder_chain_id=pocket_blk["binder_chain_id"],
|
| 223 |
+
contacts=[tuple(c) for c in pocket_blk["contacts"]],
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
distogram_conditioning: list[DistogramConditioning] | None = None
|
| 227 |
+
if (disto_blk := data.get("distogram_conditioning")) is not None:
|
| 228 |
+
distogram_conditioning = [
|
| 229 |
+
DistogramConditioning(
|
| 230 |
+
chain_id=d["chain_id"], distogram=np.asarray(d["distogram"])
|
| 231 |
+
)
|
| 232 |
+
for d in disto_blk
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
covalent_bonds: list[CovalentBond] | None = None
|
| 236 |
+
if (bonds_blk := data.get("covalent_bonds")) is not None:
|
| 237 |
+
covalent_bonds = [
|
| 238 |
+
CovalentBond(
|
| 239 |
+
chain_id1=b["chain_id1"],
|
| 240 |
+
res_idx1=b["res_idx1"],
|
| 241 |
+
atom_idx1=b["atom_idx1"],
|
| 242 |
+
chain_id2=b["chain_id2"],
|
| 243 |
+
res_idx2=b["res_idx2"],
|
| 244 |
+
atom_idx2=b["atom_idx2"],
|
| 245 |
+
)
|
| 246 |
+
for b in bonds_blk
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
return StructurePredictionInput(
|
| 250 |
+
sequences=sequences,
|
| 251 |
+
pocket=pocket,
|
| 252 |
+
distogram_conditioning=distogram_conditioning,
|
| 253 |
+
covalent_bonds=covalent_bonds,
|
| 254 |
+
)
|
| 255 |
+
|
esmfold2_metrics.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch.amp import autocast # type: ignore
|
| 7 |
+
|
| 8 |
+
from . import esmfold2_residue_constants
|
| 9 |
+
from .esmfold2_misc import binpack, unbinpack
|
| 10 |
+
from .esmfold2_protein_structure import (
|
| 11 |
+
compute_alignment_tensors,
|
| 12 |
+
compute_gdt_ts_no_alignment,
|
| 13 |
+
compute_rmsd_no_alignment,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def contact_precision(
|
| 18 |
+
predictions: Tensor,
|
| 19 |
+
targets: Tensor,
|
| 20 |
+
src_lengths: Tensor | None = None,
|
| 21 |
+
minsep: int = 6,
|
| 22 |
+
maxsep: int | None = None,
|
| 23 |
+
override_length: int | None = None, # for casp
|
| 24 |
+
):
|
| 25 |
+
"""Computes contact precisions.
|
| 26 |
+
|
| 27 |
+
For protein contact prediction, precision is measured for the top (L/K) highest confidence predictions,
|
| 28 |
+
with L being the length of the protein sequence and K generally being equal to 1 or 5.
|
| 29 |
+
|
| 30 |
+
K = 5 measures the predictions of the very highest confidence contacts, while K = 1 is a more general measure
|
| 31 |
+
over all relatively high confidence predictions.
|
| 32 |
+
|
| 33 |
+
Since there are roughly ~L true contacts in a protein, this is a reasonable cutoff.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
predictions (Tensor): Tensor of probabilities of size (B, L, L)
|
| 38 |
+
targets (Tensor): Tensor of true contacts of size (B, L, L)
|
| 39 |
+
src_lengths (Tensor, optional): Lengths of each sample in the batch, if using variable lengths.
|
| 40 |
+
If not provided, inferred from the size of the predictions.
|
| 41 |
+
minsep (int): Minimum separation distance to consider. We often want to measure contacts at a
|
| 42 |
+
certain range. Typical ranges are short [6, 12), medium [12, 24), and long [24, inf).
|
| 43 |
+
maxsep (int, optional): Used in conjunction with minsep to specify a contact range. If not provided uses
|
| 44 |
+
assumes no maximum range
|
| 45 |
+
override_length (int, optional): Used for casp evaluation where sometimes the "true" length is not
|
| 46 |
+
the same as the length of the input. Kept for posterity, we probably don't need this argument.
|
| 47 |
+
"""
|
| 48 |
+
if predictions.dim() == 2:
|
| 49 |
+
predictions = predictions.unsqueeze(0)
|
| 50 |
+
if targets.dim() == 2:
|
| 51 |
+
targets = targets.unsqueeze(0)
|
| 52 |
+
|
| 53 |
+
# Check sizes
|
| 54 |
+
if predictions.size() != targets.size():
|
| 55 |
+
raise ValueError(
|
| 56 |
+
f"Size mismatch. Received predictions of size {predictions.size()}, "
|
| 57 |
+
f"targets of size {targets.size()}"
|
| 58 |
+
)
|
| 59 |
+
device = predictions.device
|
| 60 |
+
|
| 61 |
+
batch_size, seqlen, _ = predictions.size()
|
| 62 |
+
|
| 63 |
+
# Step 1) Construct a mask of size [B, L, L] to mask invalid contacts
|
| 64 |
+
seqlen_range = torch.arange(seqlen, device=device)
|
| 65 |
+
sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1)
|
| 66 |
+
sep = sep.unsqueeze(0)
|
| 67 |
+
# Mask contacts that are closer than minsep
|
| 68 |
+
valid_mask = sep >= minsep
|
| 69 |
+
# Mask contacts where target is negative (padding or unknown)
|
| 70 |
+
valid_mask = valid_mask & (targets >= 0) # negative targets are invalid
|
| 71 |
+
|
| 72 |
+
# Mask contacts that are farther than maxsep, if provided
|
| 73 |
+
if maxsep is not None:
|
| 74 |
+
valid_mask &= sep < maxsep
|
| 75 |
+
|
| 76 |
+
if src_lengths is not None:
|
| 77 |
+
# If the lengths of the individual sequences are provided, mask positions
|
| 78 |
+
# that are farther than the end of the sequence.
|
| 79 |
+
valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1)
|
| 80 |
+
valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2)
|
| 81 |
+
else:
|
| 82 |
+
src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long)
|
| 83 |
+
|
| 84 |
+
# Fill in the logit tensor with -inf for all invalid positions
|
| 85 |
+
predictions = predictions.masked_fill(~valid_mask, float("-inf"))
|
| 86 |
+
|
| 87 |
+
# Step 2) Select the top half of the prediction (should be symmetric)
|
| 88 |
+
x_ind, y_ind = np.triu_indices(seqlen, minsep)
|
| 89 |
+
predictions_upper = predictions[:, x_ind, y_ind]
|
| 90 |
+
targets_upper = targets[:, x_ind, y_ind]
|
| 91 |
+
|
| 92 |
+
# Step 3) Select the topk values in each batch where k = L (length of sequence)
|
| 93 |
+
topk = seqlen if override_length is None else max(seqlen, override_length)
|
| 94 |
+
# Indices are the indices into the predictions corresponding to the most confident predictions
|
| 95 |
+
indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk]
|
| 96 |
+
# topk_targets are the target values corresponding to the above indices
|
| 97 |
+
topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices]
|
| 98 |
+
if topk_targets.size(1) < topk:
|
| 99 |
+
# If there aren't enough targets, pad to the output.
|
| 100 |
+
topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)])
|
| 101 |
+
|
| 102 |
+
# Step 4) Sum the accuracy at of the top-i predictions for i in 1, L
|
| 103 |
+
# topk_targets => 1/0 true vs. false contact, sorted by confidence of prediction
|
| 104 |
+
# cmumulative sum => Number of correct answers for the top-i predictions.
|
| 105 |
+
cumulative_dist = topk_targets.type_as(predictions).cumsum(-1)
|
| 106 |
+
|
| 107 |
+
# Step 5) Find the gather indices. This should be P@(L / K) for varous values of K
|
| 108 |
+
# The values will differ for each batch.
|
| 109 |
+
gather_lengths = src_lengths.unsqueeze(1)
|
| 110 |
+
if override_length is not None:
|
| 111 |
+
gather_lengths = override_length * torch.ones_like(
|
| 112 |
+
gather_lengths, device=device
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# This gets you (0.1 * L, 0.2 * L, 0.3 * L, etc.)
|
| 116 |
+
gather_indices = (
|
| 117 |
+
(torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths).type(
|
| 118 |
+
torch.long
|
| 119 |
+
)
|
| 120 |
+
- 1
|
| 121 |
+
).clamp_min(0)
|
| 122 |
+
|
| 123 |
+
# Step 6) Gather the results and divide by the number of guesses to get the precision.
|
| 124 |
+
binned_cumulative_dist = cumulative_dist.gather(1, gather_indices)
|
| 125 |
+
binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as(
|
| 126 |
+
binned_cumulative_dist
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Select specific P@L/k. pl5 is index 1 b/c that corresponds to L * 0.2 in
|
| 130 |
+
# gather_indices above
|
| 131 |
+
pl5 = binned_precisions[:, 1]
|
| 132 |
+
# pl2 = binned_precisions[:, 4]
|
| 133 |
+
pl = binned_precisions[:, 9]
|
| 134 |
+
# AUC is the integral wrt K of P@L/K for K in range(1, L)
|
| 135 |
+
auc = binned_precisions.mean(-1)
|
| 136 |
+
|
| 137 |
+
return {"AUC": auc, "P@L": pl, "P@L5": pl5}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def compute_lddt(
|
| 141 |
+
all_atom_pred_pos: torch.Tensor,
|
| 142 |
+
all_atom_positions: torch.Tensor,
|
| 143 |
+
all_atom_mask: torch.Tensor,
|
| 144 |
+
pairwise_all_atom_mask: torch.Tensor | None = None,
|
| 145 |
+
cutoff: float | torch.Tensor = 15.0,
|
| 146 |
+
eps: float = 1e-10,
|
| 147 |
+
per_residue: bool = True,
|
| 148 |
+
sequence_id: torch.Tensor | None = None,
|
| 149 |
+
) -> torch.Tensor:
|
| 150 |
+
"""
|
| 151 |
+
Computes LDDT for a protein. Tensor sizes below include some optional dimensions. Specifically:
|
| 152 |
+
Nstates:
|
| 153 |
+
all_atom_pred_pos can contain multiple states in the first dimension which corresponds to outputs from different layers of a model (e.g. each IPA block). The return size will be [Nstates x Batch size] if this is included.
|
| 154 |
+
Natoms:
|
| 155 |
+
LDDT can be computed for all atoms or some atoms. The second to last dimension should contain the *FLATTENED* representation of L x Natoms. If you want to calculate for atom37, e.g., this will be of size (L * 37). If you are only calculating CA LDDT, it will be of size L.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions
|
| 159 |
+
all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions
|
| 160 |
+
all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists.
|
| 161 |
+
pairwise_all_atom_mask (Tensor[float], [B x (L * Natoms x L * Natoms)], optional): Tensor of masks, indicating whether a pair of atoms should be considered in the LDDT calculation.
|
| 162 |
+
cutoff (float): Max distance to score lddt over. This can either be a float, or a tensor of shape [B, L, L] to allow for per-residue cutoffs, e.g. if you want to use a different cutoff for nucleic acids.
|
| 163 |
+
per_residue (bool): Whether to return per-residue or full-protein lddt.
|
| 164 |
+
sequence_id (Tensor, optional): Sequence id tensor for binpacking. NOTE: only supported for lddt_ca calculations, not when Natoms is passed!
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
LDDT Tensor:
|
| 168 |
+
if per_residue:
|
| 169 |
+
Tensor[float], [(Nstates x) B x (L * Natoms)]
|
| 170 |
+
else:
|
| 171 |
+
Tensor[float], [(Nstates x) B]
|
| 172 |
+
"""
|
| 173 |
+
all_atom_mask = all_atom_mask[..., None] # add a dimension for broadcasting
|
| 174 |
+
dmat_true = torch.sqrt(
|
| 175 |
+
eps
|
| 176 |
+
+ torch.sum(
|
| 177 |
+
(all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])
|
| 178 |
+
** 2,
|
| 179 |
+
dim=-1,
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
dmat_pred = torch.sqrt(
|
| 184 |
+
eps
|
| 185 |
+
+ torch.sum(
|
| 186 |
+
(all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2,
|
| 187 |
+
dim=-1,
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
mask = all_atom_mask * rearrange(all_atom_mask, "... a b -> ... b a")
|
| 191 |
+
if pairwise_all_atom_mask is not None:
|
| 192 |
+
mask = mask * pairwise_all_atom_mask
|
| 193 |
+
|
| 194 |
+
if sequence_id is not None:
|
| 195 |
+
# TODO: This will work for lddt_ca, but not for regular lddt
|
| 196 |
+
# Problem is that regular lddt has natoms * nres scores, so would need to repeat this mask by natoms
|
| 197 |
+
# Leaving for now because it won't fail silently so should be ook.
|
| 198 |
+
seqid_mask = sequence_id[..., None] == sequence_id[..., None, :]
|
| 199 |
+
mask = mask * seqid_mask.type_as(mask)
|
| 200 |
+
|
| 201 |
+
return compute_lddt_from_dmat(
|
| 202 |
+
dmat_pred, dmat_true, mask, cutoff=cutoff, eps=eps, per_residue=per_residue
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def compute_lddt_from_dmat(
|
| 207 |
+
dmat_pred: torch.Tensor,
|
| 208 |
+
dmat_true: torch.Tensor,
|
| 209 |
+
pairwise_mask: torch.Tensor,
|
| 210 |
+
cutoff: float | torch.Tensor = 15.0,
|
| 211 |
+
eps: float = 1e-10,
|
| 212 |
+
per_residue: bool = True,
|
| 213 |
+
):
|
| 214 |
+
"""
|
| 215 |
+
Compute LDDT from pre-computed distance matrices.
|
| 216 |
+
This is useful when you want to compute LDDT with multiple different masks or cutoffs, e.g. for different molecule types (protein, nucleic acid, etc.).
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
dmat_pred (Tensor[float], [B x L x L]): Predicted distance matrix
|
| 220 |
+
dmat_true (Tensor[float], [B x L x L]): True distance matrix
|
| 221 |
+
pairwise_mask (Tensor[float], [B x L x L]): Pairwise mask indicating which pairs of atoms to consider
|
| 222 |
+
cutoff (float): Max distance to score lddt over. This can either be a float, or a tensor of shape [B, L, L] to allow for per-residue cutoffs, e.g. if you want to use a different cutoff for nucleic acids.
|
| 223 |
+
per_residue (bool): Whether to return per-residue or full-protein lddt.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
LDDT Tensor:
|
| 227 |
+
if per_residue:
|
| 228 |
+
Tensor[float], [B x L]
|
| 229 |
+
else:
|
| 230 |
+
Tensor[float], [B]
|
| 231 |
+
"""
|
| 232 |
+
n = dmat_true.size(-1)
|
| 233 |
+
dists_to_score = (
|
| 234 |
+
(dmat_true < cutoff)
|
| 235 |
+
* pairwise_mask
|
| 236 |
+
* (1.0 - torch.eye(n, device=dmat_true.device))
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
dist_l1 = torch.abs(dmat_true - dmat_pred)
|
| 240 |
+
score = (
|
| 241 |
+
(dist_l1 < 0.5).type(dist_l1.dtype)
|
| 242 |
+
+ (dist_l1 < 1.0).type(dist_l1.dtype)
|
| 243 |
+
+ (dist_l1 < 2.0).type(dist_l1.dtype)
|
| 244 |
+
+ (dist_l1 < 4.0).type(dist_l1.dtype)
|
| 245 |
+
)
|
| 246 |
+
score = score * 0.25
|
| 247 |
+
|
| 248 |
+
dims = (-1,) if per_residue else (-2, -1)
|
| 249 |
+
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
|
| 250 |
+
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
|
| 251 |
+
return score
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def compute_lddt_ca(
|
| 255 |
+
all_atom_pred_pos: torch.Tensor,
|
| 256 |
+
all_atom_positions: torch.Tensor,
|
| 257 |
+
all_atom_mask: torch.Tensor,
|
| 258 |
+
cutoff: float = 15.0,
|
| 259 |
+
eps: float = 1e-10,
|
| 260 |
+
per_residue: bool = True,
|
| 261 |
+
sequence_id: torch.Tensor | None = None,
|
| 262 |
+
) -> torch.Tensor:
|
| 263 |
+
ca_pos = residue_constants.atom_order["CA"]
|
| 264 |
+
if all_atom_pred_pos.dim() != 3:
|
| 265 |
+
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
|
| 266 |
+
all_atom_positions = all_atom_positions[..., ca_pos, :]
|
| 267 |
+
all_atom_mask = all_atom_mask[..., ca_pos]
|
| 268 |
+
|
| 269 |
+
return compute_lddt(
|
| 270 |
+
all_atom_pred_pos,
|
| 271 |
+
all_atom_positions,
|
| 272 |
+
all_atom_mask,
|
| 273 |
+
cutoff=cutoff,
|
| 274 |
+
eps=eps,
|
| 275 |
+
per_residue=per_residue,
|
| 276 |
+
sequence_id=sequence_id,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# NOTE(roshan): no_grad required for stack_variable_length_tensors apparently... let's revisit if we want to backprop
|
| 281 |
+
@torch.no_grad()
|
| 282 |
+
@autocast("cuda", enabled=False)
|
| 283 |
+
def compute_rmsd(
|
| 284 |
+
mobile: torch.Tensor,
|
| 285 |
+
target: torch.Tensor,
|
| 286 |
+
atom_exists_mask: torch.Tensor | None = None,
|
| 287 |
+
sequence_id: torch.Tensor | None = None,
|
| 288 |
+
reduction: str = "batch",
|
| 289 |
+
):
|
| 290 |
+
"""
|
| 291 |
+
Compute RMSD between two batches of structures with support for masking invalid atoms using PyTorch.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
|
| 295 |
+
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
|
| 296 |
+
- atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
|
| 297 |
+
- sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
|
| 298 |
+
- reduction (str): One of "batch", "per_sample", "per_residue".
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
If reduction == "batch":
|
| 302 |
+
(torch.Tensor): 0-dim, Average Root Mean Square Deviation between the structures for each batch
|
| 303 |
+
If reduction == "per_sample":
|
| 304 |
+
(torch.Tensor): (B,)-dim, Root Mean Square Deviation between the structures for each batch
|
| 305 |
+
If reduction == "per_residue":
|
| 306 |
+
(torch.Tensor): (B, N)-dim, Root Mean Square Deviation between the structures for residue in the batch
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
(centered_mobile, _, centered_target, _, rotation_matrix, num_valid_atoms) = (
|
| 310 |
+
compute_alignment_tensors(
|
| 311 |
+
mobile=mobile,
|
| 312 |
+
target=target,
|
| 313 |
+
atom_exists_mask=atom_exists_mask,
|
| 314 |
+
sequence_id=sequence_id,
|
| 315 |
+
)
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Apply transformation to centered structure
|
| 319 |
+
rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
|
| 320 |
+
|
| 321 |
+
# Compute rmsd for centered structures
|
| 322 |
+
rmsd = compute_rmsd_no_alignment(
|
| 323 |
+
rotated_mobile, centered_target, num_valid_atoms, reduction=reduction
|
| 324 |
+
)
|
| 325 |
+
if reduction == "per_residue" and sequence_id is not None:
|
| 326 |
+
rmsd = binpack(rmsd, sequence_id, pad_value=0)
|
| 327 |
+
return rmsd
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def compute_gdt_ts(
|
| 331 |
+
mobile: torch.Tensor,
|
| 332 |
+
target: torch.Tensor,
|
| 333 |
+
atom_exists_mask: torch.Tensor | None = None,
|
| 334 |
+
sequence_id: torch.Tensor | None = None,
|
| 335 |
+
reduction: str = "per_sample",
|
| 336 |
+
):
|
| 337 |
+
"""
|
| 338 |
+
Compute GDT_TS between two batches of structures with support for masking invalid atoms using PyTorch.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
|
| 342 |
+
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
|
| 343 |
+
- atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
|
| 344 |
+
- sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
|
| 345 |
+
- reduction (str): One of "batch", "per_sample", "per_residue".
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
If reduction == "batch":
|
| 349 |
+
(torch.Tensor): 0-dim, GDT_TS between the structures for each batch
|
| 350 |
+
If reduction == "per_sample":
|
| 351 |
+
(torch.Tensor): (B,)-dim, GDT_TS between the structures for each sample in the batch
|
| 352 |
+
"""
|
| 353 |
+
if atom_exists_mask is None:
|
| 354 |
+
atom_exists_mask = torch.isfinite(target).all(dim=-1)
|
| 355 |
+
(centered_mobile, _, centered_target, _, rotation_matrix, _) = (
|
| 356 |
+
compute_alignment_tensors(
|
| 357 |
+
mobile=mobile,
|
| 358 |
+
target=target,
|
| 359 |
+
atom_exists_mask=atom_exists_mask,
|
| 360 |
+
sequence_id=sequence_id,
|
| 361 |
+
)
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# Apply transformation to centered structure
|
| 365 |
+
rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
|
| 366 |
+
|
| 367 |
+
# the coordinate tensors returned by `compute_alignment_tensors` are unbinpacked and contain zeros for invalid positions
|
| 368 |
+
# so `compute_gdt_ts_no_alignment` requires `atom_exists_mask` to be passed and be unbinpacked
|
| 369 |
+
if sequence_id is not None:
|
| 370 |
+
atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=False)
|
| 371 |
+
return compute_gdt_ts_no_alignment(
|
| 372 |
+
rotated_mobile, centered_target, atom_exists_mask, reduction
|
| 373 |
+
)
|
| 374 |
+
|
esmfold2_misc.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from contextlib import nullcontext
|
| 6 |
+
from dataclasses import is_dataclass
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
from typing import (
|
| 9 |
+
Any,
|
| 10 |
+
ContextManager,
|
| 11 |
+
Generator,
|
| 12 |
+
Iterable,
|
| 13 |
+
Protocol,
|
| 14 |
+
Sequence,
|
| 15 |
+
TypeVar,
|
| 16 |
+
runtime_checkable,
|
| 17 |
+
)
|
| 18 |
+
from warnings import warn
|
| 19 |
+
|
| 20 |
+
import huggingface_hub
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import zstd
|
| 24 |
+
|
| 25 |
+
from .esmfold2_constants_esm3 import CHAIN_BREAK_STR
|
| 26 |
+
from .esmfold2_utils_types import FunctionAnnotation
|
| 27 |
+
|
| 28 |
+
MAX_SUPPORTED_DISTANCE = 1e6
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
TSequence = TypeVar("TSequence", bound=Sequence)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@runtime_checkable
|
| 35 |
+
class Concatable(Protocol):
|
| 36 |
+
@classmethod
|
| 37 |
+
def concat(cls, objs: list[Concatable]) -> Concatable: ...
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def slice_python_object_as_numpy(
|
| 41 |
+
obj: TSequence, idx: int | list[int] | slice | np.ndarray
|
| 42 |
+
) -> TSequence:
|
| 43 |
+
"""
|
| 44 |
+
Slice a python object (like a list, string, or tuple) as if it was a numpy object.
|
| 45 |
+
|
| 46 |
+
Example:
|
| 47 |
+
>>> obj = "ABCDE"
|
| 48 |
+
>>> slice_python_object_as_numpy(obj, [1, 3, 4])
|
| 49 |
+
"BDE"
|
| 50 |
+
|
| 51 |
+
>>> obj = [1, 2, 3, 4, 5]
|
| 52 |
+
>>> slice_python_object_as_numpy(obj, np.arange(5) < 3)
|
| 53 |
+
[1, 2, 3]
|
| 54 |
+
"""
|
| 55 |
+
if np.isscalar(idx):
|
| 56 |
+
idx = [int(idx)] # type: ignore
|
| 57 |
+
|
| 58 |
+
if isinstance(idx, np.ndarray) and idx.dtype == bool:
|
| 59 |
+
sliced_obj = [obj[i] for i in np.where(idx)[0]]
|
| 60 |
+
elif isinstance(idx, slice):
|
| 61 |
+
sliced_obj = obj[idx]
|
| 62 |
+
else:
|
| 63 |
+
sliced_obj = [obj[i] for i in idx] # type: ignore
|
| 64 |
+
|
| 65 |
+
match obj, sliced_obj:
|
| 66 |
+
case str(), list():
|
| 67 |
+
sliced_obj = "".join(sliced_obj)
|
| 68 |
+
case _:
|
| 69 |
+
sliced_obj = obj.__class__(sliced_obj) # type: ignore
|
| 70 |
+
|
| 71 |
+
return sliced_obj # type: ignore
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def slice_any_object(
|
| 75 |
+
obj: TSequence, idx: int | list[int] | slice | np.ndarray
|
| 76 |
+
) -> TSequence:
|
| 77 |
+
"""
|
| 78 |
+
Slice a arbitrary object (like a list, string, or tuple) as if it was a numpy object. Similar to `slice_python_object_as_numpy`, but detects if it's a numpy array or Tensor and uses the existing slice method if so.
|
| 79 |
+
|
| 80 |
+
If the object is a dataclass, it will simply apply the index to the object, under the assumption that the object has correcty implemented numpy indexing.
|
| 81 |
+
|
| 82 |
+
Example:
|
| 83 |
+
>>> obj = "ABCDE"
|
| 84 |
+
>>> slice_any_object(obj, [1, 3, 4])
|
| 85 |
+
"BDE"
|
| 86 |
+
|
| 87 |
+
>>> obj = np.array([1, 2, 3, 4, 5])
|
| 88 |
+
>>> slice_any_object(obj, np.arange(5) < 3)
|
| 89 |
+
np.array([1, 2, 3])
|
| 90 |
+
|
| 91 |
+
>>> obj = ProteinChain.from_rcsb("1a3a", "A")
|
| 92 |
+
>>> slice_any_object(obj, np.arange(len(obj)) < 10)
|
| 93 |
+
# ProteinChain w/ length 10
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
if isinstance(obj, (np.ndarray, torch.Tensor)):
|
| 97 |
+
return obj[idx] # type: ignore
|
| 98 |
+
elif is_dataclass(obj):
|
| 99 |
+
# if passing a dataclass, assume it implements a custom slice
|
| 100 |
+
return obj[idx] # type: ignore
|
| 101 |
+
else:
|
| 102 |
+
return slice_python_object_as_numpy(obj, idx)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def rbf(values, v_min, v_max, n_bins=16):
|
| 106 |
+
"""
|
| 107 |
+
Returns RBF encodings in a new dimension at the end.
|
| 108 |
+
"""
|
| 109 |
+
rbf_centers = torch.linspace(
|
| 110 |
+
v_min, v_max, n_bins, device=values.device, dtype=values.dtype
|
| 111 |
+
)
|
| 112 |
+
rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
|
| 113 |
+
rbf_std = (v_max - v_min) / n_bins
|
| 114 |
+
z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
|
| 115 |
+
return torch.exp(-(z**2))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def batched_gather(data, inds, dim=0, no_batch_dims=0):
|
| 119 |
+
ranges = []
|
| 120 |
+
for i, s in enumerate(data.shape[:no_batch_dims]):
|
| 121 |
+
r = torch.arange(s)
|
| 122 |
+
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
|
| 123 |
+
ranges.append(r)
|
| 124 |
+
|
| 125 |
+
remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
|
| 126 |
+
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
|
| 127 |
+
ranges.extend(remaining_dims)
|
| 128 |
+
return data[ranges]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def knn_graph(
|
| 136 |
+
coords: torch.Tensor,
|
| 137 |
+
coord_mask: torch.Tensor,
|
| 138 |
+
padding_mask: torch.Tensor,
|
| 139 |
+
sequence_id: torch.Tensor,
|
| 140 |
+
*,
|
| 141 |
+
no_knn: int,
|
| 142 |
+
):
|
| 143 |
+
L = coords.shape[-2]
|
| 144 |
+
num_by_dist = min(no_knn, L)
|
| 145 |
+
device = coords.device
|
| 146 |
+
|
| 147 |
+
coords = coords.nan_to_num()
|
| 148 |
+
coord_mask = ~(coord_mask[..., None, :] & coord_mask[..., :, None])
|
| 149 |
+
padding_pairwise_mask = padding_mask[..., None, :] | padding_mask[..., :, None]
|
| 150 |
+
if sequence_id is not None:
|
| 151 |
+
padding_pairwise_mask |= torch.unsqueeze(sequence_id, 1) != torch.unsqueeze(
|
| 152 |
+
sequence_id, 2
|
| 153 |
+
)
|
| 154 |
+
dists = (coords.unsqueeze(-2) - coords.unsqueeze(-3)).norm(dim=-1)
|
| 155 |
+
arange = torch.arange(L, device=device)
|
| 156 |
+
seq_dists = (arange.unsqueeze(-1) - arange.unsqueeze(-2)).abs()
|
| 157 |
+
# We only support up to a certain distance, above that, we use sequence distance
|
| 158 |
+
# instead. This is so that when a large portion of the structure is masked out,
|
| 159 |
+
# the edges are built according to sequence distance.
|
| 160 |
+
max_dist = MAX_SUPPORTED_DISTANCE
|
| 161 |
+
if not (dists[~coord_mask] < max_dist).all():
|
| 162 |
+
raise ValueError(
|
| 163 |
+
f"Coordinate pairwise distances exceed max supported distance ({max_dist}). "
|
| 164 |
+
)
|
| 165 |
+
struct_then_seq_dist = (
|
| 166 |
+
seq_dists.to(dists.dtype)
|
| 167 |
+
.mul(1e2)
|
| 168 |
+
.add(max_dist)
|
| 169 |
+
.where(coord_mask, dists)
|
| 170 |
+
.masked_fill(padding_pairwise_mask, torch.inf)
|
| 171 |
+
)
|
| 172 |
+
dists, edges = struct_then_seq_dist.sort(dim=-1, descending=False)
|
| 173 |
+
# This is a L x L tensor, where we index by rows first,
|
| 174 |
+
# and columns are the edges we should pick.
|
| 175 |
+
chosen_edges = edges[..., :num_by_dist]
|
| 176 |
+
chosen_mask = dists[..., :num_by_dist].isfinite()
|
| 177 |
+
return chosen_edges, chosen_mask
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def stack_variable_length_tensors(
|
| 181 |
+
sequences: Sequence[torch.Tensor],
|
| 182 |
+
constant_value: int | float = 0,
|
| 183 |
+
dtype: torch.dtype | None = None,
|
| 184 |
+
) -> torch.Tensor:
|
| 185 |
+
"""Automatically stack tensors together, padding variable lengths with the
|
| 186 |
+
value in constant_value. Handles an arbitrary number of dimensions.
|
| 187 |
+
|
| 188 |
+
Examples:
|
| 189 |
+
>>> tensor1, tensor2 = torch.ones([2]), torch.ones([5])
|
| 190 |
+
>>> stack_variable_length_tensors(tensor1, tensor2)
|
| 191 |
+
tensor of shape [2, 5]. First row is [1, 1, 0, 0, 0]. Second row is all ones.
|
| 192 |
+
|
| 193 |
+
>>> tensor1, tensor2 = torch.ones([2, 4]), torch.ones([5, 3])
|
| 194 |
+
>>> stack_variable_length_tensors(tensor1, tensor2)
|
| 195 |
+
tensor of shape [2, 5, 4]
|
| 196 |
+
"""
|
| 197 |
+
batch_size = len(sequences)
|
| 198 |
+
shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist()
|
| 199 |
+
|
| 200 |
+
if dtype is None:
|
| 201 |
+
dtype = sequences[0].dtype
|
| 202 |
+
device = sequences[0].device
|
| 203 |
+
|
| 204 |
+
array = torch.full(shape, constant_value, dtype=dtype, device=device)
|
| 205 |
+
for arr, seq in zip(array, sequences):
|
| 206 |
+
arrslice = tuple(slice(dim) for dim in seq.shape)
|
| 207 |
+
arr[arrslice] = seq
|
| 208 |
+
|
| 209 |
+
return array
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def binpack(
|
| 213 |
+
tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float
|
| 214 |
+
):
|
| 215 |
+
"""
|
| 216 |
+
Args:
|
| 217 |
+
tensor (Tensor): [B, L, ...]
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
Tensor: [B_binpacked, L_binpacked, ...]
|
| 221 |
+
"""
|
| 222 |
+
if sequence_id is None:
|
| 223 |
+
return tensor
|
| 224 |
+
|
| 225 |
+
num_sequences = sequence_id.max(dim=-1).values + 1
|
| 226 |
+
|
| 227 |
+
dims = sequence_id.shape + tensor.shape[2:]
|
| 228 |
+
output_tensor = torch.full(
|
| 229 |
+
dims, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
idx = 0
|
| 233 |
+
for batch_idx, (batch_seqid, batch_num_sequences) in enumerate(
|
| 234 |
+
zip(sequence_id, num_sequences)
|
| 235 |
+
):
|
| 236 |
+
for seqid in range(batch_num_sequences):
|
| 237 |
+
mask = batch_seqid == seqid
|
| 238 |
+
output_tensor[batch_idx, mask] = tensor[idx, : mask.sum()]
|
| 239 |
+
idx += 1
|
| 240 |
+
return output_tensor
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def unbinpack(
|
| 244 |
+
tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float
|
| 245 |
+
):
|
| 246 |
+
"""
|
| 247 |
+
Args:
|
| 248 |
+
tensor (Tensor): [B, L, ...]
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Tensor: [B_unbinpacked, L_unbinpack, ...]
|
| 252 |
+
"""
|
| 253 |
+
if sequence_id is None:
|
| 254 |
+
return tensor
|
| 255 |
+
|
| 256 |
+
unpacked_tensors = []
|
| 257 |
+
num_sequences = sequence_id.max(dim=-1).values + 1
|
| 258 |
+
for batch_idx, (batch_seqid, batch_num_sequences) in enumerate(
|
| 259 |
+
zip(sequence_id, num_sequences)
|
| 260 |
+
):
|
| 261 |
+
for seqid in range(batch_num_sequences):
|
| 262 |
+
mask = batch_seqid == seqid
|
| 263 |
+
unpacked = tensor[batch_idx, mask]
|
| 264 |
+
unpacked_tensors.append(unpacked)
|
| 265 |
+
return stack_variable_length_tensors(unpacked_tensors, pad_value)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def fp32_autocast_context(device_type: str) -> ContextManager[Any]: # type: ignore
|
| 269 |
+
"""
|
| 270 |
+
Returns an autocast context manager that disables downcasting by AMP.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
device_type: The device type ('cpu' or 'cuda')
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
An autocast context manager with the specified behavior.
|
| 277 |
+
"""
|
| 278 |
+
if device_type == "cpu":
|
| 279 |
+
return torch.amp.autocast(device_type, enabled=False) # type: ignore
|
| 280 |
+
elif device_type == "mps":
|
| 281 |
+
# For MPS, just return a no-op context manager (nullcontext) since MPS does not support autocast.
|
| 282 |
+
return nullcontext()
|
| 283 |
+
elif device_type == "cuda":
|
| 284 |
+
return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore
|
| 285 |
+
else:
|
| 286 |
+
raise ValueError(f"Unsupported device type: {device_type}")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[range]:
|
| 290 |
+
"""Merge overlapping ranges into sorted, non-overlapping segments.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
ranges: collection of ranges to merge.
|
| 294 |
+
merge_gap_max: optionally merge neighboring ranges that are separated by a gap
|
| 295 |
+
no larger than this size.
|
| 296 |
+
Returns:
|
| 297 |
+
non-overlapping ranges merged from the inputs, sorted by position.
|
| 298 |
+
"""
|
| 299 |
+
ranges = sorted(ranges, key=lambda r: r.start)
|
| 300 |
+
merge_gap_max = merge_gap_max if merge_gap_max is not None else 0
|
| 301 |
+
assert merge_gap_max >= 0, f"Invalid merge_gap_max: {merge_gap_max}"
|
| 302 |
+
|
| 303 |
+
merged = []
|
| 304 |
+
for r in ranges:
|
| 305 |
+
if not merged:
|
| 306 |
+
merged.append(r)
|
| 307 |
+
else:
|
| 308 |
+
last = merged[-1]
|
| 309 |
+
if last.stop + merge_gap_max >= r.start:
|
| 310 |
+
merged[-1] = range(last.start, max(last.stop, r.stop))
|
| 311 |
+
else:
|
| 312 |
+
merged.append(r)
|
| 313 |
+
return merged
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def merge_annotations(
|
| 317 |
+
annotations: list[FunctionAnnotation], merge_gap_max: int | None = None
|
| 318 |
+
) -> list[FunctionAnnotation]:
|
| 319 |
+
"""Merges annotations into non-overlapping segments.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
annotations: annotations to merge.
|
| 323 |
+
merge_gap_max: optionally merge neighboring ranges that are separated by a gap
|
| 324 |
+
no larger than this size.
|
| 325 |
+
Returns:
|
| 326 |
+
non-overlapping annotations with gaps merged.
|
| 327 |
+
"""
|
| 328 |
+
grouped: dict[str, list[range]] = defaultdict(list)
|
| 329 |
+
for a in annotations:
|
| 330 |
+
# +1 since FunctionAnnotation.end is inlcusive.
|
| 331 |
+
grouped[a.label].append(range(a.start, a.end + 1))
|
| 332 |
+
|
| 333 |
+
merged = []
|
| 334 |
+
for label, ranges in grouped.items():
|
| 335 |
+
merged_ranges = merge_ranges(ranges, merge_gap_max=merge_gap_max)
|
| 336 |
+
for range_ in merged_ranges:
|
| 337 |
+
annotation = FunctionAnnotation(
|
| 338 |
+
label=label,
|
| 339 |
+
start=range_.start,
|
| 340 |
+
end=range_.stop - 1, # convert range.stop exclusive -> inclusive.
|
| 341 |
+
)
|
| 342 |
+
merged.append(annotation)
|
| 343 |
+
return merged
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def replace_inf(data):
|
| 347 |
+
if data is None:
|
| 348 |
+
return None
|
| 349 |
+
array = np.asarray(data, dtype=np.float32)
|
| 350 |
+
array = np.where(np.isinf(array), 1000, array)
|
| 351 |
+
return array.tolist()
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None:
|
| 355 |
+
if x is None:
|
| 356 |
+
return None
|
| 357 |
+
if isinstance(x, torch.Tensor):
|
| 358 |
+
return x
|
| 359 |
+
if isinstance(x, list) and all(isinstance(t, torch.Tensor) for t in x):
|
| 360 |
+
return torch.stack(x)
|
| 361 |
+
if convert_none_to_nan:
|
| 362 |
+
x = np.asarray(x, dtype=np.float32)
|
| 363 |
+
x = np.where(x is None, np.nan, x)
|
| 364 |
+
return torch.tensor(x)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def maybe_list(x, convert_nan_to_none: bool = False) -> list | None:
|
| 368 |
+
if x is None:
|
| 369 |
+
return None
|
| 370 |
+
if not convert_nan_to_none:
|
| 371 |
+
return x.tolist()
|
| 372 |
+
|
| 373 |
+
# Handle both torch.tensor and np.ndarray input.
|
| 374 |
+
if isinstance(x, torch.Tensor):
|
| 375 |
+
nan_mask = torch.isnan(x).cpu().numpy()
|
| 376 |
+
np_arr = x.cpu().numpy().astype(object)
|
| 377 |
+
elif isinstance(x, np.ndarray):
|
| 378 |
+
nan_mask = np.isnan(x)
|
| 379 |
+
np_arr = x.astype(object)
|
| 380 |
+
else:
|
| 381 |
+
raise TypeError("maybe_list can only work with torch.tensor or np.ndarray.")
|
| 382 |
+
|
| 383 |
+
np_arr[nan_mask] = None
|
| 384 |
+
return np_arr.tolist()
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def huggingfacehub_login():
|
| 388 |
+
"""Authenticates with the Hugging Face Hub using the HF_TOKEN environment
|
| 389 |
+
variable, else by prompting the user"""
|
| 390 |
+
token = os.environ.get("HF_TOKEN")
|
| 391 |
+
huggingface_hub.login(token=token)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def get_chainbreak_boundaries_from_sequence(sequence: Sequence[str]) -> np.ndarray:
|
| 395 |
+
chain_boundaries = [0]
|
| 396 |
+
for i, aa in enumerate(sequence):
|
| 397 |
+
if aa == CHAIN_BREAK_STR:
|
| 398 |
+
if i == (len(sequence) - 1):
|
| 399 |
+
raise ValueError(
|
| 400 |
+
"Encountered chain break token at end of sequence, this is unexpected."
|
| 401 |
+
)
|
| 402 |
+
if i == (len(sequence) - 2):
|
| 403 |
+
warn(
|
| 404 |
+
"Encountered chain break token at penultimate position, this is unexpected."
|
| 405 |
+
)
|
| 406 |
+
chain_boundaries.append(i)
|
| 407 |
+
chain_boundaries.append(i + 1)
|
| 408 |
+
chain_boundaries.append(len(sequence))
|
| 409 |
+
assert len(chain_boundaries) % 2 == 0
|
| 410 |
+
chain_boundaries = np.array(chain_boundaries).reshape(-1, 2)
|
| 411 |
+
return chain_boundaries
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def deserialize_tensors(b: bytes) -> Any:
|
| 415 |
+
buf = BytesIO(zstd.ZSTD_uncompress(b))
|
| 416 |
+
d = torch.load(buf, map_location="cpu", weights_only=False)
|
| 417 |
+
return d
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def join_lists(
|
| 421 |
+
lists: Sequence[Sequence[Any]], separator: Sequence[Any] | None = None
|
| 422 |
+
) -> list[Any]:
|
| 423 |
+
"""Joins multiple lists with separator element. Like str.join but for lists.
|
| 424 |
+
|
| 425 |
+
Example: [[1, 2], [3], [4]], separator=[0] -> [1, 2, 0, 3, 0, 4]
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
lists: Lists of elements to chain
|
| 429 |
+
separator: separators to intsert between chained output.
|
| 430 |
+
Returns:
|
| 431 |
+
Joined lists.
|
| 432 |
+
"""
|
| 433 |
+
if not lists:
|
| 434 |
+
return []
|
| 435 |
+
joined = []
|
| 436 |
+
joined.extend(lists[0])
|
| 437 |
+
for l in lists[1:]:
|
| 438 |
+
if separator:
|
| 439 |
+
joined.extend(separator)
|
| 440 |
+
joined.extend(l)
|
| 441 |
+
return joined
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def iterate_with_intermediate(
|
| 445 |
+
lists: Iterable, intermediate
|
| 446 |
+
) -> Generator[Any, None, None]:
|
| 447 |
+
"""
|
| 448 |
+
Iterate over the iterable, yielding the intermediate value between
|
| 449 |
+
every element of the intermediate. Useful for joining objects with
|
| 450 |
+
separator tokens.
|
| 451 |
+
"""
|
| 452 |
+
it = iter(lists)
|
| 453 |
+
yield next(it)
|
| 454 |
+
for l in it:
|
| 455 |
+
yield intermediate
|
| 456 |
+
yield l
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def concat_objects(objs: Sequence[Any], separator: Any | None = None):
|
| 460 |
+
"""
|
| 461 |
+
Concat objects with each other using a separator token.
|
| 462 |
+
|
| 463 |
+
Supports:
|
| 464 |
+
- Concatable (objects that implement `concat` classmethod)
|
| 465 |
+
- strings
|
| 466 |
+
- lists
|
| 467 |
+
- numpy arrays
|
| 468 |
+
- torch Tensors
|
| 469 |
+
|
| 470 |
+
Example:
|
| 471 |
+
>>> foo = "abc"
|
| 472 |
+
>>> bar = "def"
|
| 473 |
+
>>> concat_objects([foo, bar], "|")
|
| 474 |
+
"abc|def"
|
| 475 |
+
"""
|
| 476 |
+
match objs[0]:
|
| 477 |
+
case Concatable():
|
| 478 |
+
return objs[0].__class__.concat(objs) # type: ignore
|
| 479 |
+
case str():
|
| 480 |
+
assert isinstance(
|
| 481 |
+
separator, str
|
| 482 |
+
), "Trying to join strings but separator is not a string"
|
| 483 |
+
return separator.join(objs)
|
| 484 |
+
case list():
|
| 485 |
+
if separator is not None:
|
| 486 |
+
return join_lists(objs, [separator])
|
| 487 |
+
else:
|
| 488 |
+
return join_lists(objs)
|
| 489 |
+
case np.ndarray():
|
| 490 |
+
if separator is not None:
|
| 491 |
+
return np.concatenate(
|
| 492 |
+
list(iterate_with_intermediate(objs, np.array([separator])))
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
return np.concatenate(objs)
|
| 496 |
+
case torch.Tensor():
|
| 497 |
+
if separator is not None:
|
| 498 |
+
return torch.cat(
|
| 499 |
+
list(iterate_with_intermediate(objs, torch.tensor([separator])))
|
| 500 |
+
)
|
| 501 |
+
else:
|
| 502 |
+
return torch.cat(objs) # type: ignore
|
| 503 |
+
case _:
|
| 504 |
+
raise TypeError(type(objs[0]))
|
| 505 |
+
|
esmfold2_mmcif_parsing.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import functools
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import biotite.structure as bs
|
| 11 |
+
import biotite.structure.io.pdbx as pdbx
|
| 12 |
+
|
| 13 |
+
from . import esmfold2_residue_constants
|
| 14 |
+
|
| 15 |
+
# Define PathOrBuffer for the opensource version
|
| 16 |
+
PathOrBuffer = Union[str, os.PathLike, io.StringIO]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class NoProteinError(Exception):
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class Residue:
|
| 25 |
+
residue_number: int | None = None
|
| 26 |
+
insertion_code: str = ""
|
| 27 |
+
hetflag: bool = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class MmcifHeader:
|
| 32 |
+
release_date: datetime | None = None
|
| 33 |
+
resolution: float | None = None
|
| 34 |
+
structure_method: str = "UNKNOWN"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class MmcifWrapper:
|
| 38 |
+
def __init__(self, id: str | None = None):
|
| 39 |
+
self.id: str = id or ""
|
| 40 |
+
self.raw: pdbx.CIFFile | None = None
|
| 41 |
+
self.structure: bs.AtomArray
|
| 42 |
+
self.header: MmcifHeader = MmcifHeader()
|
| 43 |
+
self.entities: dict[int, list[str]] = {}
|
| 44 |
+
self.chain_to_seqres: dict[str, str] = {}
|
| 45 |
+
self.seqres_to_structure: dict[str, dict[int, Residue]] = {}
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def read(cls, path: PathOrBuffer, id: str | None = None) -> MmcifWrapper:
|
| 49 |
+
obj = cls(id=id)
|
| 50 |
+
obj._load(path)
|
| 51 |
+
return obj
|
| 52 |
+
|
| 53 |
+
def _load(self, path: PathOrBuffer, fileid: str | None = None):
|
| 54 |
+
"""Load mmCIF data from file."""
|
| 55 |
+
self.raw = pdbx.CIFFile.read(path)
|
| 56 |
+
|
| 57 |
+
self._parse_structure()
|
| 58 |
+
self._parse_header()
|
| 59 |
+
self._parse_entities()
|
| 60 |
+
self._parse_sequences()
|
| 61 |
+
|
| 62 |
+
def _parse_structure(self):
|
| 63 |
+
"""Parse the atomic structure from mmCIF."""
|
| 64 |
+
try:
|
| 65 |
+
structure = pdbx.get_structure(self.raw, model=1)
|
| 66 |
+
if structure is None or not isinstance(structure, bs.AtomArray):
|
| 67 |
+
raise NoProteinError("No structure found in mmCIF file")
|
| 68 |
+
if len(structure) == 0:
|
| 69 |
+
raise NoProteinError("Empty structure in mmCIF file")
|
| 70 |
+
self.structure = structure
|
| 71 |
+
except Exception as e:
|
| 72 |
+
raise ValueError(f"Failed to parse structure: {e}")
|
| 73 |
+
|
| 74 |
+
def _parse_header(self):
|
| 75 |
+
"""Parse header information from mmCIF."""
|
| 76 |
+
if not self.raw:
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
# Get the first (and usually only) block
|
| 81 |
+
block = self.raw.block
|
| 82 |
+
|
| 83 |
+
# Parse release date
|
| 84 |
+
if "pdbx_database_status" in block:
|
| 85 |
+
status_cat = block["pdbx_database_status"]
|
| 86 |
+
if "recvd_initial_deposition_date" in status_cat:
|
| 87 |
+
date_str = status_cat["recvd_initial_deposition_date"].as_item()
|
| 88 |
+
if date_str and date_str != "?":
|
| 89 |
+
try:
|
| 90 |
+
self.header.release_date = datetime.strptime(
|
| 91 |
+
date_str, "%Y-%m-%d"
|
| 92 |
+
)
|
| 93 |
+
except ValueError:
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
# Parse resolution
|
| 97 |
+
if "refine" in block:
|
| 98 |
+
refine_cat = block["refine"]
|
| 99 |
+
if "ls_d_res_high" in refine_cat:
|
| 100 |
+
res_str = refine_cat["ls_d_res_high"].as_item()
|
| 101 |
+
if res_str and res_str != "?":
|
| 102 |
+
try:
|
| 103 |
+
self.header.resolution = float(res_str)
|
| 104 |
+
except ValueError:
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
# Parse structure method
|
| 108 |
+
if "exptl" in block:
|
| 109 |
+
exptl_cat = block["exptl"]
|
| 110 |
+
if "method" in exptl_cat:
|
| 111 |
+
method = exptl_cat["method"].as_item()
|
| 112 |
+
if method and method != "?":
|
| 113 |
+
self.header.structure_method = method.upper()
|
| 114 |
+
|
| 115 |
+
except Exception:
|
| 116 |
+
# If parsing fails, keep default values
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
def _parse_entities(self):
|
| 120 |
+
"""Parse entity information and map to chains."""
|
| 121 |
+
if not self.raw:
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
block = self.raw.block
|
| 126 |
+
|
| 127 |
+
# Parse entity information
|
| 128 |
+
if "entity" in block:
|
| 129 |
+
entity_cat = block["entity"]
|
| 130 |
+
entity_ids = entity_cat["id"].as_array(str)
|
| 131 |
+
entity_types = entity_cat["type"].as_array(str)
|
| 132 |
+
|
| 133 |
+
# Initialize entities dict with all entities (not just polymers)
|
| 134 |
+
for i, (entity_id, entity_type) in enumerate(
|
| 135 |
+
zip(entity_ids, entity_types)
|
| 136 |
+
):
|
| 137 |
+
self.entities[int(entity_id)] = []
|
| 138 |
+
|
| 139 |
+
# Map polymer chains to entities using entity_poly
|
| 140 |
+
if "entity_poly" in block:
|
| 141 |
+
poly_cat = block["entity_poly"]
|
| 142 |
+
entity_ids = poly_cat["entity_id"].as_array(str)
|
| 143 |
+
chain_lists = poly_cat["pdbx_strand_id"].as_array(str)
|
| 144 |
+
|
| 145 |
+
for entity_id, chain_list in zip(entity_ids, chain_lists):
|
| 146 |
+
entity_id = int(entity_id)
|
| 147 |
+
# Chain list is comma-separated
|
| 148 |
+
chains = [c.strip() for c in chain_list.split(",") if c.strip()]
|
| 149 |
+
if entity_id in self.entities:
|
| 150 |
+
self.entities[entity_id] = chains
|
| 151 |
+
|
| 152 |
+
# Map non-polymer chains using struct_asym for entities not covered by entity_poly
|
| 153 |
+
if "struct_asym" in block:
|
| 154 |
+
asym_cat = block["struct_asym"]
|
| 155 |
+
asym_ids = asym_cat["id"].as_array(str)
|
| 156 |
+
entity_ids = asym_cat["entity_id"].as_array(str)
|
| 157 |
+
|
| 158 |
+
for asym_id, entity_id in zip(asym_ids, entity_ids):
|
| 159 |
+
entity_id = int(entity_id)
|
| 160 |
+
# Only add if entity exists but has no chains yet (non-polymer entities)
|
| 161 |
+
if entity_id in self.entities and not self.entities[entity_id]:
|
| 162 |
+
self.entities[entity_id].append(asym_id)
|
| 163 |
+
|
| 164 |
+
except Exception:
|
| 165 |
+
# If parsing fails, try to infer from structure
|
| 166 |
+
if (
|
| 167 |
+
self.structure
|
| 168 |
+
and hasattr(self.structure, "chain_id")
|
| 169 |
+
and self.structure.chain_id is not None
|
| 170 |
+
and hasattr(self.structure.chain_id, "__iter__")
|
| 171 |
+
):
|
| 172 |
+
chain_ids = list(set(self.structure.chain_id))
|
| 173 |
+
self.entities = {1: chain_ids}
|
| 174 |
+
|
| 175 |
+
def _parse_sequences(self):
|
| 176 |
+
"""Parse sequence information from mmCIF."""
|
| 177 |
+
if not self.raw:
|
| 178 |
+
return
|
| 179 |
+
|
| 180 |
+
block = self.raw.block
|
| 181 |
+
|
| 182 |
+
# Parse polymer sequences
|
| 183 |
+
if "entity_poly" in block:
|
| 184 |
+
poly_cat = block["entity_poly"]
|
| 185 |
+
entity_ids = poly_cat["entity_id"].as_array(str)
|
| 186 |
+
sequences = poly_cat["pdbx_seq_one_letter_code_can"].as_array(str)
|
| 187 |
+
chain_lists = poly_cat["pdbx_strand_id"].as_array(str)
|
| 188 |
+
|
| 189 |
+
for entity_id, sequence, chain_list in zip(
|
| 190 |
+
entity_ids, sequences, chain_lists
|
| 191 |
+
):
|
| 192 |
+
# Clean up sequence (remove whitespace and newlines)
|
| 193 |
+
clean_seq = "".join(sequence.split())
|
| 194 |
+
chains = [c.strip() for c in chain_list.split(",") if c.strip()]
|
| 195 |
+
|
| 196 |
+
for chain_id in chains:
|
| 197 |
+
self.chain_to_seqres[chain_id] = clean_seq
|
| 198 |
+
|
| 199 |
+
# Parse sequence to structure mapping
|
| 200 |
+
if "pdbx_poly_seq_scheme" in block:
|
| 201 |
+
seq_cat = block["pdbx_poly_seq_scheme"]
|
| 202 |
+
asym_ids = seq_cat["asym_id"].as_array(str) # Internal chain IDs
|
| 203 |
+
seq_positions = seq_cat["seq_id"].as_array(str)
|
| 204 |
+
auth_seq_nums = seq_cat["auth_seq_num"].as_array(str)
|
| 205 |
+
ins_codes = (
|
| 206 |
+
seq_cat["pdb_ins_code"].as_array(str)
|
| 207 |
+
if "pdb_ins_code" in seq_cat
|
| 208 |
+
else [""] * len(asym_ids)
|
| 209 |
+
)
|
| 210 |
+
hetflags = (
|
| 211 |
+
seq_cat["hetflag"].as_array(str)
|
| 212 |
+
if "hetflag" in seq_cat
|
| 213 |
+
else ["N"] * len(asym_ids)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Get author chain IDs if available
|
| 217 |
+
auth_chain_ids = (
|
| 218 |
+
seq_cat["pdb_strand_id"].as_array(str)
|
| 219 |
+
if "pdb_strand_id" in seq_cat
|
| 220 |
+
else asym_ids # Fallback to internal IDs
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Build mapping from internal chain ID to author chain ID
|
| 224 |
+
asym_to_auth_mapping = {}
|
| 225 |
+
for asym_id, auth_id in zip(asym_ids, auth_chain_ids):
|
| 226 |
+
asym_to_auth_mapping[asym_id] = auth_id
|
| 227 |
+
|
| 228 |
+
# Group by internal chain ID first, then map to author chain ID
|
| 229 |
+
chain_data = {}
|
| 230 |
+
for asym_id, seq_pos, auth_seq, ins_code, hetflag in zip(
|
| 231 |
+
asym_ids, seq_positions, auth_seq_nums, ins_codes, hetflags
|
| 232 |
+
):
|
| 233 |
+
if asym_id not in chain_data:
|
| 234 |
+
chain_data[asym_id] = {}
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
seq_index = int(seq_pos) - 1 # Convert to 0-based indexing
|
| 238 |
+
res_num = int(auth_seq) if auth_seq != "?" else None
|
| 239 |
+
except ValueError:
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
if res_num is not None:
|
| 243 |
+
# Convert mmCIF "." and "?" to empty string
|
| 244 |
+
clean_ins_code = "" if ins_code in [".", "?"] else ins_code
|
| 245 |
+
else:
|
| 246 |
+
clean_ins_code = ""
|
| 247 |
+
res_num = None
|
| 248 |
+
|
| 249 |
+
is_het = hetflag.upper() == "Y" # type: ignore
|
| 250 |
+
chain_data[asym_id][seq_index] = Residue(
|
| 251 |
+
residue_number=res_num,
|
| 252 |
+
insertion_code=clean_ins_code, # type: ignore
|
| 253 |
+
hetflag=is_het,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Handle cases where multiple residues have the same auth_seq_num
|
| 257 |
+
# by adjusting residue numbers to be unique within each chain
|
| 258 |
+
for asym_id, residue_data in chain_data.items():
|
| 259 |
+
# Check if there are duplicate residue numbers in this chain
|
| 260 |
+
positions_with_same_num = {}
|
| 261 |
+
for seq_idx, res_at_pos in residue_data.items():
|
| 262 |
+
if res_at_pos.residue_number is not None:
|
| 263 |
+
res_num = res_at_pos.residue_number
|
| 264 |
+
if res_num not in positions_with_same_num:
|
| 265 |
+
positions_with_same_num[res_num] = []
|
| 266 |
+
positions_with_same_num[res_num].append(seq_idx)
|
| 267 |
+
|
| 268 |
+
# Fix duplicate residue numbers by making them sequential
|
| 269 |
+
for res_num, seq_indices in positions_with_same_num.items():
|
| 270 |
+
if len(seq_indices) > 1:
|
| 271 |
+
# Multiple residues have the same residue number
|
| 272 |
+
# Make them sequential starting from the original number
|
| 273 |
+
seq_indices.sort() # Ensure consistent ordering
|
| 274 |
+
for i, seq_idx in enumerate(seq_indices):
|
| 275 |
+
original_pos = residue_data[seq_idx]
|
| 276 |
+
new_pos = Residue(
|
| 277 |
+
residue_number=res_num + i,
|
| 278 |
+
insertion_code=original_pos.insertion_code,
|
| 279 |
+
hetflag=original_pos.hetflag,
|
| 280 |
+
)
|
| 281 |
+
residue_data[seq_idx] = new_pos
|
| 282 |
+
|
| 283 |
+
# Create ordered mappings using author chain IDs
|
| 284 |
+
for asym_id in chain_data:
|
| 285 |
+
auth_chain_id = asym_to_auth_mapping.get(asym_id, asym_id)
|
| 286 |
+
if auth_chain_id in self.chain_to_seqres:
|
| 287 |
+
seq_len = len(self.chain_to_seqres[auth_chain_id])
|
| 288 |
+
ordered_mapping = {}
|
| 289 |
+
|
| 290 |
+
for i in range(seq_len):
|
| 291 |
+
if i in chain_data[asym_id]:
|
| 292 |
+
ordered_mapping[i] = chain_data[asym_id][i]
|
| 293 |
+
else:
|
| 294 |
+
# Missing residue - no structure coordinates
|
| 295 |
+
ordered_mapping[i] = Residue(
|
| 296 |
+
residue_number=None, insertion_code="", hetflag=False
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
self.seqres_to_structure[auth_chain_id] = ordered_mapping
|
| 300 |
+
else:
|
| 301 |
+
# Handle case where auth_chain_id is not in chain_to_seqres
|
| 302 |
+
# This can happen if the chain is not a polymer or if there's a parsing issue
|
| 303 |
+
# Create a basic mapping based on the chain_data
|
| 304 |
+
if chain_data[asym_id]:
|
| 305 |
+
# Sort by sequence index to create ordered mapping
|
| 306 |
+
sorted_indices = sorted(chain_data[asym_id].keys())
|
| 307 |
+
ordered_mapping = {}
|
| 308 |
+
for i, seq_idx in enumerate(sorted_indices):
|
| 309 |
+
ordered_mapping[i] = chain_data[asym_id][seq_idx]
|
| 310 |
+
self.seqres_to_structure[auth_chain_id] = ordered_mapping
|
| 311 |
+
|
| 312 |
+
# Ensure all chains have complete mappings
|
| 313 |
+
for chain_id in self.chain_to_seqres:
|
| 314 |
+
if chain_id not in self.seqres_to_structure:
|
| 315 |
+
seq_len = len(self.chain_to_seqres[chain_id])
|
| 316 |
+
self.seqres_to_structure[chain_id] = {
|
| 317 |
+
i: Residue(residue_number=None, insertion_code="", hetflag=False)
|
| 318 |
+
for i in range(seq_len)
|
| 319 |
+
}
|
| 320 |
+
else:
|
| 321 |
+
# Fill in any missing indices
|
| 322 |
+
seq_len = len(self.chain_to_seqres[chain_id])
|
| 323 |
+
mapping = self.seqres_to_structure[chain_id]
|
| 324 |
+
for i in range(seq_len):
|
| 325 |
+
if i not in mapping:
|
| 326 |
+
mapping[i] = Residue(
|
| 327 |
+
residue_number=None, insertion_code="", hetflag=False
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Fallback: create basic mappings from structure for missing chains
|
| 331 |
+
if (
|
| 332 |
+
self.structure
|
| 333 |
+
and hasattr(self.structure, "chain_id")
|
| 334 |
+
and self.structure.chain_id is not None
|
| 335 |
+
and hasattr(self.structure.chain_id, "__iter__")
|
| 336 |
+
):
|
| 337 |
+
for chain_id in set(self.structure.chain_id):
|
| 338 |
+
if chain_id not in self.seqres_to_structure:
|
| 339 |
+
chain_structure = self.structure[
|
| 340 |
+
self.structure.chain_id == chain_id
|
| 341 |
+
]
|
| 342 |
+
if (
|
| 343 |
+
hasattr(chain_structure, "res_id")
|
| 344 |
+
and chain_structure.res_id is not None
|
| 345 |
+
and hasattr(chain_structure.res_id, "__iter__")
|
| 346 |
+
):
|
| 347 |
+
residue_ids = list(set(chain_structure.res_id))
|
| 348 |
+
residue_ids.sort()
|
| 349 |
+
|
| 350 |
+
self.seqres_to_structure[chain_id] = {
|
| 351 |
+
i: Residue(
|
| 352 |
+
residue_number=res_id, insertion_code="", hetflag=False
|
| 353 |
+
)
|
| 354 |
+
for i, res_id in enumerate(residue_ids)
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
def _parse_nonpoly_from_mmcif(self) -> dict[tuple, bs.AtomArray]:
|
| 358 |
+
"""Parse non-polymer coordinates from mmCIF block data."""
|
| 359 |
+
nonpoly_coords = {}
|
| 360 |
+
|
| 361 |
+
# Get non-polymer entities from the mmCIF block
|
| 362 |
+
assert self.raw is not None
|
| 363 |
+
block = self.raw.block
|
| 364 |
+
nonpoly_entities = set()
|
| 365 |
+
|
| 366 |
+
# Find non-polymer entities
|
| 367 |
+
if "entity" in block:
|
| 368 |
+
entity_cat = block["entity"]
|
| 369 |
+
entity_ids = entity_cat["id"].as_array(str)
|
| 370 |
+
entity_types = entity_cat["type"].as_array(str)
|
| 371 |
+
|
| 372 |
+
for entity_id, entity_type in zip(entity_ids, entity_types):
|
| 373 |
+
if entity_type.upper() in ["NON-POLYMER", "WATER", "BRANCHED"]:
|
| 374 |
+
nonpoly_entities.add(entity_id)
|
| 375 |
+
|
| 376 |
+
# Map entities to chains for non-polymers
|
| 377 |
+
entity_to_chains = {}
|
| 378 |
+
if "pdbx_entity_nonpoly" in block:
|
| 379 |
+
nonpoly_cat = block["pdbx_entity_nonpoly"]
|
| 380 |
+
entity_ids = nonpoly_cat["entity_id"].as_array(str)
|
| 381 |
+
comp_ids = nonpoly_cat["comp_id"].as_array(str)
|
| 382 |
+
|
| 383 |
+
for entity_id, comp_id in zip(entity_ids, comp_ids):
|
| 384 |
+
if entity_id in nonpoly_entities:
|
| 385 |
+
entity_to_chains[entity_id] = comp_id
|
| 386 |
+
|
| 387 |
+
# Get atom site information for non-polymers
|
| 388 |
+
if "atom_site" in block:
|
| 389 |
+
atom_cat = block["atom_site"]
|
| 390 |
+
atom_chain_ids = atom_cat["label_asym_id"].as_array(str)
|
| 391 |
+
atom_entity_ids = atom_cat["label_entity_id"].as_array(str)
|
| 392 |
+
atom_comp_ids = atom_cat["label_comp_id"].as_array(str)
|
| 393 |
+
|
| 394 |
+
# Group non-polymer atoms by entity and chain
|
| 395 |
+
nonpoly_atom_groups = {}
|
| 396 |
+
for i, (chain_id, entity_id, comp_id) in enumerate(
|
| 397 |
+
zip(atom_chain_ids, atom_entity_ids, atom_comp_ids)
|
| 398 |
+
):
|
| 399 |
+
if entity_id in nonpoly_entities:
|
| 400 |
+
key = (comp_id, chain_id)
|
| 401 |
+
if key not in nonpoly_atom_groups:
|
| 402 |
+
nonpoly_atom_groups[key] = []
|
| 403 |
+
nonpoly_atom_groups[key].append(i)
|
| 404 |
+
|
| 405 |
+
# Extract coordinates for each non-polymer group
|
| 406 |
+
for (comp_id, chain_id), atom_indices in nonpoly_atom_groups.items():
|
| 407 |
+
# Match atoms by comparing chain_id and residue name
|
| 408 |
+
structure_mask = (self.structure.chain_id == chain_id) & (
|
| 409 |
+
self.structure.res_name == comp_id
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
if structure_mask.any():
|
| 413 |
+
nonpoly_array = self.structure[structure_mask]
|
| 414 |
+
if (
|
| 415 |
+
isinstance(nonpoly_array, (bs.AtomArray, bs.AtomArrayStack))
|
| 416 |
+
and len(nonpoly_array) > 0
|
| 417 |
+
):
|
| 418 |
+
nonpoly_coords[(comp_id, chain_id)] = nonpoly_array
|
| 419 |
+
|
| 420 |
+
return nonpoly_coords
|
| 421 |
+
|
| 422 |
+
def _parse_nonpoly_fallback(self) -> dict[tuple, bs.AtomArray]:
|
| 423 |
+
"""Fallback method to extract heteroatoms directly from structure."""
|
| 424 |
+
nonpoly_coords = {}
|
| 425 |
+
|
| 426 |
+
if not (self.structure and hasattr(self.structure, "chain_id")):
|
| 427 |
+
return nonpoly_coords
|
| 428 |
+
|
| 429 |
+
# Create set of standard residues from residue_constants
|
| 430 |
+
standard_residues = set(residue_constants.resnames[:-1]) # Exclude 'UNK'
|
| 431 |
+
standard_residues.update({"A", "C", "G", "T", "U"}) # Add nucleic acids
|
| 432 |
+
|
| 433 |
+
if hasattr(self.structure, "chain_id") and self.structure.chain_id is not None:
|
| 434 |
+
for chain_id in set(self.structure.chain_id):
|
| 435 |
+
chain_structure = self.structure[self.structure.chain_id == chain_id]
|
| 436 |
+
|
| 437 |
+
# Find non-standard residues
|
| 438 |
+
if (
|
| 439 |
+
hasattr(chain_structure, "res_name")
|
| 440 |
+
and chain_structure.res_name is not None
|
| 441 |
+
and hasattr(chain_structure.res_name, "__iter__")
|
| 442 |
+
):
|
| 443 |
+
for res_name in set(chain_structure.res_name):
|
| 444 |
+
if res_name not in standard_residues:
|
| 445 |
+
res_mask = (chain_structure.chain_id == chain_id) & (
|
| 446 |
+
chain_structure.res_name == res_name
|
| 447 |
+
)
|
| 448 |
+
if res_mask.any() and isinstance(
|
| 449 |
+
chain_structure, (bs.AtomArray, bs.AtomArrayStack)
|
| 450 |
+
):
|
| 451 |
+
nonpoly_array = chain_structure[res_mask]
|
| 452 |
+
nonpoly_coords[(res_name, chain_id)] = nonpoly_array
|
| 453 |
+
|
| 454 |
+
return nonpoly_coords
|
| 455 |
+
|
| 456 |
+
@functools.cached_property
|
| 457 |
+
def non_polymer_coords(self) -> dict[tuple, bs.AtomArray]:
|
| 458 |
+
"""
|
| 459 |
+
Extract non-polymer coordinates (ligands, cofactors, etc.) from mmCIF structure.
|
| 460 |
+
|
| 461 |
+
Returns a dictionary mapping (nonpolymer_info, chain_id) tuples to AtomArrays.
|
| 462 |
+
"""
|
| 463 |
+
if not self.structure or not self.raw:
|
| 464 |
+
return {}
|
| 465 |
+
|
| 466 |
+
try:
|
| 467 |
+
return self._parse_nonpoly_from_mmcif()
|
| 468 |
+
except Exception:
|
| 469 |
+
return self._parse_nonpoly_fallback()
|
| 470 |
+
|
esmfold2_molecular_complex.py
ADDED
|
@@ -0,0 +1,1226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import io
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
from dataclasses import asdict, dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from subprocess import check_output
|
| 9 |
+
from tempfile import TemporaryDirectory
|
| 10 |
+
from typing import TYPE_CHECKING, Any
|
| 11 |
+
|
| 12 |
+
import biotite.structure as bs
|
| 13 |
+
import biotite.structure.io.pdbx as pdbx
|
| 14 |
+
import brotli
|
| 15 |
+
import msgpack
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from biotite.structure.io.pdbx import (
|
| 19 |
+
CIFCategory,
|
| 20 |
+
CIFColumn,
|
| 21 |
+
CIFData,
|
| 22 |
+
CIFFile,
|
| 23 |
+
set_structure,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from . import esmfold2_residue_constants
|
| 27 |
+
from .esmfold2_metrics import compute_lddt, compute_rmsd
|
| 28 |
+
from .esmfold2_protein_complex import ProteinComplex, ProteinComplexMetadata
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class MolecularComplexResult:
|
| 33 |
+
"""Result of molecular complex folding"""
|
| 34 |
+
|
| 35 |
+
complex: MolecularComplex
|
| 36 |
+
plddt: torch.Tensor | None = None
|
| 37 |
+
ptm: float | None = None
|
| 38 |
+
iptm: float | None = None
|
| 39 |
+
pae: torch.Tensor | None = None
|
| 40 |
+
distogram: torch.Tensor | None = None
|
| 41 |
+
pair_chains_iptm: torch.Tensor | None = None
|
| 42 |
+
output_embedding_sequence: torch.Tensor | None = None
|
| 43 |
+
output_embedding_pair_pooled: torch.Tensor | None = None
|
| 44 |
+
residue_index: torch.Tensor | None = None
|
| 45 |
+
entity_id: torch.Tensor | None = None
|
| 46 |
+
sae_features: np.ndarray | None = None # [L, n_features]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class MolecularComplexMetadata:
|
| 51 |
+
"""Metadata for MolecularComplex objects."""
|
| 52 |
+
|
| 53 |
+
entity_lookup: dict[int, str]
|
| 54 |
+
chain_lookup: dict[int, str]
|
| 55 |
+
assembly_composition: dict[str, list[str]] | None = None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@dataclass
|
| 59 |
+
class Molecule:
|
| 60 |
+
"""Represents a single molecule/token within a MolecularComplex."""
|
| 61 |
+
|
| 62 |
+
token: str
|
| 63 |
+
token_idx: int
|
| 64 |
+
atom_positions: np.ndarray # [N_atoms, 3]
|
| 65 |
+
atom_elements: np.ndarray # [N_atoms] element strings
|
| 66 |
+
atom_names: np.ndarray | None = None # [N_atoms] atom names (optional)
|
| 67 |
+
atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional)
|
| 68 |
+
residue_type: int = 0
|
| 69 |
+
molecule_type: int = 0 # PROTEIN=0, RNA=1, DNA=2, LIGAND=3
|
| 70 |
+
confidence: float = 0.0
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@dataclass(frozen=True)
|
| 74 |
+
class MolecularComplex:
|
| 75 |
+
"""
|
| 76 |
+
Dataclass representing a molecular complex with support for proteins, nucleic acids, and ligands.
|
| 77 |
+
|
| 78 |
+
Uses a flat atom representation with token-based sequence indexing, supporting all atom types
|
| 79 |
+
beyond the traditional atom37 protein representation.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
id: str
|
| 83 |
+
sequence: list[str] # Token sequence like ['MET', 'LYS', 'A', 'G', 'ATP']
|
| 84 |
+
|
| 85 |
+
# Flat atom arrays - simplified representation
|
| 86 |
+
atom_positions: np.ndarray # [N_atoms, 3] 3D coordinates
|
| 87 |
+
atom_elements: np.ndarray # [N_atoms] element strings
|
| 88 |
+
|
| 89 |
+
# Token-to-atom mapping for efficient access
|
| 90 |
+
token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array
|
| 91 |
+
|
| 92 |
+
# Chain information
|
| 93 |
+
chain_id: np.ndarray # [N_tokens] chain identifier for each token
|
| 94 |
+
|
| 95 |
+
# Confidence data
|
| 96 |
+
plddt: np.ndarray # Per-token confidence scores [N_tokens]
|
| 97 |
+
|
| 98 |
+
# Metadata
|
| 99 |
+
metadata: MolecularComplexMetadata
|
| 100 |
+
|
| 101 |
+
# Optional atom names and hetero flags (preserved from original structures)
|
| 102 |
+
atom_names: np.ndarray | None = None # [N_atoms] atom names (optional)
|
| 103 |
+
atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional)
|
| 104 |
+
|
| 105 |
+
def __post_init__(self):
|
| 106 |
+
"""Validate array dimensions."""
|
| 107 |
+
n_tokens = len(self.sequence)
|
| 108 |
+
n_atoms = len(self.atom_positions)
|
| 109 |
+
assert (
|
| 110 |
+
self.token_to_atoms.shape[0] == n_tokens
|
| 111 |
+
), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens"
|
| 112 |
+
assert (
|
| 113 |
+
self.chain_id.shape[0] == n_tokens
|
| 114 |
+
), f"chain_id shape {self.chain_id.shape} != {n_tokens} tokens"
|
| 115 |
+
assert (
|
| 116 |
+
self.plddt.shape[0] == n_tokens
|
| 117 |
+
), f"plddt shape {self.plddt.shape} != {n_tokens} tokens"
|
| 118 |
+
if self.atom_names is not None:
|
| 119 |
+
assert (
|
| 120 |
+
self.atom_names.shape[0] == n_atoms
|
| 121 |
+
), f"atom_names shape {self.atom_names.shape} != {n_atoms} atoms"
|
| 122 |
+
if self.atom_hetero is not None:
|
| 123 |
+
assert (
|
| 124 |
+
self.atom_hetero.shape[0] == n_atoms
|
| 125 |
+
), f"atom_hetero shape {self.atom_hetero.shape} != {n_atoms} atoms"
|
| 126 |
+
|
| 127 |
+
def __len__(self) -> int:
|
| 128 |
+
"""Return number of tokens."""
|
| 129 |
+
return len(self.sequence)
|
| 130 |
+
|
| 131 |
+
def __getitem__(self, idx: int) -> Molecule:
|
| 132 |
+
"""Access individual molecules/tokens by index."""
|
| 133 |
+
if idx >= len(self.sequence) or idx < 0:
|
| 134 |
+
raise IndexError(
|
| 135 |
+
f"Token index {idx} out of range for {len(self.sequence)} tokens"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
token = self.sequence[idx]
|
| 139 |
+
start_atom, end_atom = self.token_to_atoms[idx]
|
| 140 |
+
|
| 141 |
+
# Extract atom data for this token
|
| 142 |
+
token_atom_positions = self.atom_positions[start_atom:end_atom]
|
| 143 |
+
token_atom_elements = self.atom_elements[start_atom:end_atom]
|
| 144 |
+
token_atom_names = None
|
| 145 |
+
if self.atom_names is not None:
|
| 146 |
+
token_atom_names = self.atom_names[start_atom:end_atom]
|
| 147 |
+
token_atom_hetero = None
|
| 148 |
+
if self.atom_hetero is not None:
|
| 149 |
+
token_atom_hetero = self.atom_hetero[start_atom:end_atom]
|
| 150 |
+
|
| 151 |
+
# Default values for residue/molecule type (would be extended based on actual implementation)
|
| 152 |
+
residue_type = 0 # Default to standard residue
|
| 153 |
+
molecule_type = 0 # Default to protein
|
| 154 |
+
|
| 155 |
+
return Molecule(
|
| 156 |
+
token=token,
|
| 157 |
+
token_idx=idx,
|
| 158 |
+
atom_positions=token_atom_positions,
|
| 159 |
+
atom_elements=token_atom_elements,
|
| 160 |
+
atom_names=token_atom_names,
|
| 161 |
+
atom_hetero=token_atom_hetero,
|
| 162 |
+
residue_type=residue_type,
|
| 163 |
+
molecule_type=molecule_type,
|
| 164 |
+
confidence=self.plddt[idx],
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
@property
|
| 168 |
+
def atom_coordinates(self) -> np.ndarray:
|
| 169 |
+
"""Get flat array of all atom coordinates [N_atoms, 3]."""
|
| 170 |
+
return self.atom_positions
|
| 171 |
+
|
| 172 |
+
# Conversion methods
|
| 173 |
+
@classmethod
|
| 174 |
+
def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex":
|
| 175 |
+
"""Convert a ProteinComplex to MolecularComplex.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
pc: ProteinComplex object with atom37 representation
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
MolecularComplex with flat atom arrays and token-based indexing
|
| 182 |
+
"""
|
| 183 |
+
from . import esmfold2_residue_constants
|
| 184 |
+
|
| 185 |
+
# Extract sequence without chain breaks
|
| 186 |
+
sequence_no_breaks = pc.sequence.replace("|", "")
|
| 187 |
+
sequence_tokens = [
|
| 188 |
+
residue_constants.restype_1to3.get(aa, "UNK") for aa in sequence_no_breaks
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
# Convert atom37 to flat arrays
|
| 192 |
+
flat_positions = []
|
| 193 |
+
flat_elements = []
|
| 194 |
+
flat_names = []
|
| 195 |
+
flat_hetero = []
|
| 196 |
+
token_to_atoms = []
|
| 197 |
+
|
| 198 |
+
atom_idx = 0
|
| 199 |
+
|
| 200 |
+
for i, aa in enumerate(pc.sequence):
|
| 201 |
+
if aa == "|":
|
| 202 |
+
# Skip chain break tokens
|
| 203 |
+
continue
|
| 204 |
+
|
| 205 |
+
# Get atom37 positions and mask for this residue.
|
| 206 |
+
# ProteinComplex arrays are indexed by sequence position (including |),
|
| 207 |
+
# so use `i` not a separate residue counter.
|
| 208 |
+
res_positions = pc.atom37_positions[i] # [37, 3]
|
| 209 |
+
res_mask = pc.atom37_mask[i] # [37]
|
| 210 |
+
|
| 211 |
+
# Track start position for this token
|
| 212 |
+
token_start = atom_idx
|
| 213 |
+
|
| 214 |
+
# Process each atom type in atom37 representation
|
| 215 |
+
for atom_type_idx, atom_name in enumerate(residue_constants.atom_types):
|
| 216 |
+
if res_mask[atom_type_idx]: # Atom is present
|
| 217 |
+
# Add position
|
| 218 |
+
flat_positions.append(res_positions[atom_type_idx])
|
| 219 |
+
|
| 220 |
+
# Determine element from atom name
|
| 221 |
+
element = (
|
| 222 |
+
atom_name[0] if atom_name else "C"
|
| 223 |
+
) # First character is element
|
| 224 |
+
flat_elements.append(element)
|
| 225 |
+
|
| 226 |
+
# Add atom name
|
| 227 |
+
flat_names.append(atom_name)
|
| 228 |
+
|
| 229 |
+
# Add hetero flag (all proteins are non-hetero)
|
| 230 |
+
flat_hetero.append(False)
|
| 231 |
+
|
| 232 |
+
atom_idx += 1
|
| 233 |
+
|
| 234 |
+
# Record token-to-atom mapping [start_idx, end_idx)
|
| 235 |
+
token_to_atoms.append([token_start, atom_idx])
|
| 236 |
+
|
| 237 |
+
# Convert to numpy arrays
|
| 238 |
+
atom_positions = np.array(flat_positions, dtype=np.float32)
|
| 239 |
+
atom_elements = np.array(flat_elements, dtype=object)
|
| 240 |
+
atom_names = np.array(flat_names, dtype=object)
|
| 241 |
+
atom_hetero = np.array(flat_hetero, dtype=bool)
|
| 242 |
+
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
|
| 243 |
+
|
| 244 |
+
# Extract confidence scores and chain_ids (skip chain breaks)
|
| 245 |
+
confidence_scores = []
|
| 246 |
+
chain_ids = []
|
| 247 |
+
for seq_idx, aa in enumerate(pc.sequence):
|
| 248 |
+
if aa != "|":
|
| 249 |
+
confidence_scores.append(pc.confidence[seq_idx])
|
| 250 |
+
chain_ids.append(pc.chain_id[seq_idx])
|
| 251 |
+
|
| 252 |
+
confidence_array = np.array(confidence_scores, dtype=np.float32)
|
| 253 |
+
chain_id_array = np.array(chain_ids, dtype=np.int64)
|
| 254 |
+
|
| 255 |
+
# Create metadata - convert entity IDs to strings for MolecularComplexMetadata
|
| 256 |
+
entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()}
|
| 257 |
+
metadata = MolecularComplexMetadata(
|
| 258 |
+
entity_lookup=entity_lookup_str,
|
| 259 |
+
chain_lookup=pc.metadata.chain_lookup,
|
| 260 |
+
assembly_composition=pc.metadata.assembly_composition,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
return cls(
|
| 264 |
+
id=pc.id,
|
| 265 |
+
sequence=sequence_tokens,
|
| 266 |
+
atom_positions=atom_positions,
|
| 267 |
+
atom_elements=atom_elements,
|
| 268 |
+
token_to_atoms=token_to_atoms_array,
|
| 269 |
+
chain_id=chain_id_array,
|
| 270 |
+
plddt=confidence_array,
|
| 271 |
+
metadata=metadata,
|
| 272 |
+
atom_names=atom_names,
|
| 273 |
+
atom_hetero=atom_hetero,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
def to_protein_complex(self) -> ProteinComplex:
|
| 277 |
+
"""Convert MolecularComplex back to ProteinComplex format.
|
| 278 |
+
|
| 279 |
+
Extracts only protein tokens and converts from flat atom representation
|
| 280 |
+
back to atom37 format used by ProteinComplex.
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
ProteinComplex with protein residues only, excluding ligands/nucleic acids
|
| 284 |
+
"""
|
| 285 |
+
from . import esmfold2_residue_constants
|
| 286 |
+
|
| 287 |
+
# No need for element mapping - already using element characters
|
| 288 |
+
|
| 289 |
+
# Filter for protein tokens only (skip ligands, nucleic acids)
|
| 290 |
+
protein_tokens = []
|
| 291 |
+
protein_indices = []
|
| 292 |
+
|
| 293 |
+
for i, token in enumerate(self.sequence):
|
| 294 |
+
# Check if token is a standard 3-letter amino acid code
|
| 295 |
+
if token in residue_constants.restype_3to1:
|
| 296 |
+
protein_tokens.append(token)
|
| 297 |
+
protein_indices.append(i)
|
| 298 |
+
|
| 299 |
+
if not protein_tokens:
|
| 300 |
+
raise ValueError("No protein tokens found in MolecularComplex")
|
| 301 |
+
|
| 302 |
+
n_residues = len(protein_tokens)
|
| 303 |
+
|
| 304 |
+
# Initialize atom37 arrays
|
| 305 |
+
atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32)
|
| 306 |
+
atom37_mask = np.zeros((n_residues, 37), dtype=bool)
|
| 307 |
+
|
| 308 |
+
# Extract confidence scores and chain_ids for protein residues only
|
| 309 |
+
protein_confidence = self.plddt[protein_indices]
|
| 310 |
+
protein_chain_ids = self.chain_id[protein_indices]
|
| 311 |
+
|
| 312 |
+
# Convert tokens back to single-letter sequence with chain breaks
|
| 313 |
+
single_letter_residues = []
|
| 314 |
+
prev_chain_id = None
|
| 315 |
+
|
| 316 |
+
for i, (token, chain_id_val) in enumerate(
|
| 317 |
+
zip(protein_tokens, protein_chain_ids)
|
| 318 |
+
):
|
| 319 |
+
# Add chain break if we're switching to a new chain
|
| 320 |
+
if prev_chain_id is not None and chain_id_val != prev_chain_id:
|
| 321 |
+
single_letter_residues.append("|")
|
| 322 |
+
single_letter_residues.append(residue_constants.restype_3to1[token])
|
| 323 |
+
prev_chain_id = chain_id_val
|
| 324 |
+
|
| 325 |
+
single_letter_sequence = "".join(single_letter_residues)
|
| 326 |
+
|
| 327 |
+
# Calculate final sequence length (includes chain breaks)
|
| 328 |
+
sequence_length = len(single_letter_sequence)
|
| 329 |
+
|
| 330 |
+
# Convert flat atoms back to atom37 representation using atom names
|
| 331 |
+
for res_idx, token_idx in enumerate(protein_indices):
|
| 332 |
+
token = self.sequence[token_idx]
|
| 333 |
+
start_atom, end_atom = self.token_to_atoms[token_idx]
|
| 334 |
+
|
| 335 |
+
res_atom_positions = self.atom_positions[start_atom:end_atom]
|
| 336 |
+
res_atom_names = (
|
| 337 |
+
np.array(self.atom_names[start_atom:end_atom], dtype=str)
|
| 338 |
+
if self.atom_names is not None
|
| 339 |
+
else np.array([], dtype=str)
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# Build a mapping from normalized atom name -> position for this residue
|
| 343 |
+
# Normalize to uppercase and strip whitespace for robust matching
|
| 344 |
+
name_to_pos: dict[str, np.ndarray] = {}
|
| 345 |
+
for i, nm in enumerate(res_atom_names):
|
| 346 |
+
key = nm.upper().strip()
|
| 347 |
+
# Prefer first occurrence; ignore duplicates/altlocs
|
| 348 |
+
if key not in name_to_pos:
|
| 349 |
+
name_to_pos[key] = res_atom_positions[i]
|
| 350 |
+
|
| 351 |
+
# Place atoms into atom37 by matching stored atom names to atom37 indices.
|
| 352 |
+
# This handles all atoms present in the flat representation, not just
|
| 353 |
+
# the canonical residue_atoms for this residue type. This preserves
|
| 354 |
+
# atoms that were in the original atom37_mask even if they're atypical
|
| 355 |
+
# for the residue (e.g., from alternate conformations or data quirks).
|
| 356 |
+
for atom_name_str, pos in name_to_pos.items():
|
| 357 |
+
idx37 = residue_constants.atom_order.get(atom_name_str)
|
| 358 |
+
if idx37 is not None:
|
| 359 |
+
atom37_positions[res_idx, idx37] = pos
|
| 360 |
+
atom37_mask[res_idx, idx37] = True
|
| 361 |
+
|
| 362 |
+
# Create arrays that match sequence length (including chain breaks)
|
| 363 |
+
# Initialize arrays with proper size
|
| 364 |
+
chain_id_expanded = np.full(sequence_length, -1, dtype=np.int64)
|
| 365 |
+
entity_id_expanded = np.full(sequence_length, -1, dtype=np.int64)
|
| 366 |
+
sym_id_expanded = np.zeros(sequence_length, dtype=np.int64)
|
| 367 |
+
residue_index_expanded = np.zeros(sequence_length, dtype=np.int64)
|
| 368 |
+
insertion_code_expanded = np.array([""] * sequence_length, dtype=object)
|
| 369 |
+
confidence_expanded = np.zeros(sequence_length, dtype=np.float32)
|
| 370 |
+
atom37_positions_expanded = np.full(
|
| 371 |
+
(sequence_length, 37, 3), np.nan, dtype=np.float32
|
| 372 |
+
)
|
| 373 |
+
atom37_mask_expanded = np.zeros((sequence_length, 37), dtype=bool)
|
| 374 |
+
|
| 375 |
+
# Map residue data to sequence positions (skipping chain breaks)
|
| 376 |
+
residue_idx = 0
|
| 377 |
+
residue_counter_per_chain = {}
|
| 378 |
+
|
| 379 |
+
for seq_pos, char in enumerate(single_letter_sequence):
|
| 380 |
+
if char != "|":
|
| 381 |
+
# This is a residue position
|
| 382 |
+
chain_id_val = protein_chain_ids[residue_idx]
|
| 383 |
+
|
| 384 |
+
chain_id_expanded[seq_pos] = chain_id_val
|
| 385 |
+
entity_id_expanded[seq_pos] = chain_id_val # Simplified mapping
|
| 386 |
+
|
| 387 |
+
# Track residue numbering per chain
|
| 388 |
+
if chain_id_val not in residue_counter_per_chain:
|
| 389 |
+
residue_counter_per_chain[chain_id_val] = 1
|
| 390 |
+
else:
|
| 391 |
+
residue_counter_per_chain[chain_id_val] += 1
|
| 392 |
+
|
| 393 |
+
residue_index_expanded[seq_pos] = residue_counter_per_chain[
|
| 394 |
+
chain_id_val
|
| 395 |
+
]
|
| 396 |
+
confidence_expanded[seq_pos] = protein_confidence[residue_idx]
|
| 397 |
+
atom37_positions_expanded[seq_pos] = atom37_positions[residue_idx]
|
| 398 |
+
atom37_mask_expanded[seq_pos] = atom37_mask[residue_idx]
|
| 399 |
+
|
| 400 |
+
residue_idx += 1
|
| 401 |
+
# Chain break positions keep default values (-1, False, etc.)
|
| 402 |
+
|
| 403 |
+
# Use the expanded arrays
|
| 404 |
+
chain_id = chain_id_expanded
|
| 405 |
+
entity_id = entity_id_expanded
|
| 406 |
+
sym_id = sym_id_expanded
|
| 407 |
+
residue_index = residue_index_expanded
|
| 408 |
+
insertion_code = insertion_code_expanded
|
| 409 |
+
protein_confidence = confidence_expanded
|
| 410 |
+
atom37_positions = atom37_positions_expanded
|
| 411 |
+
atom37_mask = atom37_mask_expanded
|
| 412 |
+
|
| 413 |
+
# Create protein complex metadata preserving chain information
|
| 414 |
+
# Convert MolecularComplex metadata to ProteinComplex format
|
| 415 |
+
unique_chain_ids = np.unique(protein_chain_ids)
|
| 416 |
+
entity_lookup = {int(cid): int(cid) for cid in unique_chain_ids}
|
| 417 |
+
chain_lookup = {
|
| 418 |
+
int(cid): self.metadata.chain_lookup.get(int(cid), chr(65 + int(cid)))
|
| 419 |
+
for cid in unique_chain_ids
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
protein_metadata = ProteinComplexMetadata(
|
| 423 |
+
entity_lookup=entity_lookup,
|
| 424 |
+
chain_lookup=chain_lookup,
|
| 425 |
+
assembly_composition=self.metadata.assembly_composition,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
return ProteinComplex(
|
| 429 |
+
id=self.id,
|
| 430 |
+
sequence=single_letter_sequence,
|
| 431 |
+
entity_id=entity_id,
|
| 432 |
+
chain_id=chain_id,
|
| 433 |
+
sym_id=sym_id,
|
| 434 |
+
residue_index=residue_index,
|
| 435 |
+
insertion_code=insertion_code,
|
| 436 |
+
atom37_positions=atom37_positions,
|
| 437 |
+
atom37_mask=atom37_mask,
|
| 438 |
+
confidence=protein_confidence,
|
| 439 |
+
metadata=protein_metadata,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
@classmethod
|
| 443 |
+
def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex":
|
| 444 |
+
"""Read MolecularComplex from mmcif file or string.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
inp: Path to mmCIF file or mmCIF content as string
|
| 448 |
+
id: Optional identifier to assign to the complex
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
MolecularComplex with all molecules (proteins, ligands, nucleic acids)
|
| 452 |
+
"""
|
| 453 |
+
from io import StringIO
|
| 454 |
+
|
| 455 |
+
# Check if input is a file path or mmCIF string content
|
| 456 |
+
if os.path.exists(inp):
|
| 457 |
+
# Input is a file path
|
| 458 |
+
mmcif_file = pdbx.CIFFile.read(inp)
|
| 459 |
+
else:
|
| 460 |
+
# Input is mmCIF string content
|
| 461 |
+
mmcif_file = pdbx.CIFFile.read(StringIO(inp))
|
| 462 |
+
|
| 463 |
+
# Get structure - handle missing model information gracefully
|
| 464 |
+
try:
|
| 465 |
+
structure = pdbx.get_structure(
|
| 466 |
+
mmcif_file, model=1, extra_fields=["b_factor"]
|
| 467 |
+
)
|
| 468 |
+
except (KeyError, ValueError):
|
| 469 |
+
# Fallback for mmCIF files without model information
|
| 470 |
+
try:
|
| 471 |
+
structure = pdbx.get_structure(mmcif_file)
|
| 472 |
+
except Exception:
|
| 473 |
+
# Last resort: use the first available model or all atoms
|
| 474 |
+
structure = pdbx.get_structure(mmcif_file, model=None)
|
| 475 |
+
# Type hint for pyright - structure is an AtomArray which is iterable
|
| 476 |
+
if TYPE_CHECKING:
|
| 477 |
+
structure: Any = structure
|
| 478 |
+
|
| 479 |
+
# Read label_asym_id from the raw CIF atom_site category.
|
| 480 |
+
# Biotite's atom.chain_id uses auth_asym_id, which collapses ligands
|
| 481 |
+
# onto their parent protein chain. label_asym_id gives each entity a
|
| 482 |
+
# distinct chain identifier.
|
| 483 |
+
block = mmcif_file.block
|
| 484 |
+
label_asym_ids: list[str] | None = None
|
| 485 |
+
if "atom_site" in block:
|
| 486 |
+
atom_site = block["atom_site"]
|
| 487 |
+
if "label_asym_id" in atom_site:
|
| 488 |
+
_col = atom_site["label_asym_id"]
|
| 489 |
+
_raw = (
|
| 490 |
+
_col.as_array(str)
|
| 491 |
+
if hasattr(_col, "as_array")
|
| 492 |
+
else np.array(list(_col), dtype=str) # type: ignore[arg-type]
|
| 493 |
+
)
|
| 494 |
+
# biotite's get_structure(model=1) filters to model 1 AND
|
| 495 |
+
# removes alternate conformations. We must apply the same
|
| 496 |
+
# filters to label_asym_id to keep arrays aligned.
|
| 497 |
+
keep = np.ones(len(_raw), dtype=bool)
|
| 498 |
+
if "pdbx_PDB_model_num" in atom_site:
|
| 499 |
+
_mc = atom_site["pdbx_PDB_model_num"]
|
| 500 |
+
_models = (
|
| 501 |
+
_mc.as_array(str)
|
| 502 |
+
if hasattr(_mc, "as_array")
|
| 503 |
+
else np.array(list(_mc), dtype=str) # type: ignore[arg-type]
|
| 504 |
+
)
|
| 505 |
+
keep &= _models == "1"
|
| 506 |
+
if "label_alt_id" in atom_site:
|
| 507 |
+
_ac = atom_site["label_alt_id"]
|
| 508 |
+
_alts = (
|
| 509 |
+
_ac.as_array(str)
|
| 510 |
+
if hasattr(_ac, "as_array")
|
| 511 |
+
else np.array(list(_ac), dtype=str) # type: ignore[arg-type]
|
| 512 |
+
)
|
| 513 |
+
keep &= np.isin(_alts, [".", "?", "", "A"])
|
| 514 |
+
filtered = _raw[keep]
|
| 515 |
+
if len(filtered) == len(structure):
|
| 516 |
+
label_asym_ids = filtered.tolist()
|
| 517 |
+
# If lengths still don't match, fall back to atom.chain_id
|
| 518 |
+
|
| 519 |
+
# Get entity information from mmCIF
|
| 520 |
+
entity_info = {}
|
| 521 |
+
try:
|
| 522 |
+
if "entity" in block:
|
| 523 |
+
entity_category = block["entity"]
|
| 524 |
+
if "id" in entity_category and "type" in entity_category:
|
| 525 |
+
entity_ids = entity_category["id"]
|
| 526 |
+
entity_types = entity_category["type"]
|
| 527 |
+
# Convert CIFColumn to list for iteration
|
| 528 |
+
if hasattr(entity_ids, "__iter__") and hasattr(
|
| 529 |
+
entity_types, "__iter__"
|
| 530 |
+
):
|
| 531 |
+
# Type annotation to help pyright understand these are iterable
|
| 532 |
+
entity_ids_list = list(entity_ids) # type: ignore
|
| 533 |
+
entity_types_list = list(entity_types) # type: ignore
|
| 534 |
+
for eid, etype in zip(entity_ids_list, entity_types_list):
|
| 535 |
+
entity_info[eid] = etype
|
| 536 |
+
except Exception:
|
| 537 |
+
pass
|
| 538 |
+
|
| 539 |
+
# Initialize arrays for flat atom representation
|
| 540 |
+
sequence_tokens = []
|
| 541 |
+
flat_positions = []
|
| 542 |
+
flat_elements = []
|
| 543 |
+
flat_names = []
|
| 544 |
+
flat_hetero = []
|
| 545 |
+
token_to_atoms = []
|
| 546 |
+
confidence_scores = []
|
| 547 |
+
chain_ids = [] # Track chain IDs for each token
|
| 548 |
+
|
| 549 |
+
atom_idx = 0
|
| 550 |
+
|
| 551 |
+
# Group atoms by chain and residue.
|
| 552 |
+
# Use label_asym_id (distinct per entity) when available, otherwise
|
| 553 |
+
# fall back to biotite's chain_id (auth_asym_id).
|
| 554 |
+
chain_residue_groups: dict[str, dict[tuple[int, str], dict]] = {}
|
| 555 |
+
for atom_i, atom in enumerate(structure):
|
| 556 |
+
chain_id = (
|
| 557 |
+
label_asym_ids[atom_i] if label_asym_ids is not None else atom.chain_id
|
| 558 |
+
)
|
| 559 |
+
res_id = atom.res_id
|
| 560 |
+
res_name = atom.res_name
|
| 561 |
+
|
| 562 |
+
if chain_id not in chain_residue_groups:
|
| 563 |
+
chain_residue_groups[chain_id] = {}
|
| 564 |
+
# Key by (res_id, res_name) to distinguish residues that share
|
| 565 |
+
# the same res_id but have different res_name (e.g. a protein
|
| 566 |
+
# residue and a ligand that were on the same auth chain).
|
| 567 |
+
res_key = (res_id, res_name)
|
| 568 |
+
if res_key not in chain_residue_groups[chain_id]:
|
| 569 |
+
chain_residue_groups[chain_id][res_key] = {
|
| 570 |
+
"atoms": [],
|
| 571 |
+
"res_name": res_name,
|
| 572 |
+
"is_hetero": atom.hetero,
|
| 573 |
+
}
|
| 574 |
+
chain_residue_groups[chain_id][res_key]["atoms"].append(atom)
|
| 575 |
+
|
| 576 |
+
# Create a mapping from chain_id to numeric indices
|
| 577 |
+
chain_id_to_numeric = {
|
| 578 |
+
chain_id: idx
|
| 579 |
+
for idx, chain_id in enumerate(sorted(chain_residue_groups.keys()))
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
# Process each chain and residue
|
| 583 |
+
for chain_id in sorted(chain_residue_groups.keys()):
|
| 584 |
+
residues = chain_residue_groups[chain_id]
|
| 585 |
+
numeric_chain_id = chain_id_to_numeric[chain_id]
|
| 586 |
+
|
| 587 |
+
for res_key in sorted(residues.keys()):
|
| 588 |
+
residue_data = residues[res_key]
|
| 589 |
+
res_name = residue_data["res_name"]
|
| 590 |
+
atoms = residue_data["atoms"]
|
| 591 |
+
is_hetero = residue_data["is_hetero"]
|
| 592 |
+
|
| 593 |
+
# Skip water molecules
|
| 594 |
+
if res_name == "HOH":
|
| 595 |
+
continue
|
| 596 |
+
|
| 597 |
+
# Determine token name
|
| 598 |
+
if not is_hetero and res_name in residue_constants.restype_3to1:
|
| 599 |
+
# Standard amino acid
|
| 600 |
+
token_name = res_name
|
| 601 |
+
elif res_name in ["A", "T", "G", "C", "U", "DA", "DT", "DG", "DC"]:
|
| 602 |
+
# Nucleotide
|
| 603 |
+
token_name = res_name
|
| 604 |
+
else:
|
| 605 |
+
# Ligand or other molecule
|
| 606 |
+
token_name = res_name
|
| 607 |
+
|
| 608 |
+
sequence_tokens.append(token_name)
|
| 609 |
+
chain_ids.append(
|
| 610 |
+
numeric_chain_id
|
| 611 |
+
) # Store the numeric chain ID for this token
|
| 612 |
+
token_start = atom_idx
|
| 613 |
+
|
| 614 |
+
# Add all atoms from this residue
|
| 615 |
+
for atom in atoms:
|
| 616 |
+
flat_positions.append(atom.coord)
|
| 617 |
+
|
| 618 |
+
# Get element character
|
| 619 |
+
element = atom.element
|
| 620 |
+
flat_elements.append(element)
|
| 621 |
+
|
| 622 |
+
# Get atom name
|
| 623 |
+
atom_name = atom.atom_name
|
| 624 |
+
flat_names.append(atom_name)
|
| 625 |
+
|
| 626 |
+
# Get hetero flag
|
| 627 |
+
hetero_flag = atom.hetero
|
| 628 |
+
flat_hetero.append(hetero_flag)
|
| 629 |
+
|
| 630 |
+
atom_idx += 1
|
| 631 |
+
|
| 632 |
+
# Record token-to-atom mapping
|
| 633 |
+
token_to_atoms.append([token_start, atom_idx])
|
| 634 |
+
|
| 635 |
+
# Add confidence score (B-factor if available, otherwise 1.0)
|
| 636 |
+
bfactor = getattr(atoms[0], "b_factor", 50.0) if atoms else 50.0
|
| 637 |
+
confidence_scores.append(min(bfactor / 100.0, 1.0))
|
| 638 |
+
|
| 639 |
+
# Convert to numpy arrays
|
| 640 |
+
if not flat_positions:
|
| 641 |
+
# Create minimal arrays if no atoms found
|
| 642 |
+
atom_positions = np.zeros((0, 3), dtype=np.float32)
|
| 643 |
+
atom_elements = np.zeros(0, dtype=object)
|
| 644 |
+
atom_names = np.zeros(0, dtype=object)
|
| 645 |
+
atom_hetero = np.zeros(0, dtype=bool)
|
| 646 |
+
token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32)
|
| 647 |
+
chain_id_array = (
|
| 648 |
+
np.array(chain_ids, dtype=np.int64)
|
| 649 |
+
if chain_ids
|
| 650 |
+
else np.zeros(len(sequence_tokens), dtype=np.int64)
|
| 651 |
+
)
|
| 652 |
+
else:
|
| 653 |
+
atom_positions = np.array(flat_positions, dtype=np.float32)
|
| 654 |
+
atom_elements = np.array(flat_elements, dtype=object)
|
| 655 |
+
atom_names = np.array(flat_names, dtype=object)
|
| 656 |
+
atom_hetero = np.array(flat_hetero, dtype=bool)
|
| 657 |
+
token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
|
| 658 |
+
chain_id_array = np.array(chain_ids, dtype=np.int64)
|
| 659 |
+
|
| 660 |
+
confidence_array = np.array(confidence_scores, dtype=np.float32)
|
| 661 |
+
|
| 662 |
+
# Create metadata using the chain_id_to_numeric mapping
|
| 663 |
+
if chain_residue_groups:
|
| 664 |
+
chain_lookup = {
|
| 665 |
+
numeric_id: chain_id
|
| 666 |
+
for chain_id, numeric_id in chain_id_to_numeric.items()
|
| 667 |
+
}
|
| 668 |
+
else:
|
| 669 |
+
chain_lookup = {}
|
| 670 |
+
|
| 671 |
+
metadata = MolecularComplexMetadata(
|
| 672 |
+
entity_lookup=entity_info,
|
| 673 |
+
chain_lookup=chain_lookup,
|
| 674 |
+
assembly_composition=None,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Set complex ID - if input was a path, use the stem; otherwise use default
|
| 678 |
+
if os.path.exists(inp):
|
| 679 |
+
complex_id = id or Path(inp).stem
|
| 680 |
+
else:
|
| 681 |
+
complex_id = id or "complex_from_string"
|
| 682 |
+
|
| 683 |
+
return cls(
|
| 684 |
+
id=complex_id,
|
| 685 |
+
sequence=sequence_tokens,
|
| 686 |
+
atom_positions=atom_positions,
|
| 687 |
+
atom_elements=atom_elements,
|
| 688 |
+
token_to_atoms=token_to_atoms_array,
|
| 689 |
+
chain_id=chain_id_array,
|
| 690 |
+
plddt=confidence_array,
|
| 691 |
+
metadata=metadata,
|
| 692 |
+
atom_names=atom_names,
|
| 693 |
+
atom_hetero=atom_hetero,
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
def _get_entity_mapping(
|
| 697 |
+
self,
|
| 698 |
+
) -> tuple[dict[str, list[str]], dict[str, int], dict[int, tuple[str, ...]]]:
|
| 699 |
+
"""Compute chain→sequence, chain→entity_id, and entity_id→sequence mappings.
|
| 700 |
+
|
| 701 |
+
Returns:
|
| 702 |
+
(chain_sequences, chain_to_entity, entity_sequences)
|
| 703 |
+
"""
|
| 704 |
+
chain_sequences: dict[str, list[str]] = {}
|
| 705 |
+
for token_idx in range(len(self.token_to_atoms)):
|
| 706 |
+
chain_id_numeric = self.chain_id[token_idx]
|
| 707 |
+
chain_id_str = self.metadata.chain_lookup.get(
|
| 708 |
+
int(chain_id_numeric), chr(65 + int(chain_id_numeric))
|
| 709 |
+
)
|
| 710 |
+
if chain_id_str not in chain_sequences:
|
| 711 |
+
chain_sequences[chain_id_str] = []
|
| 712 |
+
chain_sequences[chain_id_str].append(self.sequence[token_idx])
|
| 713 |
+
|
| 714 |
+
sequence_to_entity: dict[tuple[str, ...], int] = {}
|
| 715 |
+
chain_to_entity: dict[str, int] = {}
|
| 716 |
+
entity_sequences: dict[int, tuple[str, ...]] = {}
|
| 717 |
+
entity_id_counter = 1
|
| 718 |
+
for chain_id_str, sequence in chain_sequences.items():
|
| 719 |
+
seq_tuple = tuple(sequence)
|
| 720 |
+
if seq_tuple not in sequence_to_entity:
|
| 721 |
+
sequence_to_entity[seq_tuple] = entity_id_counter
|
| 722 |
+
entity_sequences[entity_id_counter] = seq_tuple
|
| 723 |
+
entity_id_counter += 1
|
| 724 |
+
chain_to_entity[chain_id_str] = sequence_to_entity[seq_tuple]
|
| 725 |
+
|
| 726 |
+
return chain_sequences, chain_to_entity, entity_sequences
|
| 727 |
+
|
| 728 |
+
def _add_entity_information(
|
| 729 |
+
self, cif_file: CIFFile, entity_sequences: dict[int, tuple[str, ...]]
|
| 730 |
+
) -> None:
|
| 731 |
+
"""Add _entity category to CIF file so OST can identify ligands vs polymers."""
|
| 732 |
+
|
| 733 |
+
entity_ids: list[str] = []
|
| 734 |
+
entity_types: list[str] = []
|
| 735 |
+
entity_descriptions: list[str] = []
|
| 736 |
+
for eid in sorted(entity_sequences.keys()):
|
| 737 |
+
seq = entity_sequences[eid]
|
| 738 |
+
entity_ids.append(str(eid))
|
| 739 |
+
has_protein = any(t in residue_constants.restype_3to1 for t in seq)
|
| 740 |
+
has_na = any(
|
| 741 |
+
t in ("A", "T", "G", "C", "U", "DA", "DT", "DG", "DC") for t in seq
|
| 742 |
+
)
|
| 743 |
+
if has_protein or has_na:
|
| 744 |
+
entity_types.append("polymer")
|
| 745 |
+
if has_protein:
|
| 746 |
+
entity_descriptions.append(f"Polymer entity {eid} (protein)")
|
| 747 |
+
else:
|
| 748 |
+
entity_descriptions.append(f"Polymer entity {eid} (nucleic acid)")
|
| 749 |
+
else:
|
| 750 |
+
entity_types.append("non-polymer")
|
| 751 |
+
entity_descriptions.append(f"Non-polymer entity {eid}")
|
| 752 |
+
|
| 753 |
+
if entity_ids:
|
| 754 |
+
cif_file.block["entity"] = CIFCategory(
|
| 755 |
+
name="entity",
|
| 756 |
+
columns={
|
| 757 |
+
"id": CIFColumn(
|
| 758 |
+
data=CIFData(array=np.array(entity_ids), dtype=np.str_)
|
| 759 |
+
),
|
| 760 |
+
"type": CIFColumn(
|
| 761 |
+
data=CIFData(array=np.array(entity_types), dtype=np.str_)
|
| 762 |
+
),
|
| 763 |
+
"pdbx_description": CIFColumn(
|
| 764 |
+
data=CIFData(array=np.array(entity_descriptions), dtype=np.str_)
|
| 765 |
+
),
|
| 766 |
+
},
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
# Add _struct_asym to map chain IDs to entity IDs
|
| 770 |
+
_, chain_to_entity, _ = self._get_entity_mapping()
|
| 771 |
+
if chain_to_entity:
|
| 772 |
+
asym_ids = sorted(chain_to_entity.keys())
|
| 773 |
+
asym_entity_ids = [str(chain_to_entity[c]) for c in asym_ids]
|
| 774 |
+
cif_file.block["struct_asym"] = CIFCategory(
|
| 775 |
+
name="struct_asym",
|
| 776 |
+
columns={
|
| 777 |
+
"id": CIFColumn(
|
| 778 |
+
data=CIFData(array=np.array(asym_ids), dtype=np.str_)
|
| 779 |
+
),
|
| 780 |
+
"entity_id": CIFColumn(
|
| 781 |
+
data=CIFData(array=np.array(asym_entity_ids), dtype=np.str_)
|
| 782 |
+
),
|
| 783 |
+
},
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
def to_mmcif(self) -> str:
|
| 787 |
+
"""Write MolecularComplex to mmcif string using biotite.
|
| 788 |
+
|
| 789 |
+
Returns:
|
| 790 |
+
String representation of the complex in mmCIF format
|
| 791 |
+
"""
|
| 792 |
+
# Pre-allocate AtomArray
|
| 793 |
+
n_atoms = len(self.atom_positions)
|
| 794 |
+
atom_array = bs.AtomArray(length=n_atoms)
|
| 795 |
+
|
| 796 |
+
# Set coordinates directly (already vectorized)
|
| 797 |
+
atom_array.coord = self.atom_positions
|
| 798 |
+
|
| 799 |
+
# Pre-allocate per-atom arrays
|
| 800 |
+
atom_res_ids = np.zeros(n_atoms, dtype=np.int32)
|
| 801 |
+
atom_chain_ids = np.empty(n_atoms, dtype=object)
|
| 802 |
+
atom_res_names = np.empty(n_atoms, dtype=object)
|
| 803 |
+
atom_hetero = np.zeros(n_atoms, dtype=bool)
|
| 804 |
+
atom_bfactors = np.zeros(n_atoms, dtype=np.float32)
|
| 805 |
+
atom_names = np.empty(n_atoms, dtype=object)
|
| 806 |
+
|
| 807 |
+
# Build entity mappings: chains with identical sequences share entity ID
|
| 808 |
+
_, chain_to_entity, entity_sequences = self._get_entity_mapping()
|
| 809 |
+
|
| 810 |
+
atom_entity_ids = np.zeros(n_atoms, dtype=np.int32)
|
| 811 |
+
|
| 812 |
+
# Track residue IDs per chain
|
| 813 |
+
chain_res_counters: dict[int, int] = {}
|
| 814 |
+
|
| 815 |
+
# Vectorized expansion of token-level to atom-level annotations
|
| 816 |
+
for token_idx, (start, end) in enumerate(self.token_to_atoms):
|
| 817 |
+
token = self.sequence[token_idx]
|
| 818 |
+
chain_id_numeric = self.chain_id[token_idx]
|
| 819 |
+
chain_id_str = self.metadata.chain_lookup.get(
|
| 820 |
+
int(chain_id_numeric), chr(65 + int(chain_id_numeric))
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
# Track residue numbering per chain
|
| 824 |
+
if chain_id_numeric not in chain_res_counters:
|
| 825 |
+
chain_res_counters[chain_id_numeric] = 1
|
| 826 |
+
res_id = chain_res_counters[chain_id_numeric]
|
| 827 |
+
chain_res_counters[chain_id_numeric] += 1
|
| 828 |
+
|
| 829 |
+
# Determine if protein
|
| 830 |
+
is_protein = token in residue_constants.restype_3to1
|
| 831 |
+
|
| 832 |
+
# Get atom names for this residue
|
| 833 |
+
if self.atom_names is not None:
|
| 834 |
+
# Use stored atom names (preserves original names from mmCIF)
|
| 835 |
+
names = list(self.atom_names[start:end])
|
| 836 |
+
elif is_protein:
|
| 837 |
+
# Fallback: use standard protein atom names
|
| 838 |
+
standard_names = residue_constants.residue_atoms.get(
|
| 839 |
+
token, ["N", "CA", "C", "O"]
|
| 840 |
+
)
|
| 841 |
+
names = standard_names[: end - start]
|
| 842 |
+
# Pad if needed
|
| 843 |
+
while len(names) < (end - start):
|
| 844 |
+
names.append(f"X{len(names)+1}")
|
| 845 |
+
else:
|
| 846 |
+
# Fallback: generate names for ligands/nucleic acids
|
| 847 |
+
names = [f"C{i+1}" for i in range(end - start)]
|
| 848 |
+
|
| 849 |
+
# Vectorized assignment for this token's atoms
|
| 850 |
+
atom_res_ids[start:end] = res_id
|
| 851 |
+
atom_chain_ids[start:end] = chain_id_str
|
| 852 |
+
atom_res_names[start:end] = token
|
| 853 |
+
# Use stored hetero flags if available, otherwise guess based on protein status
|
| 854 |
+
if self.atom_hetero is not None:
|
| 855 |
+
atom_hetero[start:end] = self.atom_hetero[start:end]
|
| 856 |
+
else:
|
| 857 |
+
atom_hetero[start:end] = not is_protein
|
| 858 |
+
atom_bfactors[start:end] = self.plddt[token_idx] * 100.0
|
| 859 |
+
atom_names[start:end] = names
|
| 860 |
+
atom_entity_ids[start:end] = chain_to_entity.get(chain_id_str, 1)
|
| 861 |
+
|
| 862 |
+
# Set all AtomArray attributes at once (convert object arrays to proper string arrays)
|
| 863 |
+
# res_name uses U8 to accommodate CCD codes up to 5 characters (e.g., A1AZ2);
|
| 864 |
+
# chain_id uses U16 because chain names like ``ligand_1`` / ``ligand_2`` /
|
| 865 |
+
# auth-asym ids of arbitrary length are possible.
|
| 866 |
+
atom_array.res_id = atom_res_ids
|
| 867 |
+
atom_array.chain_id = np.array(atom_chain_ids, dtype="U16")
|
| 868 |
+
atom_array.res_name = np.array(atom_res_names, dtype="U8")
|
| 869 |
+
atom_array.hetero = atom_hetero
|
| 870 |
+
atom_array.atom_name = np.array(atom_names, dtype="U4")
|
| 871 |
+
atom_array.add_annotation("b_factor", dtype=float)
|
| 872 |
+
atom_array.b_factor = atom_bfactors
|
| 873 |
+
atom_array.add_annotation("entity_id", dtype=int)
|
| 874 |
+
atom_array.entity_id = atom_entity_ids
|
| 875 |
+
|
| 876 |
+
# Use existing elements or infer them from atom names
|
| 877 |
+
if self.atom_elements is not None and len(self.atom_elements) == n_atoms:
|
| 878 |
+
# Convert object array to proper string array for biotite
|
| 879 |
+
atom_array.element = np.array(self.atom_elements, dtype="U4")
|
| 880 |
+
else:
|
| 881 |
+
# Use biotite's built-in element inference
|
| 882 |
+
atom_array.element = bs.infer_elements(atom_array)
|
| 883 |
+
|
| 884 |
+
# Create CIF file and set structure
|
| 885 |
+
cif_file = CIFFile()
|
| 886 |
+
set_structure(cif_file, atom_array, data_block=self.id)
|
| 887 |
+
|
| 888 |
+
# Manually fix label_entity_id (biotite doesn't use entity_id annotation correctly)
|
| 889 |
+
if "atom_site" in cif_file.block:
|
| 890 |
+
atom_site = cif_file.block["atom_site"]
|
| 891 |
+
if "label_asym_id" in atom_site and "label_entity_id" in atom_site:
|
| 892 |
+
label_asym_ids = atom_site["label_asym_id"]
|
| 893 |
+
if hasattr(label_asym_ids, "as_array"):
|
| 894 |
+
chain_ids_list = label_asym_ids.as_array(str).tolist()
|
| 895 |
+
elif hasattr(label_asym_ids, "__iter__"):
|
| 896 |
+
chain_ids_list = list(label_asym_ids) # type: ignore[arg-type]
|
| 897 |
+
else:
|
| 898 |
+
chain_ids_list = []
|
| 899 |
+
updated_entity_ids = [
|
| 900 |
+
str(chain_to_entity.get(cid, 1)) for cid in chain_ids_list
|
| 901 |
+
]
|
| 902 |
+
if updated_entity_ids:
|
| 903 |
+
atom_site["label_entity_id"] = CIFColumn(
|
| 904 |
+
data=CIFData(array=np.array(updated_entity_ids), dtype=np.str_)
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
# Add _entity category for OST compatibility
|
| 908 |
+
self._add_entity_information(cif_file, entity_sequences)
|
| 909 |
+
|
| 910 |
+
# Convert to string
|
| 911 |
+
output = io.StringIO()
|
| 912 |
+
cif_file.write(output)
|
| 913 |
+
return output.getvalue()
|
| 914 |
+
|
| 915 |
+
def dockq(self, native: "MolecularComplex") -> Any:
|
| 916 |
+
"""Compute DockQ score against native structure.
|
| 917 |
+
|
| 918 |
+
Args:
|
| 919 |
+
native: Native MolecularComplex to compute DockQ against
|
| 920 |
+
|
| 921 |
+
Returns:
|
| 922 |
+
DockQ result containing score and alignment information
|
| 923 |
+
"""
|
| 924 |
+
# Imports moved to top of file
|
| 925 |
+
|
| 926 |
+
# Convert both complexes to ProteinComplex format for DockQ computation
|
| 927 |
+
# This extracts only the protein portion and converts to PDB format
|
| 928 |
+
try:
|
| 929 |
+
self_pc = self.to_protein_complex()
|
| 930 |
+
native_pc = native.to_protein_complex()
|
| 931 |
+
except ValueError as e:
|
| 932 |
+
raise ValueError(
|
| 933 |
+
f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}"
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
# Normalize chain IDs for PDB compatibility
|
| 937 |
+
self_pc = self_pc.normalize_chain_ids_for_pdb()
|
| 938 |
+
native_pc = native_pc.normalize_chain_ids_for_pdb()
|
| 939 |
+
|
| 940 |
+
# Use the existing ProteinComplex.dockq() method
|
| 941 |
+
try:
|
| 942 |
+
dockq_result = self_pc.dockq(native_pc)
|
| 943 |
+
return dockq_result
|
| 944 |
+
except Exception:
|
| 945 |
+
# Fallback to manual DockQ computation if ProteinComplex.dockq() fails
|
| 946 |
+
return self._compute_dockq_manual(native)
|
| 947 |
+
|
| 948 |
+
def _compute_dockq_manual(self, native: "MolecularComplex") -> Any:
|
| 949 |
+
"""Manual DockQ computation fallback."""
|
| 950 |
+
# Imports moved to top of file
|
| 951 |
+
|
| 952 |
+
# Convert both complexes to ProteinComplex format
|
| 953 |
+
try:
|
| 954 |
+
self_pc = self.to_protein_complex()
|
| 955 |
+
native_pc = native.to_protein_complex()
|
| 956 |
+
except ValueError as e:
|
| 957 |
+
raise ValueError(
|
| 958 |
+
f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}"
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
# Normalize chain IDs for PDB compatibility
|
| 962 |
+
self_pc = self_pc.normalize_chain_ids_for_pdb()
|
| 963 |
+
native_pc = native_pc.normalize_chain_ids_for_pdb()
|
| 964 |
+
|
| 965 |
+
# Write temporary PDB files and run DockQ
|
| 966 |
+
with TemporaryDirectory() as tdir:
|
| 967 |
+
dir_path = Path(tdir)
|
| 968 |
+
self_pdb = dir_path / "self.pdb"
|
| 969 |
+
native_pdb = dir_path / "native.pdb"
|
| 970 |
+
|
| 971 |
+
# Write PDB files
|
| 972 |
+
self_pc.to_pdb(self_pdb)
|
| 973 |
+
native_pc.to_pdb(native_pdb)
|
| 974 |
+
|
| 975 |
+
# Run DockQ
|
| 976 |
+
try:
|
| 977 |
+
output = check_output(["DockQ", str(self_pdb), str(native_pdb)])
|
| 978 |
+
output_text = output.decode()
|
| 979 |
+
|
| 980 |
+
# Parse DockQ output
|
| 981 |
+
lines = output_text.split("\n")
|
| 982 |
+
|
| 983 |
+
# Find the total DockQ score
|
| 984 |
+
dockq_score = None
|
| 985 |
+
for line in lines:
|
| 986 |
+
if "Total DockQ" in line:
|
| 987 |
+
match = re.search(r"Total DockQ.*: ([\d.]+)", line)
|
| 988 |
+
if match:
|
| 989 |
+
dockq_score = float(match.group(1))
|
| 990 |
+
break
|
| 991 |
+
|
| 992 |
+
if dockq_score is None:
|
| 993 |
+
# Try to find individual DockQ scores
|
| 994 |
+
for line in lines:
|
| 995 |
+
if line.startswith("DockQ") and ":" in line:
|
| 996 |
+
try:
|
| 997 |
+
dockq_score = float(line.split(":")[1].strip())
|
| 998 |
+
break
|
| 999 |
+
except (ValueError, IndexError):
|
| 1000 |
+
continue
|
| 1001 |
+
|
| 1002 |
+
if dockq_score is None:
|
| 1003 |
+
raise ValueError("Could not parse DockQ score from output")
|
| 1004 |
+
|
| 1005 |
+
# Return a simple result structure
|
| 1006 |
+
return {
|
| 1007 |
+
"total_dockq": dockq_score,
|
| 1008 |
+
"raw_output": output_text,
|
| 1009 |
+
"aligned": self, # Return self as aligned structure
|
| 1010 |
+
}
|
| 1011 |
+
|
| 1012 |
+
except FileNotFoundError:
|
| 1013 |
+
raise RuntimeError(
|
| 1014 |
+
"DockQ is not installed. Please install DockQ to use this method."
|
| 1015 |
+
)
|
| 1016 |
+
except Exception as e:
|
| 1017 |
+
raise RuntimeError(f"DockQ computation failed: {e}")
|
| 1018 |
+
|
| 1019 |
+
def rmsd(self, target: "MolecularComplex", **kwargs) -> float:
|
| 1020 |
+
"""Compute RMSD against target structure.
|
| 1021 |
+
|
| 1022 |
+
Args:
|
| 1023 |
+
target: Target MolecularComplex to compute RMSD against
|
| 1024 |
+
**kwargs: Additional arguments passed to compute_rmsd
|
| 1025 |
+
|
| 1026 |
+
Returns:
|
| 1027 |
+
float: RMSD value between the two structures
|
| 1028 |
+
"""
|
| 1029 |
+
# Imports moved to top of file
|
| 1030 |
+
|
| 1031 |
+
# Ensure both complexes have the same number of tokens
|
| 1032 |
+
if len(self) != len(target):
|
| 1033 |
+
raise ValueError(
|
| 1034 |
+
f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}"
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
# Extract center positions for each token (using centroid of atoms)
|
| 1038 |
+
mobile_coords = []
|
| 1039 |
+
target_coords = []
|
| 1040 |
+
atom_mask = []
|
| 1041 |
+
|
| 1042 |
+
for i in range(len(self)):
|
| 1043 |
+
# Get atom positions for this token
|
| 1044 |
+
mobile_start, mobile_end = self.token_to_atoms[i]
|
| 1045 |
+
target_start, target_end = target.token_to_atoms[i]
|
| 1046 |
+
|
| 1047 |
+
# Extract atom positions
|
| 1048 |
+
mobile_atoms = self.atom_positions[mobile_start:mobile_end]
|
| 1049 |
+
target_atoms = target.atom_positions[target_start:target_end]
|
| 1050 |
+
|
| 1051 |
+
# Check if both tokens have atoms
|
| 1052 |
+
if len(mobile_atoms) == 0 or len(target_atoms) == 0:
|
| 1053 |
+
# Skip tokens with no atoms
|
| 1054 |
+
continue
|
| 1055 |
+
|
| 1056 |
+
# For simplicity, use the centroid of atoms as the representative position
|
| 1057 |
+
mobile_center = mobile_atoms.mean(axis=0)
|
| 1058 |
+
target_center = target_atoms.mean(axis=0)
|
| 1059 |
+
|
| 1060 |
+
mobile_coords.append(mobile_center)
|
| 1061 |
+
target_coords.append(target_center)
|
| 1062 |
+
atom_mask.append(True)
|
| 1063 |
+
|
| 1064 |
+
if len(mobile_coords) == 0:
|
| 1065 |
+
raise ValueError("No valid atoms found for RMSD computation")
|
| 1066 |
+
|
| 1067 |
+
# Convert to tensors
|
| 1068 |
+
mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze(
|
| 1069 |
+
0
|
| 1070 |
+
) # [1, N, 3]
|
| 1071 |
+
target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze(
|
| 1072 |
+
0
|
| 1073 |
+
) # [1, N, 3]
|
| 1074 |
+
mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N]
|
| 1075 |
+
|
| 1076 |
+
# Compute RMSD using existing infrastructure
|
| 1077 |
+
rmsd_value = compute_rmsd(
|
| 1078 |
+
mobile=mobile_tensor,
|
| 1079 |
+
target=target_tensor,
|
| 1080 |
+
atom_exists_mask=mask_tensor,
|
| 1081 |
+
reduction="batch",
|
| 1082 |
+
**kwargs,
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
return float(rmsd_value)
|
| 1086 |
+
|
| 1087 |
+
def lddt_ca(self, target: "MolecularComplex", **kwargs) -> float:
|
| 1088 |
+
"""Compute LDDT score against target structure.
|
| 1089 |
+
|
| 1090 |
+
Args:
|
| 1091 |
+
target: Target MolecularComplex to compute LDDT against
|
| 1092 |
+
**kwargs: Additional arguments passed to compute_lddt
|
| 1093 |
+
|
| 1094 |
+
Returns:
|
| 1095 |
+
float: LDDT value between the two structures
|
| 1096 |
+
"""
|
| 1097 |
+
# Imports moved to top of file
|
| 1098 |
+
|
| 1099 |
+
# Ensure both complexes have the same number of tokens
|
| 1100 |
+
if len(self) != len(target):
|
| 1101 |
+
raise ValueError(
|
| 1102 |
+
f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}"
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
# Extract center positions for each token (using centroid of atoms)
|
| 1106 |
+
mobile_coords = []
|
| 1107 |
+
target_coords = []
|
| 1108 |
+
atom_mask = []
|
| 1109 |
+
|
| 1110 |
+
for i in range(len(self)):
|
| 1111 |
+
# Get atom positions for this token
|
| 1112 |
+
mobile_start, mobile_end = self.token_to_atoms[i]
|
| 1113 |
+
target_start, target_end = target.token_to_atoms[i]
|
| 1114 |
+
|
| 1115 |
+
# Extract atom positions
|
| 1116 |
+
mobile_atoms = self.atom_positions[mobile_start:mobile_end]
|
| 1117 |
+
target_atoms = target.atom_positions[target_start:target_end]
|
| 1118 |
+
|
| 1119 |
+
# Check if both tokens have atoms
|
| 1120 |
+
if len(mobile_atoms) == 0 or len(target_atoms) == 0:
|
| 1121 |
+
# Skip tokens with no atoms
|
| 1122 |
+
mobile_coords.append(np.full(3, np.nan))
|
| 1123 |
+
target_coords.append(np.full(3, np.nan))
|
| 1124 |
+
atom_mask.append(False)
|
| 1125 |
+
continue
|
| 1126 |
+
|
| 1127 |
+
# For simplicity, use the centroid of atoms as the representative position
|
| 1128 |
+
mobile_center = mobile_atoms.mean(axis=0)
|
| 1129 |
+
target_center = target_atoms.mean(axis=0)
|
| 1130 |
+
|
| 1131 |
+
mobile_coords.append(mobile_center)
|
| 1132 |
+
target_coords.append(target_center)
|
| 1133 |
+
atom_mask.append(True)
|
| 1134 |
+
|
| 1135 |
+
if not any(atom_mask):
|
| 1136 |
+
raise ValueError("No valid atoms found for LDDT computation")
|
| 1137 |
+
|
| 1138 |
+
# Convert to tensors
|
| 1139 |
+
mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze(
|
| 1140 |
+
0
|
| 1141 |
+
) # [1, N, 3]
|
| 1142 |
+
target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze(
|
| 1143 |
+
0
|
| 1144 |
+
) # [1, N, 3]
|
| 1145 |
+
mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N]
|
| 1146 |
+
|
| 1147 |
+
# Compute LDDT using existing infrastructure
|
| 1148 |
+
lddt_value = compute_lddt(
|
| 1149 |
+
all_atom_pred_pos=mobile_tensor,
|
| 1150 |
+
all_atom_positions=target_tensor,
|
| 1151 |
+
all_atom_mask=mask_tensor,
|
| 1152 |
+
per_residue=False, # Return overall LDDT score
|
| 1153 |
+
**kwargs,
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
return float(lddt_value)
|
| 1157 |
+
|
| 1158 |
+
def state_dict(self):
|
| 1159 |
+
"""This state dict is optimized for storage, so it turns things to fp16 whenever
|
| 1160 |
+
possible and converts numpy arrays to lists for JSON serialization.
|
| 1161 |
+
"""
|
| 1162 |
+
dct = {k: v for k, v in vars(self).items()}
|
| 1163 |
+
for k, v in dct.items():
|
| 1164 |
+
if isinstance(v, np.ndarray):
|
| 1165 |
+
match v.dtype:
|
| 1166 |
+
case np.int64:
|
| 1167 |
+
dct[k] = v.astype(np.int32).tolist()
|
| 1168 |
+
case np.float64 | np.float32:
|
| 1169 |
+
dct[k] = v.astype(np.float16).tolist()
|
| 1170 |
+
case _:
|
| 1171 |
+
dct[k] = v.tolist()
|
| 1172 |
+
elif isinstance(v, MolecularComplexMetadata):
|
| 1173 |
+
dct[k] = asdict(v)
|
| 1174 |
+
|
| 1175 |
+
return dct
|
| 1176 |
+
|
| 1177 |
+
def to_blob(self) -> bytes:
|
| 1178 |
+
return brotli.compress(msgpack.dumps(self.state_dict()), quality=5)
|
| 1179 |
+
|
| 1180 |
+
@classmethod
|
| 1181 |
+
def from_state_dict(cls, dct):
|
| 1182 |
+
for k, v in dct.items():
|
| 1183 |
+
if isinstance(v, list) and k in [
|
| 1184 |
+
"atom_positions",
|
| 1185 |
+
"atom_elements",
|
| 1186 |
+
"atom_names",
|
| 1187 |
+
"atom_hetero",
|
| 1188 |
+
"token_to_atoms",
|
| 1189 |
+
"chain_id",
|
| 1190 |
+
"plddt",
|
| 1191 |
+
]:
|
| 1192 |
+
dct[k] = np.array(v)
|
| 1193 |
+
|
| 1194 |
+
for k, v in dct.items():
|
| 1195 |
+
if isinstance(v, np.ndarray):
|
| 1196 |
+
if k in ["atom_positions", "plddt"]:
|
| 1197 |
+
dct[k] = v.astype(np.float32)
|
| 1198 |
+
elif k in ["token_to_atoms", "chain_id"]:
|
| 1199 |
+
dct[k] = (
|
| 1200 |
+
v.astype(np.int32)
|
| 1201 |
+
if k == "token_to_atoms"
|
| 1202 |
+
else v.astype(np.int64)
|
| 1203 |
+
)
|
| 1204 |
+
|
| 1205 |
+
dct["metadata"] = MolecularComplexMetadata(**dct["metadata"])
|
| 1206 |
+
|
| 1207 |
+
# Backward compatibility: if chain_id is missing, create default array
|
| 1208 |
+
if "chain_id" not in dct:
|
| 1209 |
+
# Default all tokens to chain 0
|
| 1210 |
+
dct["chain_id"] = np.zeros(len(dct["sequence"]), dtype=np.int64)
|
| 1211 |
+
|
| 1212 |
+
return cls(**dct)
|
| 1213 |
+
|
| 1214 |
+
@classmethod
|
| 1215 |
+
def from_blob(cls, input: Path | str | io.BytesIO | bytes):
|
| 1216 |
+
match input:
|
| 1217 |
+
case Path() | str():
|
| 1218 |
+
bytes = Path(input).read_bytes()
|
| 1219 |
+
case io.BytesIO():
|
| 1220 |
+
bytes = input.getvalue()
|
| 1221 |
+
case _:
|
| 1222 |
+
bytes = input
|
| 1223 |
+
return cls.from_state_dict(
|
| 1224 |
+
msgpack.loads(brotli.decompress(bytes), strict_map_key=False)
|
| 1225 |
+
)
|
| 1226 |
+
|
esmfold2_msa.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
import string
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from functools import cached_property
|
| 7 |
+
from itertools import islice
|
| 8 |
+
from typing import Sequence
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from Bio import SeqIO
|
| 12 |
+
from scipy.spatial.distance import cdist
|
| 13 |
+
|
| 14 |
+
from .esmfold2_misc import slice_any_object
|
| 15 |
+
from .esmfold2_msa_filter_sequences import greedy_select_indices, hhfilter
|
| 16 |
+
from .esmfold2_parsing import FastaEntry, read_sequences, write_sequences
|
| 17 |
+
from .esmfold2_sequential_dataclass import SequentialDataclass
|
| 18 |
+
from .esmfold2_system import PathOrBuffer
|
| 19 |
+
|
| 20 |
+
REMOVE_LOWERCASE_TRANSLATION = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def remove_insertions_from_sequence(seq: str) -> str:
|
| 24 |
+
return seq.translate(REMOVE_LOWERCASE_TRANSLATION)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass(frozen=True)
|
| 28 |
+
class MSA(SequentialDataclass):
|
| 29 |
+
"""Object-oriented interface to an MSA.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
sequences (list[str]): List of protein sequences
|
| 33 |
+
headers (list[str]): List of headers describing the sequences
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
entries: list[FastaEntry]
|
| 38 |
+
|
| 39 |
+
@cached_property
|
| 40 |
+
def sequences(self) -> list[str]:
|
| 41 |
+
return [entry.sequence for entry in self.entries]
|
| 42 |
+
|
| 43 |
+
@cached_property
|
| 44 |
+
def headers(self) -> list[str]:
|
| 45 |
+
return [entry.header for entry in self.entries]
|
| 46 |
+
|
| 47 |
+
def __repr__(self):
|
| 48 |
+
return (
|
| 49 |
+
f"MSA({self.entries[0].header}: Depth={self.depth}, Length={self.seqlen})"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def to_fast_msa(self) -> FastMSA:
|
| 53 |
+
return FastMSA(self.array, self.headers)
|
| 54 |
+
|
| 55 |
+
@classmethod
|
| 56 |
+
def from_a3m(
|
| 57 |
+
cls,
|
| 58 |
+
path: PathOrBuffer,
|
| 59 |
+
remove_insertions: bool = True,
|
| 60 |
+
max_sequences: int | None = None,
|
| 61 |
+
) -> MSA:
|
| 62 |
+
entries = []
|
| 63 |
+
for header, seq in islice(read_sequences(path), max_sequences):
|
| 64 |
+
if remove_insertions:
|
| 65 |
+
seq = remove_insertions_from_sequence(seq)
|
| 66 |
+
if entries:
|
| 67 |
+
assert (
|
| 68 |
+
len(seq) == len(entries[0].sequence)
|
| 69 |
+
), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}"
|
| 70 |
+
entries.append(FastaEntry(header, seq))
|
| 71 |
+
return cls(entries)
|
| 72 |
+
|
| 73 |
+
def to_a3m(self, path: PathOrBuffer) -> None:
|
| 74 |
+
write_sequences(self.entries, path)
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def from_stockholm(
|
| 78 |
+
cls,
|
| 79 |
+
path: PathOrBuffer,
|
| 80 |
+
remove_insertions: bool = True,
|
| 81 |
+
max_sequences: int | None = None,
|
| 82 |
+
) -> MSA:
|
| 83 |
+
entries = []
|
| 84 |
+
for record in islice(SeqIO.parse(path, "stockholm"), max_sequences):
|
| 85 |
+
header = f"{record.id} {record.description}"
|
| 86 |
+
seq = str(record.seq)
|
| 87 |
+
if entries:
|
| 88 |
+
assert (
|
| 89 |
+
len(seq) == len(entries[0].sequence)
|
| 90 |
+
), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}"
|
| 91 |
+
entries.append(FastaEntry(header, seq))
|
| 92 |
+
msa = cls(entries)
|
| 93 |
+
if remove_insertions:
|
| 94 |
+
keep_inds = [i for i, aa in enumerate(msa.query) if aa != "-"]
|
| 95 |
+
msa = msa.select_positions(keep_inds)
|
| 96 |
+
return msa
|
| 97 |
+
|
| 98 |
+
def to_bytes(self) -> bytes:
|
| 99 |
+
version = 1
|
| 100 |
+
version_bytes = version.to_bytes(1, "little")
|
| 101 |
+
seqlen_bytes = self.seqlen.to_bytes(4, "little")
|
| 102 |
+
depth_bytes = self.depth.to_bytes(4, "little")
|
| 103 |
+
array_bytes = self.array.tobytes()
|
| 104 |
+
header_bytes = "\n".join(entry.header for entry in self.entries).encode()
|
| 105 |
+
all_bytes = (
|
| 106 |
+
version_bytes + seqlen_bytes + depth_bytes + array_bytes + header_bytes
|
| 107 |
+
)
|
| 108 |
+
return all_bytes
|
| 109 |
+
|
| 110 |
+
@classmethod
|
| 111 |
+
def from_bytes(cls, data: bytes) -> MSA:
|
| 112 |
+
version_bytes, seqlen_bytes, depth_bytes, data = (
|
| 113 |
+
data[:1],
|
| 114 |
+
data[1:5],
|
| 115 |
+
data[5:9],
|
| 116 |
+
data[9:],
|
| 117 |
+
)
|
| 118 |
+
version = int.from_bytes(version_bytes, "little")
|
| 119 |
+
if version != 1:
|
| 120 |
+
raise ValueError(f"Unsupported version: {version}")
|
| 121 |
+
seqlen = int.from_bytes(seqlen_bytes, "little")
|
| 122 |
+
depth = int.from_bytes(depth_bytes, "little")
|
| 123 |
+
array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :]
|
| 124 |
+
array = np.frombuffer(array_bytes, dtype="|S1")
|
| 125 |
+
array = array.reshape(depth, seqlen)
|
| 126 |
+
headers = header_bytes.decode().split("\n")
|
| 127 |
+
# Sometimes the separation is two newlines, which results in an empty header.
|
| 128 |
+
headers = [header for header in headers if header]
|
| 129 |
+
# If all headers were empty (e.g., saved from from_sequences), use empty headers
|
| 130 |
+
if len(headers) == 0 and depth > 0:
|
| 131 |
+
headers = [""] * depth
|
| 132 |
+
entries = [
|
| 133 |
+
FastaEntry(header, b"".join(row).decode())
|
| 134 |
+
for header, row in zip(headers, array)
|
| 135 |
+
]
|
| 136 |
+
return cls(entries)
|
| 137 |
+
|
| 138 |
+
# TODO(jmaccarl): set remove_insertions to True by default here to match other utils
|
| 139 |
+
@classmethod
|
| 140 |
+
def from_sequences(
|
| 141 |
+
cls, sequences: list[str], remove_insertions: bool = False
|
| 142 |
+
) -> MSA:
|
| 143 |
+
if remove_insertions:
|
| 144 |
+
entries = [
|
| 145 |
+
FastaEntry("", remove_insertions_from_sequence(seq))
|
| 146 |
+
for seq in sequences
|
| 147 |
+
]
|
| 148 |
+
else:
|
| 149 |
+
entries = [FastaEntry("", seq) for seq in sequences]
|
| 150 |
+
return cls(entries)
|
| 151 |
+
|
| 152 |
+
def to_sequence_bytes(self) -> bytes:
|
| 153 |
+
"""Stores ONLY SEQUENCES in array format as bytes. Header information will be lost."""
|
| 154 |
+
seqlen_bytes = self.seqlen.to_bytes(4, "little")
|
| 155 |
+
array_bytes = self.array.tobytes()
|
| 156 |
+
all_bytes = seqlen_bytes + array_bytes
|
| 157 |
+
return all_bytes
|
| 158 |
+
|
| 159 |
+
@classmethod
|
| 160 |
+
def from_sequence_bytes(cls, data: bytes) -> MSA:
|
| 161 |
+
seqlen_bytes, array_bytes = data[:4], data[4:]
|
| 162 |
+
seqlen = int.from_bytes(seqlen_bytes, "little")
|
| 163 |
+
array = np.frombuffer(array_bytes, dtype="|S1")
|
| 164 |
+
array = array.reshape(-1, seqlen)
|
| 165 |
+
entries = [FastaEntry("", b"".join(row).decode()) for row in array]
|
| 166 |
+
return cls(entries)
|
| 167 |
+
|
| 168 |
+
@property
|
| 169 |
+
def depth(self) -> int:
|
| 170 |
+
return len(self.entries)
|
| 171 |
+
|
| 172 |
+
@property
|
| 173 |
+
def seqlen(self) -> int:
|
| 174 |
+
return len(self.entries[0].sequence)
|
| 175 |
+
|
| 176 |
+
@cached_property
|
| 177 |
+
def array(self) -> np.ndarray:
|
| 178 |
+
return np.array([list(seq) for seq in self.sequences], dtype="|S1")
|
| 179 |
+
|
| 180 |
+
@property
|
| 181 |
+
def query(self) -> str:
|
| 182 |
+
return self.entries[0].sequence
|
| 183 |
+
|
| 184 |
+
def select_sequences(self, indices: Sequence[int] | np.ndarray) -> MSA:
|
| 185 |
+
"""Subselect rows of the MSA."""
|
| 186 |
+
entries = [self.entries[idx] for idx in indices]
|
| 187 |
+
return dataclasses.replace(self, entries=entries)
|
| 188 |
+
|
| 189 |
+
def select_positions(self, indices: Sequence[int] | np.ndarray) -> MSA:
|
| 190 |
+
"""Subselect columns of the MSA."""
|
| 191 |
+
entries = [
|
| 192 |
+
FastaEntry(header, "".join(seq[idx] for idx in indices))
|
| 193 |
+
for header, seq in self.entries
|
| 194 |
+
]
|
| 195 |
+
return dataclasses.replace(self, entries=entries)
|
| 196 |
+
|
| 197 |
+
def __getitem__(self, indices: int | list[int] | slice | np.ndarray):
|
| 198 |
+
if isinstance(indices, int):
|
| 199 |
+
indices = [indices]
|
| 200 |
+
|
| 201 |
+
entries = [
|
| 202 |
+
FastaEntry(header, slice_any_object(seq, indices))
|
| 203 |
+
for header, seq in self.entries
|
| 204 |
+
]
|
| 205 |
+
return dataclasses.replace(self, entries=entries)
|
| 206 |
+
|
| 207 |
+
def __len__(self):
|
| 208 |
+
return self.seqlen
|
| 209 |
+
|
| 210 |
+
def greedy_select(self, num_seqs: int, mode: str = "max") -> MSA:
|
| 211 |
+
"""Greedily select sequences that either maximize or minimize hamming distance.
|
| 212 |
+
|
| 213 |
+
Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,
|
| 214 |
+
iteratively add sequences to the list with the maximum (minimum) average Hamming
|
| 215 |
+
distance to the existing set of sequences.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
num_seqs (int): Number of sequences to select.
|
| 219 |
+
mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless
|
| 220 |
+
you're doing it to prove a point for a paper.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
MSA object w/ subselected sequences.
|
| 224 |
+
"""
|
| 225 |
+
assert mode in ("max", "min")
|
| 226 |
+
if self.depth <= num_seqs:
|
| 227 |
+
return self
|
| 228 |
+
|
| 229 |
+
indices = greedy_select_indices(self.array, num_seqs, mode)
|
| 230 |
+
return self.select_sequences(indices)
|
| 231 |
+
|
| 232 |
+
def hhfilter(
|
| 233 |
+
self,
|
| 234 |
+
seqid: int = 90,
|
| 235 |
+
diff: int = 0,
|
| 236 |
+
cov: int = 0,
|
| 237 |
+
qid: int = 0,
|
| 238 |
+
qsc: float = -20.0,
|
| 239 |
+
binary: str = "hhfilter",
|
| 240 |
+
) -> MSA:
|
| 241 |
+
"""Apply hhfilter to the sequences in the MSA and return a filtered MSA."""
|
| 242 |
+
|
| 243 |
+
indices = hhfilter(
|
| 244 |
+
self.sequences,
|
| 245 |
+
seqid=seqid,
|
| 246 |
+
diff=diff,
|
| 247 |
+
cov=cov,
|
| 248 |
+
qid=qid,
|
| 249 |
+
qsc=qsc,
|
| 250 |
+
binary=binary,
|
| 251 |
+
)
|
| 252 |
+
return self.select_sequences(indices)
|
| 253 |
+
|
| 254 |
+
def select_random_sequences(self, num_seqs: int) -> MSA:
|
| 255 |
+
"""Uses random sampling to subselect sequences from the MSA. Always
|
| 256 |
+
keeps the query sequence.
|
| 257 |
+
"""
|
| 258 |
+
if num_seqs >= self.depth:
|
| 259 |
+
return self
|
| 260 |
+
|
| 261 |
+
# Subselect random, always keeping the query sequence.
|
| 262 |
+
indices = np.sort(
|
| 263 |
+
np.append(
|
| 264 |
+
0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1
|
| 265 |
+
)
|
| 266 |
+
)
|
| 267 |
+
msa = self.select_sequences(indices) # type: ignore
|
| 268 |
+
return msa
|
| 269 |
+
|
| 270 |
+
def select_diverse_sequences(self, num_seqs: int) -> MSA:
|
| 271 |
+
"""Applies hhfilter to select ~num_seqs sequences, then uses random sampling
|
| 272 |
+
to subselect if necessary.
|
| 273 |
+
"""
|
| 274 |
+
if num_seqs >= self.depth:
|
| 275 |
+
return self
|
| 276 |
+
|
| 277 |
+
msa = self.hhfilter(diff=num_seqs)
|
| 278 |
+
if num_seqs < msa.depth:
|
| 279 |
+
msa = msa.select_random_sequences(num_seqs)
|
| 280 |
+
return msa
|
| 281 |
+
|
| 282 |
+
def pad_to_depth(self, depth: int) -> MSA:
|
| 283 |
+
if depth < self.depth:
|
| 284 |
+
raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}")
|
| 285 |
+
elif depth == self.depth:
|
| 286 |
+
return self
|
| 287 |
+
|
| 288 |
+
num_to_add = depth - self.depth
|
| 289 |
+
extra_entries = [FastaEntry("", "-" * self.seqlen) for _ in range(num_to_add)]
|
| 290 |
+
return dataclasses.replace(self, entries=self.entries + extra_entries)
|
| 291 |
+
|
| 292 |
+
@classmethod
|
| 293 |
+
def stack(
|
| 294 |
+
cls, msas: Sequence[MSA], remove_query_from_later_msas: bool = True
|
| 295 |
+
) -> MSA:
|
| 296 |
+
"""Stack a series of MSAs. Optionally remove the query from msas after the first."""
|
| 297 |
+
all_entries = []
|
| 298 |
+
for i, msa in enumerate(msas):
|
| 299 |
+
entries = msa.entries
|
| 300 |
+
if i > 0 and remove_query_from_later_msas:
|
| 301 |
+
entries = entries[1:]
|
| 302 |
+
all_entries.extend(entries)
|
| 303 |
+
return cls(entries=all_entries)
|
| 304 |
+
|
| 305 |
+
@cached_property
|
| 306 |
+
def seqid(self) -> np.ndarray:
|
| 307 |
+
array = self.array.view(np.uint8)
|
| 308 |
+
seqid = 1 - cdist(array[0][None], array, "hamming")
|
| 309 |
+
return seqid[0]
|
| 310 |
+
|
| 311 |
+
@classmethod
|
| 312 |
+
def concat(
|
| 313 |
+
cls,
|
| 314 |
+
msas: Sequence[MSA],
|
| 315 |
+
join_token: str | None = "|",
|
| 316 |
+
allow_depth_mismatch: bool = False,
|
| 317 |
+
) -> MSA:
|
| 318 |
+
"""Concatenate a series of MSAs horizontally, along the sequence dimension."""
|
| 319 |
+
if not msas:
|
| 320 |
+
raise ValueError("Cannot concatenate an empty list of MSAs")
|
| 321 |
+
msa_depths = [msa.depth for msa in msas]
|
| 322 |
+
if len(set(msa_depths)) != 1:
|
| 323 |
+
if not allow_depth_mismatch:
|
| 324 |
+
raise ValueError("Depth mismatch in concatenating MSAs")
|
| 325 |
+
else:
|
| 326 |
+
max_depth = max(msa_depths)
|
| 327 |
+
msas = [msa.pad_to_depth(max_depth) for msa in msas]
|
| 328 |
+
headers = [
|
| 329 |
+
"|".join([str(h) for h in headers])
|
| 330 |
+
for headers in zip(*(msa.headers for msa in msas))
|
| 331 |
+
]
|
| 332 |
+
|
| 333 |
+
if join_token is None:
|
| 334 |
+
join_token = ""
|
| 335 |
+
|
| 336 |
+
seqs = [join_token.join(vals) for vals in zip(*(msa.sequences for msa in msas))]
|
| 337 |
+
entries = [FastaEntry(header, seq) for header, seq in zip(headers, seqs)]
|
| 338 |
+
return cls(entries)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
@dataclass(frozen=True)
|
| 342 |
+
class FastMSA(SequentialDataclass):
|
| 343 |
+
"""Object-oriented interface to an MSA stored as a numpy uint8 array."""
|
| 344 |
+
|
| 345 |
+
array: np.ndarray
|
| 346 |
+
headers: list[str] | None = None
|
| 347 |
+
|
| 348 |
+
def __post_init__(self):
|
| 349 |
+
if self.headers is not None:
|
| 350 |
+
assert (
|
| 351 |
+
len(self.headers) == self.depth
|
| 352 |
+
), "Number of headers must match depth."
|
| 353 |
+
|
| 354 |
+
@classmethod
|
| 355 |
+
def from_bytes(cls, data: bytes) -> FastMSA:
|
| 356 |
+
version_bytes, seqlen_bytes, depth_bytes, data = (
|
| 357 |
+
data[:1],
|
| 358 |
+
data[1:5],
|
| 359 |
+
data[5:9],
|
| 360 |
+
data[9:],
|
| 361 |
+
)
|
| 362 |
+
version = int.from_bytes(version_bytes, "little")
|
| 363 |
+
if version != 1:
|
| 364 |
+
raise ValueError(f"Unsupported version: {version}")
|
| 365 |
+
seqlen = int.from_bytes(seqlen_bytes, "little")
|
| 366 |
+
depth = int.from_bytes(depth_bytes, "little")
|
| 367 |
+
array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :]
|
| 368 |
+
array = np.frombuffer(array_bytes, dtype="|S1")
|
| 369 |
+
array = array.reshape(depth, seqlen)
|
| 370 |
+
headers = header_bytes.decode().split("\n")
|
| 371 |
+
# Sometimes the separation is two newlines, which results in an empty header.
|
| 372 |
+
headers = [header for header in headers if header]
|
| 373 |
+
# If all headers were empty (e.g., saved from from_sequences), use empty headers
|
| 374 |
+
if len(headers) == 0 and depth > 0:
|
| 375 |
+
headers = [""] * depth
|
| 376 |
+
return cls(array, headers)
|
| 377 |
+
|
| 378 |
+
@classmethod
|
| 379 |
+
def from_sequence_bytes(cls, data: bytes) -> FastMSA:
|
| 380 |
+
seqlen_bytes, array_bytes = data[:4], data[4:]
|
| 381 |
+
seqlen = int.from_bytes(seqlen_bytes, "little")
|
| 382 |
+
array = np.frombuffer(array_bytes, dtype="|S1")
|
| 383 |
+
array = array.reshape(-1, seqlen)
|
| 384 |
+
return cls(array)
|
| 385 |
+
|
| 386 |
+
@property
|
| 387 |
+
def depth(self) -> int:
|
| 388 |
+
return self.array.shape[0]
|
| 389 |
+
|
| 390 |
+
@property
|
| 391 |
+
def seqlen(self) -> int:
|
| 392 |
+
return self.array.shape[1]
|
| 393 |
+
|
| 394 |
+
def __len__(self):
|
| 395 |
+
return self.seqlen
|
| 396 |
+
|
| 397 |
+
def __getitem__(self, indices: int | list[int] | slice | np.ndarray):
|
| 398 |
+
if isinstance(indices, int):
|
| 399 |
+
indices = [indices]
|
| 400 |
+
|
| 401 |
+
return dataclasses.replace(self, array=self.array[:, indices])
|
| 402 |
+
|
| 403 |
+
def select_sequences(self, indices: Sequence[int] | np.ndarray) -> FastMSA:
|
| 404 |
+
"""Subselect rows of the MSA."""
|
| 405 |
+
array = self.array[indices]
|
| 406 |
+
headers = (
|
| 407 |
+
[self.headers[idx] for idx in indices] if self.headers is not None else None
|
| 408 |
+
)
|
| 409 |
+
return dataclasses.replace(self, array=array, headers=headers)
|
| 410 |
+
|
| 411 |
+
def select_random_sequences(self, num_seqs: int) -> FastMSA:
|
| 412 |
+
"""Uses random sampling to subselect sequences from the MSA. Always
|
| 413 |
+
keeps the query sequence.
|
| 414 |
+
"""
|
| 415 |
+
if num_seqs >= self.depth:
|
| 416 |
+
return self
|
| 417 |
+
|
| 418 |
+
# Subselect random, always keeping the query sequence.
|
| 419 |
+
indices = np.sort(
|
| 420 |
+
np.append(
|
| 421 |
+
0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1
|
| 422 |
+
)
|
| 423 |
+
)
|
| 424 |
+
msa = self.select_sequences(indices) # type: ignore
|
| 425 |
+
return msa
|
| 426 |
+
|
| 427 |
+
def pad_to_depth(self, depth: int) -> FastMSA:
|
| 428 |
+
if depth < self.depth:
|
| 429 |
+
raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}")
|
| 430 |
+
elif depth == self.depth:
|
| 431 |
+
return self
|
| 432 |
+
|
| 433 |
+
num_to_add = depth - self.depth
|
| 434 |
+
array = np.pad(
|
| 435 |
+
self.array,
|
| 436 |
+
[(0, num_to_add), (0, 0)],
|
| 437 |
+
constant_values=ord("-") if self.array.dtype == np.uint8 else b"-",
|
| 438 |
+
)
|
| 439 |
+
headers = self.headers
|
| 440 |
+
if headers is not None:
|
| 441 |
+
headers = headers + [""] * num_to_add
|
| 442 |
+
return dataclasses.replace(self, array=array, headers=headers)
|
| 443 |
+
|
| 444 |
+
@classmethod
|
| 445 |
+
def concat(
|
| 446 |
+
cls,
|
| 447 |
+
msas: Sequence[FastMSA],
|
| 448 |
+
join_token: str | None = None,
|
| 449 |
+
allow_depth_mismatch: bool = False,
|
| 450 |
+
) -> FastMSA:
|
| 451 |
+
"""Concatenate a series of MSAs horizontally, along the sequence dimension."""
|
| 452 |
+
if not msas:
|
| 453 |
+
raise ValueError("Cannot concatenate an empty list of MSAs")
|
| 454 |
+
if join_token is not None and join_token != "":
|
| 455 |
+
raise NotImplementedError("join_token is not supported for FastMSA")
|
| 456 |
+
|
| 457 |
+
msa_depths = [msa.depth for msa in msas]
|
| 458 |
+
if len(set(msa_depths)) != 1:
|
| 459 |
+
if not allow_depth_mismatch:
|
| 460 |
+
raise ValueError("Depth mismatch in concatenating MSAs")
|
| 461 |
+
else:
|
| 462 |
+
max_depth = max(msa_depths)
|
| 463 |
+
msas = [msa.pad_to_depth(max_depth) for msa in msas]
|
| 464 |
+
headers = [
|
| 465 |
+
"|".join([str(h) for h in headers])
|
| 466 |
+
for headers in zip(
|
| 467 |
+
*(
|
| 468 |
+
msa.headers if msa.headers is not None else [""] * msa.depth
|
| 469 |
+
for msa in msas
|
| 470 |
+
)
|
| 471 |
+
)
|
| 472 |
+
]
|
| 473 |
+
|
| 474 |
+
array = np.concatenate([msa.array for msa in msas], axis=1)
|
| 475 |
+
return cls(array, headers)
|
| 476 |
+
|
| 477 |
+
def to_msa(self) -> MSA:
|
| 478 |
+
headers = (
|
| 479 |
+
self.headers
|
| 480 |
+
if self.headers is not None
|
| 481 |
+
else [f"seq{i}" for i in range(self.depth)]
|
| 482 |
+
)
|
| 483 |
+
entries = [
|
| 484 |
+
FastaEntry(header, b"".join(row).decode())
|
| 485 |
+
for header, row in zip(headers, self.array)
|
| 486 |
+
]
|
| 487 |
+
return MSA(entries)
|
| 488 |
+
|
| 489 |
+
@classmethod
|
| 490 |
+
def stack(
|
| 491 |
+
cls, msas: Sequence[FastMSA], remove_query_from_later_msas: bool = True
|
| 492 |
+
) -> FastMSA:
|
| 493 |
+
"""Stack a series of MSAs. Optionally remove the query from msas after the first."""
|
| 494 |
+
arrays = []
|
| 495 |
+
all_headers = []
|
| 496 |
+
for i, msa in enumerate(msas):
|
| 497 |
+
array = msa.array
|
| 498 |
+
headers = msa.headers
|
| 499 |
+
if i > 0 and remove_query_from_later_msas:
|
| 500 |
+
array = array[1:]
|
| 501 |
+
if headers is not None:
|
| 502 |
+
headers = headers[1:]
|
| 503 |
+
arrays.append(array)
|
| 504 |
+
if headers is not None:
|
| 505 |
+
all_headers.extend(headers)
|
| 506 |
+
return cls(np.concatenate(arrays, axis=0), all_headers)
|
| 507 |
+
|
esmfold2_msa_filter_sequences.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.spatial.distance import cdist
|
| 7 |
+
|
| 8 |
+
from .esmfold2_system import run_subprocess_with_errorcheck
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def greedy_select_indices(array, num_seqs: int, mode: str = "max") -> list[int]:
|
| 12 |
+
"""Greedily select sequences that either maximize or minimize hamming distance.
|
| 13 |
+
|
| 14 |
+
Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,
|
| 15 |
+
iteratively add sequences to the list with the maximum (minimum) average Hamming
|
| 16 |
+
distance to the existing set of sequences.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
array (np.ndarray): Character array representing the sequences in the MSA
|
| 20 |
+
num_seqs (int): Number of sequences to select.
|
| 21 |
+
mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless
|
| 22 |
+
you're doing it to prove a point for a paper.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
list[int]: List of indices to select from the array
|
| 26 |
+
"""
|
| 27 |
+
assert mode in ("max", "min")
|
| 28 |
+
depth = array.shape[0]
|
| 29 |
+
if depth <= num_seqs:
|
| 30 |
+
return list(range(depth))
|
| 31 |
+
array = array.view(np.uint8)
|
| 32 |
+
|
| 33 |
+
optfunc = np.argmax if mode == "max" else np.argmin
|
| 34 |
+
all_indices = np.arange(depth)
|
| 35 |
+
indices = [0]
|
| 36 |
+
pairwise_distances = np.zeros((0, depth))
|
| 37 |
+
for _ in range(num_seqs - 1):
|
| 38 |
+
dist = cdist(array[indices[-1:]], array, "hamming")
|
| 39 |
+
pairwise_distances = np.concatenate([pairwise_distances, dist])
|
| 40 |
+
shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
|
| 41 |
+
shifted_index = optfunc(shifted_distance)
|
| 42 |
+
index = np.delete(all_indices, indices)[shifted_index]
|
| 43 |
+
indices.append(index)
|
| 44 |
+
indices = sorted(indices)
|
| 45 |
+
return indices
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def hhfilter(
|
| 49 |
+
sequences: list[str],
|
| 50 |
+
seqid: int = 90,
|
| 51 |
+
diff: int = 0,
|
| 52 |
+
cov: int = 0,
|
| 53 |
+
qid: int = 0,
|
| 54 |
+
qsc: float = -20.0,
|
| 55 |
+
binary: str = "hhfilter",
|
| 56 |
+
) -> list[int]:
|
| 57 |
+
with tempfile.TemporaryDirectory(
|
| 58 |
+
dir="/dev/shm" if os.path.exists("/dev/shm") else None
|
| 59 |
+
) as tempdirname:
|
| 60 |
+
tempdir = Path(tempdirname)
|
| 61 |
+
fasta_file = tempdir / "input.fasta"
|
| 62 |
+
fasta_file.write_text(
|
| 63 |
+
"\n".join(f">{i}\n{seq}" for i, seq in enumerate(sequences))
|
| 64 |
+
)
|
| 65 |
+
output_file = tempdir / "output.fasta"
|
| 66 |
+
command = " ".join(
|
| 67 |
+
[
|
| 68 |
+
f"{binary}",
|
| 69 |
+
f"-i {fasta_file}",
|
| 70 |
+
"-M a3m",
|
| 71 |
+
f"-o {output_file}",
|
| 72 |
+
f"-id {seqid}",
|
| 73 |
+
f"-diff {diff}",
|
| 74 |
+
f"-cov {cov}",
|
| 75 |
+
f"-qid {qid}",
|
| 76 |
+
f"-qsc {qsc}",
|
| 77 |
+
]
|
| 78 |
+
).split(" ")
|
| 79 |
+
run_subprocess_with_errorcheck(command, capture_output=True)
|
| 80 |
+
with output_file.open() as f:
|
| 81 |
+
indices = [int(line[1:].strip()) for line in f if line.startswith(">")]
|
| 82 |
+
return indices
|
| 83 |
+
|
esmfold2_normalize_coordinates.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypeVar
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
from . import esmfold2_residue_constants as RC
|
| 8 |
+
from .esmfold2_affine3d import Affine3D
|
| 9 |
+
|
| 10 |
+
ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def atom3_to_backbone_frames(bb_positions: torch.Tensor) -> Affine3D:
|
| 14 |
+
N, CA, C = bb_positions.unbind(dim=-2)
|
| 15 |
+
return Affine3D.from_graham_schmidt(C, CA, N)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def index_by_atom_name(
|
| 19 |
+
atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2
|
| 20 |
+
) -> ArrayOrTensor:
|
| 21 |
+
squeeze = False
|
| 22 |
+
if isinstance(atom_names, str):
|
| 23 |
+
atom_names = [atom_names]
|
| 24 |
+
squeeze = True
|
| 25 |
+
indices = [RC.atom_order[atom_name] for atom_name in atom_names]
|
| 26 |
+
dim = dim % atom37.ndim
|
| 27 |
+
index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim))
|
| 28 |
+
result = atom37[index] # type: ignore
|
| 29 |
+
if squeeze:
|
| 30 |
+
result = result.squeeze(dim)
|
| 31 |
+
return result
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_protein_normalization_frame(coords: Tensor) -> Affine3D:
|
| 35 |
+
"""Given a set of coordinates for a protein, compute a single frame that can be used to normalize the coordinates.
|
| 36 |
+
Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame
|
| 37 |
+
using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
Affine3D: tensor of Affine3D frame
|
| 44 |
+
"""
|
| 45 |
+
bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2)
|
| 46 |
+
coord_mask = torch.all(torch.all(torch.isfinite(bb_coords), dim=-1), dim=-1)
|
| 47 |
+
|
| 48 |
+
average_position_per_n_ca_c = bb_coords.masked_fill(
|
| 49 |
+
~coord_mask[..., None, None], 0
|
| 50 |
+
).sum(-3) / (coord_mask.sum(-1)[..., None, None] + 1e-8)
|
| 51 |
+
frame = atom3_to_backbone_frames(average_position_per_n_ca_c.float())
|
| 52 |
+
|
| 53 |
+
return frame
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def apply_frame_to_coords(coords: Tensor, frame: Affine3D) -> Tensor:
|
| 57 |
+
"""Given a set of coordinates and a single frame, apply the frame to the coordinates.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates
|
| 61 |
+
frame (Affine3D): Affine3D frame
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
torch.FloatTensor: [L, 37, 3] tensor of transformed coordinates
|
| 65 |
+
"""
|
| 66 |
+
coords_trans_rot = frame[..., None, None].invert().apply(coords)
|
| 67 |
+
|
| 68 |
+
# only transform coordinates with frame that have a valid rotation
|
| 69 |
+
valid_frame = frame.trans.norm(dim=-1) > 0
|
| 70 |
+
|
| 71 |
+
is_inf = torch.isinf(coords)
|
| 72 |
+
coords = coords_trans_rot.where(valid_frame[..., None, None, None], coords)
|
| 73 |
+
coords.masked_fill_(is_inf, torch.inf)
|
| 74 |
+
|
| 75 |
+
return coords
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def normalize_coordinates(coords: Tensor) -> Tensor:
|
| 79 |
+
return apply_frame_to_coords(coords, get_protein_normalization_frame(coords))
|
| 80 |
+
|
esmfold2_output.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from itertools import groupby
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .esmfold2_constants import ELEMENT_NUMBER_TO_SYMBOL, MOL_TYPE_NONPOLYMER
|
| 8 |
+
from .esmfold2_molecular_complex import (
|
| 9 |
+
MolecularComplex,
|
| 10 |
+
MolecularComplexMetadata,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_element_symbol(atomic_num: int) -> str:
|
| 15 |
+
return ELEMENT_NUMBER_TO_SYMBOL.get(atomic_num, "X")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def build_molecular_complex_from_features(
|
| 19 |
+
coords: torch.Tensor,
|
| 20 |
+
plddt: torch.Tensor,
|
| 21 |
+
atom_mask: torch.Tensor,
|
| 22 |
+
ref_element: torch.Tensor,
|
| 23 |
+
ref_atom_name_chars: torch.Tensor,
|
| 24 |
+
chain_infos: list,
|
| 25 |
+
complex_id: str,
|
| 26 |
+
) -> MolecularComplex:
|
| 27 |
+
"""Construct a MolecularComplex from feature-dict tensors and chain metadata.
|
| 28 |
+
|
| 29 |
+
Non-polymer chains (ligands) collapse all per-atom tokens into a single
|
| 30 |
+
residue token whose pLDDT is the per-token average and whose hetero flag
|
| 31 |
+
is True.
|
| 32 |
+
"""
|
| 33 |
+
mask_np = atom_mask.bool().cpu().numpy()
|
| 34 |
+
coords_np = coords.float().cpu().numpy()
|
| 35 |
+
name_chars_np = ref_atom_name_chars.cpu().numpy()
|
| 36 |
+
elements_np = ref_element.cpu().numpy()
|
| 37 |
+
plddt_np = plddt.float().cpu().numpy()
|
| 38 |
+
|
| 39 |
+
sequence_tokens: list[str] = []
|
| 40 |
+
chain_ids_per_token: list[int] = []
|
| 41 |
+
token_to_atoms: list[list[int]] = []
|
| 42 |
+
confidence: list[float] = []
|
| 43 |
+
flat_positions: list[list[float]] = []
|
| 44 |
+
flat_elements: list[str] = []
|
| 45 |
+
flat_names: list[str] = []
|
| 46 |
+
flat_hetero: list[bool] = []
|
| 47 |
+
|
| 48 |
+
chain_lookup: dict[int, str] = {}
|
| 49 |
+
entity_info: dict[int, str] = {}
|
| 50 |
+
out_atom_cursor = 0
|
| 51 |
+
|
| 52 |
+
for ci in chain_infos:
|
| 53 |
+
chain_lookup[ci.asym_id] = ci.chain_id
|
| 54 |
+
is_nonpolymer = ci.mol_type == MOL_TYPE_NONPOLYMER
|
| 55 |
+
entity_info[ci.entity_id] = "non-polymer" if is_nonpolymer else "polymer"
|
| 56 |
+
|
| 57 |
+
if is_nonpolymer:
|
| 58 |
+
residue_name = ci.tokens[0].residue_name if ci.tokens else "LIG"
|
| 59 |
+
sequence_tokens.append(residue_name)
|
| 60 |
+
chain_ids_per_token.append(ci.asym_id)
|
| 61 |
+
avg_plddt = (
|
| 62 |
+
float(np.mean([plddt_np[ti.token_index] for ti in ci.tokens]))
|
| 63 |
+
if ci.tokens
|
| 64 |
+
else 0.0
|
| 65 |
+
)
|
| 66 |
+
confidence.append(avg_plddt)
|
| 67 |
+
token_atom_start = out_atom_cursor
|
| 68 |
+
for ti in ci.tokens:
|
| 69 |
+
for atom_idx in range(ti.atom_start, ti.atom_start + ti.atom_count):
|
| 70 |
+
if not mask_np[atom_idx]:
|
| 71 |
+
continue
|
| 72 |
+
flat_positions.append(coords_np[atom_idx].tolist())
|
| 73 |
+
flat_elements.append(get_element_symbol(int(elements_np[atom_idx])))
|
| 74 |
+
chars = name_chars_np[atom_idx]
|
| 75 |
+
name = "".join(
|
| 76 |
+
chr(int(c) + 32) for c in chars if int(c) != 0
|
| 77 |
+
).strip()
|
| 78 |
+
flat_names.append(name)
|
| 79 |
+
flat_hetero.append(True)
|
| 80 |
+
out_atom_cursor += 1
|
| 81 |
+
token_to_atoms.append([token_atom_start, out_atom_cursor])
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
# Atom-tokenized modified residues (HYP, MSE, ...) span multiple
|
| 85 |
+
# tokens per residue; collapse them back to one mmCIF residue.
|
| 86 |
+
for _residue_index, ti_iter in groupby(
|
| 87 |
+
ci.tokens, key=lambda t: t.residue_index
|
| 88 |
+
):
|
| 89 |
+
ti_group = list(ti_iter)
|
| 90 |
+
sequence_tokens.append(ti_group[0].residue_name)
|
| 91 |
+
chain_ids_per_token.append(ci.asym_id)
|
| 92 |
+
confidence.append(
|
| 93 |
+
float(np.mean([plddt_np[ti.token_index] for ti in ti_group]))
|
| 94 |
+
)
|
| 95 |
+
token_atom_start = out_atom_cursor
|
| 96 |
+
for ti in ti_group:
|
| 97 |
+
for atom_idx in range(ti.atom_start, ti.atom_start + ti.atom_count):
|
| 98 |
+
if not mask_np[atom_idx]:
|
| 99 |
+
continue
|
| 100 |
+
flat_positions.append(coords_np[atom_idx].tolist())
|
| 101 |
+
flat_elements.append(get_element_symbol(int(elements_np[atom_idx])))
|
| 102 |
+
chars = name_chars_np[atom_idx]
|
| 103 |
+
name = "".join(
|
| 104 |
+
chr(int(c) + 32) for c in chars if int(c) != 0
|
| 105 |
+
).strip()
|
| 106 |
+
flat_names.append(name)
|
| 107 |
+
flat_hetero.append(False)
|
| 108 |
+
out_atom_cursor += 1
|
| 109 |
+
token_to_atoms.append([token_atom_start, out_atom_cursor])
|
| 110 |
+
|
| 111 |
+
return MolecularComplex(
|
| 112 |
+
id=complex_id,
|
| 113 |
+
sequence=sequence_tokens,
|
| 114 |
+
atom_positions=np.array(flat_positions, dtype=np.float32).reshape(-1, 3),
|
| 115 |
+
atom_elements=np.array(flat_elements, dtype=object),
|
| 116 |
+
token_to_atoms=np.array(token_to_atoms, dtype=np.int32).reshape(-1, 2),
|
| 117 |
+
chain_id=np.array(chain_ids_per_token, dtype=np.int64),
|
| 118 |
+
plddt=np.array(confidence, dtype=np.float32),
|
| 119 |
+
atom_names=np.array(flat_names, dtype=object),
|
| 120 |
+
atom_hetero=np.array(flat_hetero, dtype=bool),
|
| 121 |
+
metadata=MolecularComplexMetadata(
|
| 122 |
+
entity_lookup=entity_info,
|
| 123 |
+
chain_lookup=chain_lookup,
|
| 124 |
+
assembly_composition=None,
|
| 125 |
+
),
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def build_molecular_complex(
|
| 130 |
+
structure: Any, coords: torch.Tensor, plddt: torch.Tensor, complex_id: str
|
| 131 |
+
) -> MolecularComplex:
|
| 132 |
+
"""Directly constructs a MolecularComplex from model outputs without intermediate files.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
structure: Object with .chains, .residues, .atoms numpy structured arrays.
|
| 136 |
+
coords: [N_atoms, 3] predicted atom coordinates.
|
| 137 |
+
plddt: [N_residues] per-residue confidence scores.
|
| 138 |
+
complex_id: Identifier string for the resulting complex.
|
| 139 |
+
"""
|
| 140 |
+
flat_positions = []
|
| 141 |
+
flat_elements = []
|
| 142 |
+
flat_names = []
|
| 143 |
+
flat_hetero = []
|
| 144 |
+
|
| 145 |
+
sequence_tokens = []
|
| 146 |
+
token_to_atoms = []
|
| 147 |
+
chain_ids_per_token = []
|
| 148 |
+
confidence_scores = []
|
| 149 |
+
|
| 150 |
+
chain_lookup = {}
|
| 151 |
+
entity_info = {}
|
| 152 |
+
|
| 153 |
+
global_atom_cursor = 0
|
| 154 |
+
global_res_cursor = 0
|
| 155 |
+
atom_array_idx = 0
|
| 156 |
+
|
| 157 |
+
for chain in structure.chains:
|
| 158 |
+
chain_idx_numeric = chain["asym_id"]
|
| 159 |
+
chain_name_str = str(chain["name"])
|
| 160 |
+
mol_type = chain["mol_type"]
|
| 161 |
+
|
| 162 |
+
chain_lookup[chain_idx_numeric] = chain_name_str
|
| 163 |
+
entity_info[chain["entity_id"]] = (
|
| 164 |
+
"polymer" if mol_type != MOL_TYPE_NONPOLYMER else "non-polymer"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
res_start = chain["res_idx"]
|
| 168 |
+
res_end = chain["res_idx"] + chain["res_num"]
|
| 169 |
+
residues = structure.residues[res_start:res_end]
|
| 170 |
+
|
| 171 |
+
for residue in residues:
|
| 172 |
+
res_name = str(residue["name"])
|
| 173 |
+
|
| 174 |
+
sequence_tokens.append(res_name)
|
| 175 |
+
chain_ids_per_token.append(chain_idx_numeric)
|
| 176 |
+
|
| 177 |
+
score = plddt[global_res_cursor].item()
|
| 178 |
+
confidence_scores.append(score)
|
| 179 |
+
token_start_idx = atom_array_idx
|
| 180 |
+
|
| 181 |
+
atom_start = residue["atom_idx"]
|
| 182 |
+
atom_end = residue["atom_idx"] + residue["atom_num"]
|
| 183 |
+
atoms = structure.atoms[atom_start:atom_end]
|
| 184 |
+
|
| 185 |
+
for atom in atoms:
|
| 186 |
+
if not atom["is_present"]:
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
pos = coords[global_atom_cursor].tolist()
|
| 190 |
+
flat_positions.append(pos)
|
| 191 |
+
|
| 192 |
+
elem = get_element_symbol(atom["element"].item())
|
| 193 |
+
flat_elements.append(elem)
|
| 194 |
+
|
| 195 |
+
raw_name = atom["name"]
|
| 196 |
+
if hasattr(raw_name, "tolist"):
|
| 197 |
+
raw_name = raw_name.tolist()
|
| 198 |
+
name_str = "".join([chr(c + 32) for c in raw_name if c != 0])
|
| 199 |
+
flat_names.append(name_str)
|
| 200 |
+
|
| 201 |
+
flat_hetero.append(mol_type == MOL_TYPE_NONPOLYMER)
|
| 202 |
+
|
| 203 |
+
global_atom_cursor += 1
|
| 204 |
+
atom_array_idx += 1
|
| 205 |
+
|
| 206 |
+
token_to_atoms.append([token_start_idx, atom_array_idx])
|
| 207 |
+
global_res_cursor += 1
|
| 208 |
+
|
| 209 |
+
return MolecularComplex(
|
| 210 |
+
id=complex_id,
|
| 211 |
+
sequence=sequence_tokens,
|
| 212 |
+
atom_positions=np.array(flat_positions, dtype=np.float32),
|
| 213 |
+
atom_elements=np.array(flat_elements, dtype=object),
|
| 214 |
+
token_to_atoms=np.array(token_to_atoms, dtype=np.int32),
|
| 215 |
+
chain_id=np.array(chain_ids_per_token, dtype=np.int64),
|
| 216 |
+
plddt=np.array(confidence_scores, dtype=np.float32),
|
| 217 |
+
atom_names=np.array(flat_names, dtype=object),
|
| 218 |
+
atom_hetero=np.array(flat_hetero, dtype=bool),
|
| 219 |
+
metadata=MolecularComplexMetadata(
|
| 220 |
+
entity_lookup=entity_info,
|
| 221 |
+
chain_lookup=chain_lookup,
|
| 222 |
+
assembly_composition=None,
|
| 223 |
+
),
|
| 224 |
+
)
|
| 225 |
+
|
esmfold2_paired_msa.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Taxonomy-paired MSA construction for ESMFold2 inference.
|
| 2 |
+
|
| 3 |
+
Taxonomy IDs are read from FASTA headers as ``key=N`` tokens. Rows
|
| 4 |
+
where any chain has ``key=-1`` (or no ``key=`` at all) are treated as
|
| 5 |
+
unpaired and assigned to that chain's block-diagonal section after
|
| 6 |
+
the paired rows.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from .esmfold2_constants import (
|
| 14 |
+
MSA_GAP_TOKEN_ID,
|
| 15 |
+
PROTEIN_3TO1,
|
| 16 |
+
PROTEIN_RESIDUE_TO_RES_TYPE,
|
| 17 |
+
PROTEIN_UNK_RES_TYPE,
|
| 18 |
+
)
|
| 19 |
+
from .esmfold2_msa import MSA
|
| 20 |
+
|
| 21 |
+
_KEY_RE = re.compile(r"key=(-?\d+)")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def protein_letter_to_res_type() -> dict[str, int]:
|
| 25 |
+
"""Return the protein 1-letter → res_type mapping used by the MSA encoder."""
|
| 26 |
+
mapping: dict[str, int] = {}
|
| 27 |
+
for three, one in PROTEIN_3TO1.items():
|
| 28 |
+
if three in PROTEIN_RESIDUE_TO_RES_TYPE:
|
| 29 |
+
mapping[one] = PROTEIN_RESIDUE_TO_RES_TYPE[three]
|
| 30 |
+
mapping["-"] = MSA_GAP_TOKEN_ID
|
| 31 |
+
mapping["X"] = PROTEIN_UNK_RES_TYPE
|
| 32 |
+
return mapping
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _taxonomy_from_header(header: str) -> int:
|
| 36 |
+
if not header:
|
| 37 |
+
return -1
|
| 38 |
+
m = _KEY_RE.search(header)
|
| 39 |
+
return int(m.group(1)) if m else -1
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def msa_to_res_type_and_deletions(
|
| 43 |
+
msa: MSA, letter_to_res_type: dict[str, int]
|
| 44 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 45 |
+
"""Convert an :class:`MSA` to ``(res_type[M, L], deletion_count[M, L])``.
|
| 46 |
+
|
| 47 |
+
Handles a3m insertion convention: lowercase letters and ``.`` are
|
| 48 |
+
insertions and are not emitted; their count is accumulated into the
|
| 49 |
+
next non-insertion position's deletion value. ``L`` is the query
|
| 50 |
+
length after stripping insertions from row 0.
|
| 51 |
+
"""
|
| 52 |
+
query = msa.entries[0].sequence
|
| 53 |
+
L = sum(1 for ch in query if not (ch.islower() or ch == "."))
|
| 54 |
+
M = msa.depth
|
| 55 |
+
|
| 56 |
+
res_type = np.full((M, L), MSA_GAP_TOKEN_ID, dtype=np.int64)
|
| 57 |
+
deletions = np.zeros((M, L), dtype=np.float32)
|
| 58 |
+
|
| 59 |
+
for r, entry in enumerate(msa.entries):
|
| 60 |
+
col = 0
|
| 61 |
+
ins = 0
|
| 62 |
+
for ch in entry.sequence:
|
| 63 |
+
if ch == "." or (ch.islower() and ch != "-"):
|
| 64 |
+
ins += 1
|
| 65 |
+
continue
|
| 66 |
+
if col >= L:
|
| 67 |
+
break
|
| 68 |
+
if ch == "-":
|
| 69 |
+
res_type[r, col] = MSA_GAP_TOKEN_ID
|
| 70 |
+
else:
|
| 71 |
+
res_type[r, col] = letter_to_res_type.get(
|
| 72 |
+
ch.upper(), PROTEIN_UNK_RES_TYPE
|
| 73 |
+
)
|
| 74 |
+
if ins > 0:
|
| 75 |
+
deletions[r, col] = float(ins)
|
| 76 |
+
ins = 0
|
| 77 |
+
col += 1
|
| 78 |
+
return res_type, deletions
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _dummy_msa_residues(query_res_types: np.ndarray) -> np.ndarray:
|
| 82 |
+
"""Single-row 'MSA' for chains without one — just the query."""
|
| 83 |
+
return query_res_types[None, :] # [1, L]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def construct_paired_msa(
|
| 87 |
+
chain_msas: dict[int, MSA | None],
|
| 88 |
+
chain_query_res_types: dict[int, np.ndarray],
|
| 89 |
+
token_asym_ids: np.ndarray,
|
| 90 |
+
token_res_ids: np.ndarray,
|
| 91 |
+
letter_to_res_type: dict[str, int] | None = None,
|
| 92 |
+
*,
|
| 93 |
+
max_pairs: int = 8192,
|
| 94 |
+
max_total: int = 16384,
|
| 95 |
+
max_seqs: int = 16384,
|
| 96 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 97 |
+
"""Build paired MSA features.
|
| 98 |
+
|
| 99 |
+
Parameters
|
| 100 |
+
----------
|
| 101 |
+
chain_msas
|
| 102 |
+
``asym_id -> MSA`` (or ``None`` for chains without an MSA).
|
| 103 |
+
chain_query_res_types
|
| 104 |
+
``asym_id -> np.ndarray[L_c]`` of res-type ids for the chain's
|
| 105 |
+
query. Used to build dummy MSAs when a chain has no MSA.
|
| 106 |
+
token_asym_ids
|
| 107 |
+
Per-token asym_id, length ``T``. Must be non-decreasing.
|
| 108 |
+
token_res_ids
|
| 109 |
+
Per-token residue index within chain, length ``T``.
|
| 110 |
+
letter_to_res_type
|
| 111 |
+
1-letter → res-type mapping. Defaults to
|
| 112 |
+
:func:`protein_letter_to_res_type`.
|
| 113 |
+
|
| 114 |
+
Returns
|
| 115 |
+
-------
|
| 116 |
+
msa_residues : ``np.ndarray[M, T]`` int64
|
| 117 |
+
deletion_value : ``np.ndarray[M, T]`` float32 (raw deletion counts; the
|
| 118 |
+
``arctan(/3) * pi/2`` transform is applied by the caller)
|
| 119 |
+
is_paired : ``np.ndarray[M, T]`` float32 broadcast of per-row,
|
| 120 |
+
per-chain paired flags.
|
| 121 |
+
"""
|
| 122 |
+
if letter_to_res_type is None:
|
| 123 |
+
letter_to_res_type = protein_letter_to_res_type()
|
| 124 |
+
|
| 125 |
+
chain_ids: list[int] = sorted(chain_msas.keys())
|
| 126 |
+
|
| 127 |
+
# Build per-chain (res_type, deletions, taxonomy) tables.
|
| 128 |
+
chain_res_type: dict[int, np.ndarray] = {}
|
| 129 |
+
chain_deletions: dict[int, np.ndarray] = {}
|
| 130 |
+
chain_taxonomies: dict[int, list[int]] = {}
|
| 131 |
+
for c in chain_ids:
|
| 132 |
+
m = chain_msas.get(c)
|
| 133 |
+
if m is None or m.depth == 0:
|
| 134 |
+
qres = chain_query_res_types[c]
|
| 135 |
+
chain_res_type[c] = _dummy_msa_residues(qres)
|
| 136 |
+
chain_deletions[c] = np.zeros((1, qres.shape[0]), dtype=np.float32)
|
| 137 |
+
chain_taxonomies[c] = [-1]
|
| 138 |
+
continue
|
| 139 |
+
rt, dl = msa_to_res_type_and_deletions(m, letter_to_res_type)
|
| 140 |
+
chain_res_type[c] = rt
|
| 141 |
+
chain_deletions[c] = dl
|
| 142 |
+
chain_taxonomies[c] = [_taxonomy_from_header(e.header) for e in m.entries]
|
| 143 |
+
|
| 144 |
+
# Group by taxonomy, skip query row and unpaired (-1) entries.
|
| 145 |
+
taxonomy_map: dict[int, list[tuple[int, int]]] = {}
|
| 146 |
+
for c in chain_ids:
|
| 147 |
+
for seq_idx, taxon in enumerate(chain_taxonomies[c]):
|
| 148 |
+
if seq_idx == 0 or taxon == -1:
|
| 149 |
+
continue
|
| 150 |
+
taxonomy_map.setdefault(taxon, []).append((c, seq_idx))
|
| 151 |
+
taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1}
|
| 152 |
+
# Order taxonomies by number of distinct chains, descending.
|
| 153 |
+
sorted_taxa = sorted(
|
| 154 |
+
taxonomy_map.items(), key=lambda kv: len({c for c, _ in kv[1]}), reverse=True
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
visited = {s for _, items in taxonomy_map.items() for s in items}
|
| 158 |
+
available: dict[int, list[int]] = {
|
| 159 |
+
c: [i for i in range(1, len(chain_taxonomies[c])) if (c, i) not in visited]
|
| 160 |
+
for c in chain_ids
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
pairing: list[dict[int, int]] = [{c: 0 for c in chain_ids}]
|
| 164 |
+
is_paired: list[dict[int, int]] = [{c: 1 for c in chain_ids}]
|
| 165 |
+
|
| 166 |
+
for _, pairs in sorted_taxa:
|
| 167 |
+
per_chain: dict[int, list[int]] = {}
|
| 168 |
+
for c, seq_idx in pairs:
|
| 169 |
+
per_chain.setdefault(c, []).append(seq_idx)
|
| 170 |
+
max_occ = max(len(v) for v in per_chain.values())
|
| 171 |
+
for i in range(max_occ):
|
| 172 |
+
row_pairing: dict[int, int] = {}
|
| 173 |
+
row_is_paired: dict[int, int] = {}
|
| 174 |
+
for c, seq_idxs in per_chain.items():
|
| 175 |
+
row_pairing[c] = seq_idxs[i % len(seq_idxs)]
|
| 176 |
+
row_is_paired[c] = 1
|
| 177 |
+
for c in chain_ids:
|
| 178 |
+
if c in row_pairing:
|
| 179 |
+
continue
|
| 180 |
+
row_is_paired[c] = 0
|
| 181 |
+
if available[c]:
|
| 182 |
+
row_pairing[c] = available[c].pop(0)
|
| 183 |
+
else:
|
| 184 |
+
row_pairing[c] = -1
|
| 185 |
+
pairing.append(row_pairing)
|
| 186 |
+
is_paired.append(row_is_paired)
|
| 187 |
+
if len(pairing) >= max_pairs:
|
| 188 |
+
break
|
| 189 |
+
if len(pairing) >= max_pairs:
|
| 190 |
+
break
|
| 191 |
+
|
| 192 |
+
max_left = max((len(v) for v in available.values()), default=0)
|
| 193 |
+
for _ in range(min(max_total - len(pairing), max_left)):
|
| 194 |
+
row_pairing = {}
|
| 195 |
+
row_is_paired = {}
|
| 196 |
+
for c in chain_ids:
|
| 197 |
+
row_is_paired[c] = 0
|
| 198 |
+
if available[c]:
|
| 199 |
+
row_pairing[c] = available[c].pop(0)
|
| 200 |
+
else:
|
| 201 |
+
row_pairing[c] = -1
|
| 202 |
+
pairing.append(row_pairing)
|
| 203 |
+
is_paired.append(row_is_paired)
|
| 204 |
+
if len(pairing) >= max_total:
|
| 205 |
+
break
|
| 206 |
+
|
| 207 |
+
pairing = pairing[:max_seqs]
|
| 208 |
+
is_paired = is_paired[:max_seqs]
|
| 209 |
+
M = len(pairing)
|
| 210 |
+
T = len(token_asym_ids)
|
| 211 |
+
|
| 212 |
+
msa_residues = np.full((M, T), MSA_GAP_TOKEN_ID, dtype=np.int64)
|
| 213 |
+
deletion_value = np.zeros((M, T), dtype=np.float32)
|
| 214 |
+
paired_mask = np.zeros((M, T), dtype=np.float32)
|
| 215 |
+
|
| 216 |
+
# Vectorize per chain: gather chain rows according to pairing[c], then
|
| 217 |
+
# index into them by the chain's token residue ids.
|
| 218 |
+
for c in chain_ids:
|
| 219 |
+
rt = chain_res_type[c]
|
| 220 |
+
dl = chain_deletions[c]
|
| 221 |
+
Lc = rt.shape[1]
|
| 222 |
+
chain_pairing = np.array([row[c] for row in pairing], dtype=np.int64)
|
| 223 |
+
chain_paired = np.array([row[c] for row in is_paired], dtype=np.float32)
|
| 224 |
+
|
| 225 |
+
token_mask = token_asym_ids == c
|
| 226 |
+
if not token_mask.any():
|
| 227 |
+
continue
|
| 228 |
+
token_res_in_chain = token_res_ids[token_mask]
|
| 229 |
+
# Clamp residue indices to the MSA's column range. Modified-residue
|
| 230 |
+
# tokens that exceed the query length fall back to the last column.
|
| 231 |
+
cols = np.minimum(token_res_in_chain, Lc - 1)
|
| 232 |
+
|
| 233 |
+
# Rows where pairing == -1 fall back to gap (already initialized).
|
| 234 |
+
valid_rows = chain_pairing >= 0
|
| 235 |
+
if valid_rows.any():
|
| 236 |
+
gathered_rt = rt[chain_pairing[valid_rows]][:, cols]
|
| 237 |
+
gathered_dl = dl[chain_pairing[valid_rows]][:, cols]
|
| 238 |
+
valid_idx = np.where(valid_rows)[0]
|
| 239 |
+
token_idx = np.where(token_mask)[0]
|
| 240 |
+
msa_residues[np.ix_(valid_idx, token_idx)] = gathered_rt
|
| 241 |
+
deletion_value[np.ix_(valid_idx, token_idx)] = gathered_dl
|
| 242 |
+
|
| 243 |
+
paired_mask[:, token_mask] = chain_paired[:, None]
|
| 244 |
+
|
| 245 |
+
return msa_residues, deletion_value, paired_mask
|
| 246 |
+
|
esmfold2_parsing.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Generator, Iterable, NamedTuple
|
| 4 |
+
|
| 5 |
+
PathOrBuffer = str | Path | io.TextIOBase
|
| 6 |
+
FastaEntry = NamedTuple("FastaEntry", [("header", str), ("sequence", str)])
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_fasta(fasta_string: str) -> Generator[FastaEntry, None, None]:
|
| 10 |
+
"""
|
| 11 |
+
Parses a fasta file and yields FastaEntry objects
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
fasta_string: The fasta file as a string
|
| 15 |
+
Returns:
|
| 16 |
+
A generator of FastaEntry objects
|
| 17 |
+
"""
|
| 18 |
+
header = None
|
| 19 |
+
seq = []
|
| 20 |
+
num_sequences = 0
|
| 21 |
+
for line in fasta_string.splitlines():
|
| 22 |
+
if not line or line[0] == "#":
|
| 23 |
+
continue
|
| 24 |
+
if line.startswith(">"):
|
| 25 |
+
if header is not None:
|
| 26 |
+
yield FastaEntry(header, "".join(seq))
|
| 27 |
+
seq = []
|
| 28 |
+
header = line[1:].strip()
|
| 29 |
+
else:
|
| 30 |
+
seq.append(line)
|
| 31 |
+
if header is not None:
|
| 32 |
+
num_sequences += 1
|
| 33 |
+
yield FastaEntry(header, "".join(seq))
|
| 34 |
+
|
| 35 |
+
if num_sequences == 0:
|
| 36 |
+
raise ValueError("Found no sequences in input")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def read_sequences(path: PathOrBuffer) -> Generator[FastaEntry, None, None]:
|
| 40 |
+
# Uses duck typing to try and call the right method
|
| 41 |
+
# Doesn't use explicit isinstance check to support
|
| 42 |
+
# inputs that are not explicitly str/Path/TextIOBase but
|
| 43 |
+
# may support similar functionality
|
| 44 |
+
data = None # type: ignore
|
| 45 |
+
try:
|
| 46 |
+
if str(path).endswith(".gz"):
|
| 47 |
+
import gzip
|
| 48 |
+
|
| 49 |
+
data = gzip.open(path, "rt") # type: ignore
|
| 50 |
+
else:
|
| 51 |
+
try:
|
| 52 |
+
data = open(path) # type: ignore
|
| 53 |
+
except TypeError:
|
| 54 |
+
data: io.TextIOBase = path # type: ignore
|
| 55 |
+
|
| 56 |
+
yield from parse_fasta(data.read())
|
| 57 |
+
finally:
|
| 58 |
+
if data is not None:
|
| 59 |
+
data.close()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def read_first_sequence(path: PathOrBuffer) -> FastaEntry:
|
| 63 |
+
return next(iter(read_sequences(path)))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def count_fasta_sequences(path: str | Path) -> int:
|
| 67 |
+
"""Count sequences in a FASTA file by counting header lines.
|
| 68 |
+
|
| 69 |
+
Faster than parsing the full file — only scans for '>' prefixes.
|
| 70 |
+
Returns 0 if the file does not exist.
|
| 71 |
+
"""
|
| 72 |
+
path = Path(path)
|
| 73 |
+
if not path.exists():
|
| 74 |
+
return 0
|
| 75 |
+
with open(path) as f:
|
| 76 |
+
return sum(1 for line in f if line.startswith(">"))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def append_fasta_sequence(header: str, sequence: str, path: str | Path) -> None:
|
| 80 |
+
"""Append a single sequence to a FASTA file (creating it if needed)."""
|
| 81 |
+
path = Path(path)
|
| 82 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
# The existing file may not end with a newline (e.g., write_sequences()
|
| 84 |
+
# explicitly avoids writing a newline at the end), so we insert one before
|
| 85 |
+
# appending to avoid merging with the last line.
|
| 86 |
+
needs_newline = (
|
| 87 |
+
path.exists() and path.stat().st_size > 0 and path.read_bytes()[-1:] != b"\n"
|
| 88 |
+
)
|
| 89 |
+
with open(path, "a") as f:
|
| 90 |
+
if needs_newline:
|
| 91 |
+
f.write("\n")
|
| 92 |
+
f.write(f">{header}\n{sequence}\n")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def write_sequences(sequences: Iterable[tuple[str, str]], path: PathOrBuffer) -> None:
|
| 96 |
+
needs_closing = False
|
| 97 |
+
handle = None
|
| 98 |
+
try:
|
| 99 |
+
try:
|
| 100 |
+
handle = open(path, "w") # type: ignore
|
| 101 |
+
needs_closing = True
|
| 102 |
+
except TypeError:
|
| 103 |
+
handle = path
|
| 104 |
+
has_prev = False
|
| 105 |
+
for header, seq in sequences:
|
| 106 |
+
if has_prev:
|
| 107 |
+
handle.write("\n") # type: ignore
|
| 108 |
+
handle.write(f">{header}\n{seq}") # type: ignore
|
| 109 |
+
has_prev = True
|
| 110 |
+
finally:
|
| 111 |
+
if needs_closing:
|
| 112 |
+
handle.close() # type: ignore
|
| 113 |
+
|
esmfold2_predicted_aligned_error.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
from .esmfold2_affine3d import Affine3D
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def masked_mean(
|
| 8 |
+
mask: torch.Tensor,
|
| 9 |
+
value: torch.Tensor,
|
| 10 |
+
dim: int | None | tuple[int, ...] = None,
|
| 11 |
+
eps=1e-10,
|
| 12 |
+
) -> torch.Tensor:
|
| 13 |
+
"""Compute the mean of `value` where only positions where `mask == true` are
|
| 14 |
+
counted.
|
| 15 |
+
"""
|
| 16 |
+
mask = mask.expand(*value.shape)
|
| 17 |
+
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _pae_bins(
|
| 21 |
+
max_bin: float = 31, num_bins: int = 64, device: torch.device = torch.device("cpu")
|
| 22 |
+
):
|
| 23 |
+
bins = torch.linspace(0, max_bin, steps=(num_bins - 1), device=device)
|
| 24 |
+
step = max_bin / (num_bins - 2)
|
| 25 |
+
bin_centers = bins + step / 2
|
| 26 |
+
bin_centers = torch.cat(
|
| 27 |
+
[bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
|
| 28 |
+
)
|
| 29 |
+
return bin_centers
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _compute_pae_masks(mask: torch.Tensor):
|
| 33 |
+
square_mask = (mask.unsqueeze(-1) * mask.unsqueeze(-2)).bool()
|
| 34 |
+
return square_mask
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def compute_predicted_aligned_error(
|
| 38 |
+
logits: torch.Tensor,
|
| 39 |
+
aa_mask: torch.Tensor,
|
| 40 |
+
sequence_id: torch.Tensor | None = None,
|
| 41 |
+
max_bin: float = 31,
|
| 42 |
+
) -> torch.Tensor:
|
| 43 |
+
bins = _pae_bins(max_bin, logits.shape[-1], logits.device)
|
| 44 |
+
square_mask = _compute_pae_masks(aa_mask)
|
| 45 |
+
min_v = torch.finfo(logits.dtype).min
|
| 46 |
+
probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1)
|
| 47 |
+
|
| 48 |
+
return (probs * bins).sum(dim=-1)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@torch.no_grad
|
| 52 |
+
def compute_tm(logits: torch.Tensor, aa_mask: torch.Tensor, max_bin: float = 31.0):
|
| 53 |
+
square_mask = _compute_pae_masks(aa_mask)
|
| 54 |
+
seqlens = aa_mask.sum(-1, keepdim=True)
|
| 55 |
+
bins = _pae_bins(max_bin, logits.shape[-1], logits.device)
|
| 56 |
+
d0 = 1.24 * (seqlens.clamp_min(19) - 15) ** (1 / 3) - 1.8
|
| 57 |
+
f_d = 1.0 / (1 + (bins / d0.unsqueeze(-1)) ** 2)
|
| 58 |
+
|
| 59 |
+
min_v = torch.finfo(logits.dtype).min
|
| 60 |
+
probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1)
|
| 61 |
+
# This is the sum over bins
|
| 62 |
+
ptm = (probs * f_d.unsqueeze(-2)).sum(dim=-1)
|
| 63 |
+
# This is the mean over residues j
|
| 64 |
+
ptm = masked_mean(square_mask, ptm, dim=-1)
|
| 65 |
+
# The we do a max over residues i
|
| 66 |
+
return ptm.max(dim=-1).values
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def tm_loss(
|
| 70 |
+
logits: torch.Tensor,
|
| 71 |
+
pred_affine: torch.Tensor,
|
| 72 |
+
targ_affine: torch.Tensor,
|
| 73 |
+
targ_mask: torch.Tensor,
|
| 74 |
+
tm_mask: torch.Tensor | None = None,
|
| 75 |
+
sequence_id: torch.Tensor | None = None,
|
| 76 |
+
max_bin: float = 31,
|
| 77 |
+
):
|
| 78 |
+
pred = Affine3D.from_tensor(pred_affine)
|
| 79 |
+
targ = Affine3D.from_tensor(targ_affine)
|
| 80 |
+
|
| 81 |
+
def transform(affine: Affine3D):
|
| 82 |
+
pts = affine.trans[..., None, :, :]
|
| 83 |
+
return affine.invert()[..., None].apply(pts)
|
| 84 |
+
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
sq_diff = (transform(pred) - transform(targ)).square().sum(dim=-1)
|
| 87 |
+
|
| 88 |
+
num_bins = logits.shape[-1]
|
| 89 |
+
sq_bins = torch.linspace(
|
| 90 |
+
0, max_bin, num_bins - 1, device=logits.device
|
| 91 |
+
).square()
|
| 92 |
+
# Gets the bin id by using a sum.
|
| 93 |
+
true_bins = (sq_diff[..., None] > sq_bins).sum(dim=-1).long()
|
| 94 |
+
|
| 95 |
+
errors = F.cross_entropy(logits.movedim(3, 1), true_bins, reduction="none")
|
| 96 |
+
square_mask = _compute_pae_masks(targ_mask)
|
| 97 |
+
loss = masked_mean(square_mask, errors, dim=(-1, -2))
|
| 98 |
+
|
| 99 |
+
if tm_mask is not None:
|
| 100 |
+
loss = masked_mean(tm_mask, loss, dim=None)
|
| 101 |
+
else:
|
| 102 |
+
loss = loss.mean()
|
| 103 |
+
|
| 104 |
+
return loss
|
| 105 |
+
|
esmfold2_prepare_input.py
ADDED
|
@@ -0,0 +1,1464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prepare ESMFold2 model inputs from sequence-level StructurePredictionInput.
|
| 2 |
+
|
| 3 |
+
This module converts StructurePredictionInput (protein/DNA/RNA/ligand sequences)
|
| 4 |
+
into the tensor dict expected by the ESMFold2 model forward pass.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
import warnings
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
from dataclasses import dataclass, field
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from .esmfold2_conformers import (
|
| 18 |
+
get_ccd_leaving_atoms,
|
| 19 |
+
get_idealized_atom_pos,
|
| 20 |
+
get_ligand_ccd_atoms_with_charges,
|
| 21 |
+
get_ligand_ccd_bonds,
|
| 22 |
+
get_ligand_idealized_atom_pos,
|
| 23 |
+
)
|
| 24 |
+
from .esmfold2_constants import (
|
| 25 |
+
CHARGED_ATOMS,
|
| 26 |
+
DNA_1TO3,
|
| 27 |
+
DNA_BACKBONE_ATOMS,
|
| 28 |
+
DNA_HEAVY_ATOMS,
|
| 29 |
+
DNA_RESIDUE_TO_RES_TYPE,
|
| 30 |
+
DNA_RNA_LIGAND_INPUT_ID,
|
| 31 |
+
DNA_UNK_RES_TYPE,
|
| 32 |
+
ELEMENT_TO_ATOMIC_NUM,
|
| 33 |
+
ESM_PROTEIN_VOCAB,
|
| 34 |
+
MOL_TYPE_DNA,
|
| 35 |
+
MOL_TYPE_NONPOLYMER,
|
| 36 |
+
MOL_TYPE_PROTEIN,
|
| 37 |
+
MOL_TYPE_RNA,
|
| 38 |
+
MSA_GAP_TOKEN_ID,
|
| 39 |
+
PROTEIN_1TO3,
|
| 40 |
+
PROTEIN_3TO1,
|
| 41 |
+
PROTEIN_HEAVY_ATOMS,
|
| 42 |
+
PROTEIN_RESIDUE_TO_RES_TYPE,
|
| 43 |
+
PROTEIN_UNK_RES_TYPE,
|
| 44 |
+
RNA_1TO3,
|
| 45 |
+
RNA_BACKBONE_ATOMS,
|
| 46 |
+
RNA_HEAVY_ATOMS,
|
| 47 |
+
RNA_RESIDUE_TO_RES_TYPE,
|
| 48 |
+
RNA_UNK_RES_TYPE,
|
| 49 |
+
)
|
| 50 |
+
from .esmfold2_types import (
|
| 51 |
+
MSA,
|
| 52 |
+
DNAInput,
|
| 53 |
+
LigandInput,
|
| 54 |
+
Modification,
|
| 55 |
+
ProteinInput,
|
| 56 |
+
RNAInput,
|
| 57 |
+
StructurePredictionInput,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# =============================================================================
|
| 61 |
+
# Lightweight data model
|
| 62 |
+
# =============================================================================
|
| 63 |
+
|
| 64 |
+
_ZERO_POS = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class AtomInfo:
|
| 69 |
+
name: str
|
| 70 |
+
element: str
|
| 71 |
+
charge: int
|
| 72 |
+
ref_pos: np.ndarray # Idealized position from CCD [3]
|
| 73 |
+
pos: np.ndarray # Experimental position [3] (zeros for inference)
|
| 74 |
+
token_index: int = -1
|
| 75 |
+
atom_index: int = -1
|
| 76 |
+
space_uid: int = -1
|
| 77 |
+
is_valid: bool = True
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclass
|
| 81 |
+
class TokenInfo:
|
| 82 |
+
token_index: int
|
| 83 |
+
residue_index: int # Within chain (0-based)
|
| 84 |
+
residue_name: str # 3-letter code
|
| 85 |
+
mol_type: int # 0=protein, 1=DNA, 2=RNA, 3=nonpolymer
|
| 86 |
+
res_type: int # Residue type index (2-32)
|
| 87 |
+
input_id: int # ESM vocab ID
|
| 88 |
+
asym_id: int
|
| 89 |
+
sym_id: int
|
| 90 |
+
entity_id: int
|
| 91 |
+
atom_start: int # Index into atoms list
|
| 92 |
+
atom_count: int
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@dataclass
|
| 96 |
+
class ChainInfo:
|
| 97 |
+
chain_id: str
|
| 98 |
+
asym_id: int
|
| 99 |
+
entity_id: int
|
| 100 |
+
sym_id: int
|
| 101 |
+
mol_type: int
|
| 102 |
+
tokens: list[TokenInfo] = field(default_factory=list)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# =============================================================================
|
| 106 |
+
# Helper functions
|
| 107 |
+
# =============================================================================
|
| 108 |
+
|
| 109 |
+
# Caches for hot-path functions
|
| 110 |
+
_ENCODE_ATOM_NAME_CACHE: dict[str, list[int]] = {}
|
| 111 |
+
_ELEMENT_ATOMIC_NUM_CACHE: dict[str, int] = {}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def encode_atom_name(name: str) -> list[int]:
|
| 115 |
+
"""Encode atom name as 4 character indices (offset by 32 from ASCII)."""
|
| 116 |
+
if name in _ENCODE_ATOM_NAME_CACHE:
|
| 117 |
+
return _ENCODE_ATOM_NAME_CACHE[name]
|
| 118 |
+
padded = name.ljust(4)[:4]
|
| 119 |
+
result = [ord(c) - 32 if c != " " else 0 for c in padded]
|
| 120 |
+
_ENCODE_ATOM_NAME_CACHE[name] = result
|
| 121 |
+
return result
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_element_atomic_num(element: str) -> int:
|
| 125 |
+
"""Get atomic number for an element symbol."""
|
| 126 |
+
if element in _ELEMENT_ATOMIC_NUM_CACHE:
|
| 127 |
+
return _ELEMENT_ATOMIC_NUM_CACHE[element]
|
| 128 |
+
result = ELEMENT_TO_ATOMIC_NUM.get(element.upper(), 0)
|
| 129 |
+
_ELEMENT_ATOMIC_NUM_CACHE[element] = result
|
| 130 |
+
return result
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _infer_element(atom_name: str) -> str:
|
| 134 |
+
"""Infer element from atom name."""
|
| 135 |
+
name = atom_name.strip()
|
| 136 |
+
if not name:
|
| 137 |
+
return "C"
|
| 138 |
+
if name[0].isdigit():
|
| 139 |
+
return name[1] if len(name) > 1 else "H"
|
| 140 |
+
if len(name) == 2 and name in (
|
| 141 |
+
"FE",
|
| 142 |
+
"ZN",
|
| 143 |
+
"MG",
|
| 144 |
+
"MN",
|
| 145 |
+
"CO",
|
| 146 |
+
"NI",
|
| 147 |
+
"CU",
|
| 148 |
+
"SE",
|
| 149 |
+
"BR",
|
| 150 |
+
):
|
| 151 |
+
return name
|
| 152 |
+
return name[0]
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _compute_res_type(name: str, mol_type: int) -> int:
|
| 156 |
+
"""Compute residue type index from residue name and mol_type."""
|
| 157 |
+
if mol_type == MOL_TYPE_PROTEIN:
|
| 158 |
+
return PROTEIN_RESIDUE_TO_RES_TYPE.get(name, PROTEIN_UNK_RES_TYPE)
|
| 159 |
+
elif mol_type == MOL_TYPE_DNA:
|
| 160 |
+
if name in DNA_RESIDUE_TO_RES_TYPE:
|
| 161 |
+
return DNA_RESIDUE_TO_RES_TYPE[name]
|
| 162 |
+
if name in RNA_RESIDUE_TO_RES_TYPE:
|
| 163 |
+
return RNA_RESIDUE_TO_RES_TYPE[name]
|
| 164 |
+
return DNA_UNK_RES_TYPE
|
| 165 |
+
elif mol_type == MOL_TYPE_RNA:
|
| 166 |
+
if name in RNA_RESIDUE_TO_RES_TYPE:
|
| 167 |
+
return RNA_RESIDUE_TO_RES_TYPE[name]
|
| 168 |
+
if name in DNA_RESIDUE_TO_RES_TYPE:
|
| 169 |
+
return DNA_RESIDUE_TO_RES_TYPE[name]
|
| 170 |
+
return RNA_UNK_RES_TYPE
|
| 171 |
+
return PROTEIN_UNK_RES_TYPE
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _compute_esm_input_id(name: str, mol_type: int) -> int:
|
| 175 |
+
"""Compute ESM vocabulary input ID."""
|
| 176 |
+
if mol_type == MOL_TYPE_PROTEIN:
|
| 177 |
+
letter = PROTEIN_3TO1.get(name)
|
| 178 |
+
if letter is None:
|
| 179 |
+
return DNA_RNA_LIGAND_INPUT_ID
|
| 180 |
+
return ESM_PROTEIN_VOCAB.get(letter, ESM_PROTEIN_VOCAB["X"])
|
| 181 |
+
return DNA_RNA_LIGAND_INPUT_ID
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# =============================================================================
|
| 185 |
+
# Tokenization functions — build tokens and atoms from sequences
|
| 186 |
+
# =============================================================================
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def tokenize_protein(
|
| 190 |
+
sequence: str,
|
| 191 |
+
modifications: list[Modification] | None,
|
| 192 |
+
entity_id: int,
|
| 193 |
+
asym_id: int,
|
| 194 |
+
sym_id: int,
|
| 195 |
+
token_offset: int,
|
| 196 |
+
atom_offset: int,
|
| 197 |
+
space_uid_offset: int,
|
| 198 |
+
) -> tuple[list[TokenInfo], list[AtomInfo]]:
|
| 199 |
+
"""Tokenize a protein sequence into tokens and atoms.
|
| 200 |
+
|
| 201 |
+
Standard residues produce 1 token with all heavy atoms.
|
| 202 |
+
Modified residues (from modifications) are atom-tokenized (1 token per atom).
|
| 203 |
+
"""
|
| 204 |
+
tokens: list[TokenInfo] = []
|
| 205 |
+
atoms: list[AtomInfo] = []
|
| 206 |
+
|
| 207 |
+
# Build 3-letter sequence, applying modifications
|
| 208 |
+
seq_3letter = [PROTEIN_1TO3.get(c, "UNK") for c in sequence]
|
| 209 |
+
modified_positions: set[int] = set()
|
| 210 |
+
if modifications:
|
| 211 |
+
for mod in modifications:
|
| 212 |
+
seq_3letter[mod.position] = mod.ccd
|
| 213 |
+
modified_positions.add(mod.position)
|
| 214 |
+
|
| 215 |
+
token_idx = token_offset
|
| 216 |
+
atom_idx = atom_offset
|
| 217 |
+
space_uid = space_uid_offset
|
| 218 |
+
|
| 219 |
+
for res_idx, res_name in enumerate(seq_3letter):
|
| 220 |
+
# MSE → MET for atom lookup
|
| 221 |
+
res_corrected = "MET" if res_name == "MSE" else res_name
|
| 222 |
+
is_modified = res_idx in modified_positions
|
| 223 |
+
|
| 224 |
+
# Check if standard residue (has predefined atom list)
|
| 225 |
+
if not is_modified and res_corrected in PROTEIN_HEAVY_ATOMS:
|
| 226 |
+
# Standard residue: 1 token, multiple atoms
|
| 227 |
+
atom_names = PROTEIN_HEAVY_ATOMS[res_corrected]
|
| 228 |
+
res_type = _compute_res_type(res_corrected, MOL_TYPE_PROTEIN)
|
| 229 |
+
input_id = _compute_esm_input_id(res_corrected, MOL_TYPE_PROTEIN)
|
| 230 |
+
|
| 231 |
+
atom_start = atom_idx
|
| 232 |
+
for a_name in atom_names:
|
| 233 |
+
ref_pos = get_idealized_atom_pos(res_type, a_name)
|
| 234 |
+
atoms.append(
|
| 235 |
+
AtomInfo(
|
| 236 |
+
name=a_name,
|
| 237 |
+
element=_infer_element(a_name),
|
| 238 |
+
charge=CHARGED_ATOMS.get((res_corrected, a_name), 0),
|
| 239 |
+
ref_pos=ref_pos.copy()
|
| 240 |
+
if ref_pos is not None
|
| 241 |
+
else _ZERO_POS.copy(),
|
| 242 |
+
pos=_ZERO_POS.copy(),
|
| 243 |
+
token_index=token_idx,
|
| 244 |
+
atom_index=atom_idx,
|
| 245 |
+
space_uid=space_uid,
|
| 246 |
+
)
|
| 247 |
+
)
|
| 248 |
+
atom_idx += 1
|
| 249 |
+
|
| 250 |
+
tokens.append(
|
| 251 |
+
TokenInfo(
|
| 252 |
+
token_index=token_idx,
|
| 253 |
+
residue_index=res_idx,
|
| 254 |
+
residue_name=res_corrected,
|
| 255 |
+
mol_type=MOL_TYPE_PROTEIN,
|
| 256 |
+
res_type=res_type,
|
| 257 |
+
input_id=input_id,
|
| 258 |
+
asym_id=asym_id,
|
| 259 |
+
sym_id=sym_id,
|
| 260 |
+
entity_id=entity_id,
|
| 261 |
+
atom_start=atom_start,
|
| 262 |
+
atom_count=len(atom_names),
|
| 263 |
+
)
|
| 264 |
+
)
|
| 265 |
+
token_idx += 1
|
| 266 |
+
space_uid += 1
|
| 267 |
+
|
| 268 |
+
else:
|
| 269 |
+
# Modified or unknown residue: atom-tokenized
|
| 270 |
+
ccd_atoms = get_ligand_ccd_atoms_with_charges(res_name)
|
| 271 |
+
if ccd_atoms is None:
|
| 272 |
+
# Fallback: backbone only
|
| 273 |
+
ccd_atoms = [
|
| 274 |
+
(_infer_element(n), _infer_element(n), 0)
|
| 275 |
+
for n in ["N", "CA", "C", "O"]
|
| 276 |
+
]
|
| 277 |
+
|
| 278 |
+
# Filter leaving atoms if not terminal
|
| 279 |
+
is_terminal = res_idx == len(seq_3letter) - 1
|
| 280 |
+
leaving_atoms = set() if is_terminal else get_ccd_leaving_atoms(res_name)
|
| 281 |
+
kept_atoms = [a for a in ccd_atoms if a[0] not in leaving_atoms]
|
| 282 |
+
# Single-atom residues (e.g. NH2 cap): the local frame is
|
| 283 |
+
# ill-defined with one atom; place at origin.
|
| 284 |
+
single_atom_residue = len(kept_atoms) == 1
|
| 285 |
+
|
| 286 |
+
for a_name, a_element, a_charge in kept_atoms:
|
| 287 |
+
ref_pos = get_ligand_idealized_atom_pos(res_name, a_name)
|
| 288 |
+
atoms.append(
|
| 289 |
+
AtomInfo(
|
| 290 |
+
name=a_name,
|
| 291 |
+
element=a_element,
|
| 292 |
+
charge=a_charge,
|
| 293 |
+
ref_pos=_ZERO_POS.copy()
|
| 294 |
+
if single_atom_residue
|
| 295 |
+
else (
|
| 296 |
+
ref_pos.copy() if ref_pos is not None else _ZERO_POS.copy()
|
| 297 |
+
),
|
| 298 |
+
pos=_ZERO_POS.copy(),
|
| 299 |
+
token_index=token_idx,
|
| 300 |
+
atom_index=atom_idx,
|
| 301 |
+
space_uid=space_uid,
|
| 302 |
+
)
|
| 303 |
+
)
|
| 304 |
+
tokens.append(
|
| 305 |
+
TokenInfo(
|
| 306 |
+
token_index=token_idx,
|
| 307 |
+
residue_index=res_idx,
|
| 308 |
+
residue_name=res_name,
|
| 309 |
+
mol_type=MOL_TYPE_PROTEIN,
|
| 310 |
+
res_type=PROTEIN_UNK_RES_TYPE,
|
| 311 |
+
input_id=DNA_RNA_LIGAND_INPUT_ID,
|
| 312 |
+
asym_id=asym_id,
|
| 313 |
+
sym_id=sym_id,
|
| 314 |
+
entity_id=entity_id,
|
| 315 |
+
atom_start=atom_idx,
|
| 316 |
+
atom_count=1,
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
token_idx += 1
|
| 320 |
+
atom_idx += 1
|
| 321 |
+
|
| 322 |
+
space_uid += 1
|
| 323 |
+
|
| 324 |
+
return tokens, atoms
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def tokenize_nucleotide(
|
| 328 |
+
sequence: str,
|
| 329 |
+
modifications: list[Modification] | None,
|
| 330 |
+
mol_type: int,
|
| 331 |
+
entity_id: int,
|
| 332 |
+
asym_id: int,
|
| 333 |
+
sym_id: int,
|
| 334 |
+
token_offset: int,
|
| 335 |
+
atom_offset: int,
|
| 336 |
+
space_uid_offset: int,
|
| 337 |
+
) -> tuple[list[TokenInfo], list[AtomInfo]]:
|
| 338 |
+
"""Tokenize a DNA or RNA sequence into tokens and atoms."""
|
| 339 |
+
tokens: list[TokenInfo] = []
|
| 340 |
+
atoms: list[AtomInfo] = []
|
| 341 |
+
|
| 342 |
+
letter_to_3 = DNA_1TO3 if mol_type == MOL_TYPE_DNA else RNA_1TO3
|
| 343 |
+
heavy_atoms = DNA_HEAVY_ATOMS if mol_type == MOL_TYPE_DNA else RNA_HEAVY_ATOMS
|
| 344 |
+
backbone_atoms = (
|
| 345 |
+
DNA_BACKBONE_ATOMS if mol_type == MOL_TYPE_DNA else RNA_BACKBONE_ATOMS
|
| 346 |
+
)
|
| 347 |
+
unk_res_type = DNA_UNK_RES_TYPE if mol_type == MOL_TYPE_DNA else RNA_UNK_RES_TYPE
|
| 348 |
+
|
| 349 |
+
seq_3letter = [letter_to_3.get(c, "UNK") for c in sequence]
|
| 350 |
+
modified_positions: set[int] = set()
|
| 351 |
+
if modifications:
|
| 352 |
+
for mod in modifications:
|
| 353 |
+
seq_3letter[mod.position] = mod.ccd
|
| 354 |
+
modified_positions.add(mod.position)
|
| 355 |
+
|
| 356 |
+
token_idx = token_offset
|
| 357 |
+
atom_idx = atom_offset
|
| 358 |
+
space_uid = space_uid_offset
|
| 359 |
+
|
| 360 |
+
for res_idx, res_name in enumerate(seq_3letter):
|
| 361 |
+
is_modified = res_idx in modified_positions
|
| 362 |
+
|
| 363 |
+
if not is_modified and res_name in heavy_atoms:
|
| 364 |
+
# Standard nucleotide
|
| 365 |
+
atom_names = heavy_atoms[res_name]
|
| 366 |
+
res_type = _compute_res_type(res_name, mol_type)
|
| 367 |
+
input_id = DNA_RNA_LIGAND_INPUT_ID
|
| 368 |
+
|
| 369 |
+
atom_start = atom_idx
|
| 370 |
+
for a_name in atom_names:
|
| 371 |
+
ref_pos = get_idealized_atom_pos(res_type, a_name)
|
| 372 |
+
atoms.append(
|
| 373 |
+
AtomInfo(
|
| 374 |
+
name=a_name,
|
| 375 |
+
element=_infer_element(a_name),
|
| 376 |
+
charge=CHARGED_ATOMS.get((res_name, a_name), 0),
|
| 377 |
+
ref_pos=ref_pos.copy()
|
| 378 |
+
if ref_pos is not None
|
| 379 |
+
else _ZERO_POS.copy(),
|
| 380 |
+
pos=_ZERO_POS.copy(),
|
| 381 |
+
token_index=token_idx,
|
| 382 |
+
atom_index=atom_idx,
|
| 383 |
+
space_uid=space_uid,
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
+
atom_idx += 1
|
| 387 |
+
|
| 388 |
+
tokens.append(
|
| 389 |
+
TokenInfo(
|
| 390 |
+
token_index=token_idx,
|
| 391 |
+
residue_index=res_idx,
|
| 392 |
+
residue_name=res_name,
|
| 393 |
+
mol_type=mol_type,
|
| 394 |
+
res_type=res_type,
|
| 395 |
+
input_id=input_id,
|
| 396 |
+
asym_id=asym_id,
|
| 397 |
+
sym_id=sym_id,
|
| 398 |
+
entity_id=entity_id,
|
| 399 |
+
atom_start=atom_start,
|
| 400 |
+
atom_count=len(atom_names),
|
| 401 |
+
)
|
| 402 |
+
)
|
| 403 |
+
token_idx += 1
|
| 404 |
+
space_uid += 1
|
| 405 |
+
|
| 406 |
+
elif not is_modified and res_name == "UNK":
|
| 407 |
+
# Unknown nucleotide: backbone only
|
| 408 |
+
atom_names = backbone_atoms
|
| 409 |
+
atom_start = atom_idx
|
| 410 |
+
for a_name in atom_names:
|
| 411 |
+
ref_pos = None # No idealized positions for UNK
|
| 412 |
+
atoms.append(
|
| 413 |
+
AtomInfo(
|
| 414 |
+
name=a_name,
|
| 415 |
+
element=_infer_element(a_name),
|
| 416 |
+
charge=0,
|
| 417 |
+
ref_pos=_ZERO_POS.copy(),
|
| 418 |
+
pos=_ZERO_POS.copy(),
|
| 419 |
+
token_index=token_idx,
|
| 420 |
+
atom_index=atom_idx,
|
| 421 |
+
space_uid=space_uid,
|
| 422 |
+
)
|
| 423 |
+
)
|
| 424 |
+
atom_idx += 1
|
| 425 |
+
|
| 426 |
+
tokens.append(
|
| 427 |
+
TokenInfo(
|
| 428 |
+
token_index=token_idx,
|
| 429 |
+
residue_index=res_idx,
|
| 430 |
+
residue_name=res_name,
|
| 431 |
+
mol_type=mol_type,
|
| 432 |
+
res_type=unk_res_type,
|
| 433 |
+
input_id=DNA_RNA_LIGAND_INPUT_ID,
|
| 434 |
+
asym_id=asym_id,
|
| 435 |
+
sym_id=sym_id,
|
| 436 |
+
entity_id=entity_id,
|
| 437 |
+
atom_start=atom_start,
|
| 438 |
+
atom_count=len(atom_names),
|
| 439 |
+
)
|
| 440 |
+
)
|
| 441 |
+
token_idx += 1
|
| 442 |
+
space_uid += 1
|
| 443 |
+
|
| 444 |
+
else:
|
| 445 |
+
# Modified nucleotide: atom-tokenized
|
| 446 |
+
ccd_atoms = get_ligand_ccd_atoms_with_charges(res_name)
|
| 447 |
+
if ccd_atoms is None:
|
| 448 |
+
ccd_atoms = [
|
| 449 |
+
(_infer_element(n), _infer_element(n), 0) for n in backbone_atoms
|
| 450 |
+
]
|
| 451 |
+
|
| 452 |
+
is_terminal = res_idx == len(seq_3letter) - 1
|
| 453 |
+
leaving_atoms = set() if is_terminal else get_ccd_leaving_atoms(res_name)
|
| 454 |
+
|
| 455 |
+
for a_name, a_element, a_charge in ccd_atoms:
|
| 456 |
+
if a_name in leaving_atoms:
|
| 457 |
+
continue
|
| 458 |
+
ref_pos = get_ligand_idealized_atom_pos(res_name, a_name)
|
| 459 |
+
atoms.append(
|
| 460 |
+
AtomInfo(
|
| 461 |
+
name=a_name,
|
| 462 |
+
element=a_element,
|
| 463 |
+
charge=a_charge,
|
| 464 |
+
ref_pos=ref_pos.copy()
|
| 465 |
+
if ref_pos is not None
|
| 466 |
+
else _ZERO_POS.copy(),
|
| 467 |
+
pos=_ZERO_POS.copy(),
|
| 468 |
+
token_index=token_idx,
|
| 469 |
+
atom_index=atom_idx,
|
| 470 |
+
space_uid=space_uid,
|
| 471 |
+
)
|
| 472 |
+
)
|
| 473 |
+
tokens.append(
|
| 474 |
+
TokenInfo(
|
| 475 |
+
token_index=token_idx,
|
| 476 |
+
residue_index=res_idx,
|
| 477 |
+
residue_name=res_name,
|
| 478 |
+
mol_type=mol_type,
|
| 479 |
+
res_type=PROTEIN_UNK_RES_TYPE,
|
| 480 |
+
input_id=DNA_RNA_LIGAND_INPUT_ID,
|
| 481 |
+
asym_id=asym_id,
|
| 482 |
+
sym_id=sym_id,
|
| 483 |
+
entity_id=entity_id,
|
| 484 |
+
atom_start=atom_idx,
|
| 485 |
+
atom_count=1,
|
| 486 |
+
)
|
| 487 |
+
)
|
| 488 |
+
token_idx += 1
|
| 489 |
+
atom_idx += 1
|
| 490 |
+
|
| 491 |
+
space_uid += 1
|
| 492 |
+
|
| 493 |
+
return tokens, atoms
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def tokenize_ligand_ccd(
|
| 497 |
+
ccd_codes: list[str],
|
| 498 |
+
entity_id: int,
|
| 499 |
+
asym_id: int,
|
| 500 |
+
sym_id: int,
|
| 501 |
+
token_offset: int,
|
| 502 |
+
atom_offset: int,
|
| 503 |
+
space_uid_offset: int,
|
| 504 |
+
has_covalent_bond: bool,
|
| 505 |
+
) -> tuple[list[TokenInfo], list[AtomInfo]]:
|
| 506 |
+
"""Tokenize a ligand from CCD codes (1 token per atom)."""
|
| 507 |
+
tokens: list[TokenInfo] = []
|
| 508 |
+
atoms: list[AtomInfo] = []
|
| 509 |
+
|
| 510 |
+
token_idx = token_offset
|
| 511 |
+
atom_idx = atom_offset
|
| 512 |
+
space_uid = space_uid_offset
|
| 513 |
+
|
| 514 |
+
for res_idx, code in enumerate(ccd_codes):
|
| 515 |
+
ccd_atoms = get_ligand_ccd_atoms_with_charges(code)
|
| 516 |
+
if ccd_atoms is None:
|
| 517 |
+
raise ValueError(f"CCD component {code} not found")
|
| 518 |
+
|
| 519 |
+
leaving_atoms = get_ccd_leaving_atoms(code) if has_covalent_bond else set()
|
| 520 |
+
|
| 521 |
+
for a_name, a_element, a_charge in ccd_atoms:
|
| 522 |
+
if a_name in leaving_atoms:
|
| 523 |
+
continue
|
| 524 |
+
ref_pos = get_ligand_idealized_atom_pos(code, a_name)
|
| 525 |
+
atoms.append(
|
| 526 |
+
AtomInfo(
|
| 527 |
+
name=a_name,
|
| 528 |
+
element=a_element,
|
| 529 |
+
charge=a_charge,
|
| 530 |
+
ref_pos=ref_pos.copy() if ref_pos is not None else _ZERO_POS.copy(),
|
| 531 |
+
pos=_ZERO_POS.copy(),
|
| 532 |
+
token_index=token_idx,
|
| 533 |
+
atom_index=atom_idx,
|
| 534 |
+
space_uid=space_uid,
|
| 535 |
+
)
|
| 536 |
+
)
|
| 537 |
+
tokens.append(
|
| 538 |
+
TokenInfo(
|
| 539 |
+
token_index=token_idx,
|
| 540 |
+
residue_index=res_idx,
|
| 541 |
+
residue_name=code,
|
| 542 |
+
mol_type=MOL_TYPE_NONPOLYMER,
|
| 543 |
+
res_type=PROTEIN_UNK_RES_TYPE,
|
| 544 |
+
input_id=DNA_RNA_LIGAND_INPUT_ID,
|
| 545 |
+
asym_id=asym_id,
|
| 546 |
+
sym_id=sym_id,
|
| 547 |
+
entity_id=entity_id,
|
| 548 |
+
atom_start=atom_idx,
|
| 549 |
+
atom_count=1,
|
| 550 |
+
)
|
| 551 |
+
)
|
| 552 |
+
token_idx += 1
|
| 553 |
+
atom_idx += 1
|
| 554 |
+
|
| 555 |
+
space_uid += 1
|
| 556 |
+
|
| 557 |
+
return tokens, atoms
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def tokenize_ligand_smiles(
|
| 561 |
+
smiles: str,
|
| 562 |
+
entity_id: int,
|
| 563 |
+
asym_id: int,
|
| 564 |
+
sym_id: int,
|
| 565 |
+
token_offset: int,
|
| 566 |
+
atom_offset: int,
|
| 567 |
+
space_uid_offset: int,
|
| 568 |
+
seed: int | None = None,
|
| 569 |
+
) -> tuple[list[TokenInfo], list[AtomInfo]]:
|
| 570 |
+
"""Tokenize a ligand from SMILES (1 token per heavy atom)."""
|
| 571 |
+
from rdkit import Chem
|
| 572 |
+
from rdkit.Chem import AllChem
|
| 573 |
+
|
| 574 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 575 |
+
if mol is None:
|
| 576 |
+
raise ValueError(f"Failed to parse SMILES: {smiles}")
|
| 577 |
+
mol = Chem.AddHs(mol)
|
| 578 |
+
|
| 579 |
+
# Assign atom names using canonical ranking
|
| 580 |
+
canonical_order = AllChem.CanonicalRankAtoms(mol) # type: ignore[attr-defined]
|
| 581 |
+
for atom, can_idx in zip(mol.GetAtoms(), canonical_order):
|
| 582 |
+
atom_name = atom.GetSymbol().upper() + str(can_idx + 1)
|
| 583 |
+
if len(atom_name) > 4:
|
| 584 |
+
raise ValueError(
|
| 585 |
+
f"SMILES {smiles} has atom name longer than 4 chars: {atom_name}"
|
| 586 |
+
)
|
| 587 |
+
atom.SetProp("name", atom_name)
|
| 588 |
+
|
| 589 |
+
# Generate 3D conformer
|
| 590 |
+
options = AllChem.ETKDGv3() # type: ignore[attr-defined]
|
| 591 |
+
options.clearConfs = False
|
| 592 |
+
if seed is not None:
|
| 593 |
+
options.randomSeed = seed
|
| 594 |
+
conf_id = AllChem.EmbedMolecule(mol, options) # type: ignore[attr-defined]
|
| 595 |
+
if conf_id == -1:
|
| 596 |
+
options.useRandomCoords = True
|
| 597 |
+
conf_id = AllChem.EmbedMolecule(mol, options) # type: ignore[attr-defined]
|
| 598 |
+
if conf_id != -1:
|
| 599 |
+
try:
|
| 600 |
+
AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000) # type: ignore[attr-defined]
|
| 601 |
+
except (RuntimeError, ValueError):
|
| 602 |
+
pass
|
| 603 |
+
|
| 604 |
+
# Remove hydrogens
|
| 605 |
+
mol_no_h = Chem.RemoveHs(mol)
|
| 606 |
+
if mol_no_h.GetNumConformers() == 0:
|
| 607 |
+
raise ValueError(f"Failed to generate conformer for SMILES: {smiles}")
|
| 608 |
+
|
| 609 |
+
conformer = mol_no_h.GetConformer(0)
|
| 610 |
+
|
| 611 |
+
tokens: list[TokenInfo] = []
|
| 612 |
+
atoms_list: list[AtomInfo] = []
|
| 613 |
+
token_idx = token_offset
|
| 614 |
+
atom_idx = atom_offset
|
| 615 |
+
space_uid = space_uid_offset
|
| 616 |
+
|
| 617 |
+
for atom in mol_no_h.GetAtoms():
|
| 618 |
+
a_name = atom.GetProp("name")
|
| 619 |
+
a_element = atom.GetSymbol()
|
| 620 |
+
a_charge = atom.GetFormalCharge()
|
| 621 |
+
pos_3d = conformer.GetAtomPosition(atom.GetIdx())
|
| 622 |
+
ref_pos = np.array([pos_3d.x, pos_3d.y, pos_3d.z], dtype=np.float32)
|
| 623 |
+
|
| 624 |
+
atoms_list.append(
|
| 625 |
+
AtomInfo(
|
| 626 |
+
name=a_name,
|
| 627 |
+
element=a_element,
|
| 628 |
+
charge=a_charge,
|
| 629 |
+
ref_pos=ref_pos,
|
| 630 |
+
pos=_ZERO_POS.copy(),
|
| 631 |
+
token_index=token_idx,
|
| 632 |
+
atom_index=atom_idx,
|
| 633 |
+
space_uid=space_uid,
|
| 634 |
+
)
|
| 635 |
+
)
|
| 636 |
+
tokens.append(
|
| 637 |
+
TokenInfo(
|
| 638 |
+
token_index=token_idx,
|
| 639 |
+
residue_index=0,
|
| 640 |
+
residue_name="LIG",
|
| 641 |
+
mol_type=MOL_TYPE_NONPOLYMER,
|
| 642 |
+
res_type=PROTEIN_UNK_RES_TYPE,
|
| 643 |
+
input_id=DNA_RNA_LIGAND_INPUT_ID,
|
| 644 |
+
asym_id=asym_id,
|
| 645 |
+
sym_id=sym_id,
|
| 646 |
+
entity_id=entity_id,
|
| 647 |
+
atom_start=atom_idx,
|
| 648 |
+
atom_count=1,
|
| 649 |
+
)
|
| 650 |
+
)
|
| 651 |
+
token_idx += 1
|
| 652 |
+
atom_idx += 1
|
| 653 |
+
|
| 654 |
+
return tokens, atoms_list
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
# =============================================================================
|
| 658 |
+
# Build chains from StructurePredictionInput
|
| 659 |
+
# =============================================================================
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def _get_sequence_key(item) -> str:
|
| 663 |
+
"""Get a hashable key for entity deduplication."""
|
| 664 |
+
if isinstance(item, ProteinInput):
|
| 665 |
+
return f"PROTEIN:{item.sequence}"
|
| 666 |
+
elif isinstance(item, DNAInput):
|
| 667 |
+
return f"DNA:{item.sequence}"
|
| 668 |
+
elif isinstance(item, RNAInput):
|
| 669 |
+
return f"RNA:{item.sequence}"
|
| 670 |
+
elif isinstance(item, LigandInput):
|
| 671 |
+
if item.ccd:
|
| 672 |
+
return f"LIGAND_CCD:{','.join(item.ccd)}"
|
| 673 |
+
return f"LIGAND_SMILES:{item.smiles}"
|
| 674 |
+
raise ValueError(f"Unknown input type: {type(item)}")
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def build_chains_from_input(
|
| 678 |
+
input: StructurePredictionInput, seed: int | None = None
|
| 679 |
+
) -> tuple[list[ChainInfo], list[TokenInfo], list[AtomInfo]]:
|
| 680 |
+
"""Build chains, tokens, and atoms from StructurePredictionInput.
|
| 681 |
+
|
| 682 |
+
Handles entity deduplication (identical sequences get same entity_id),
|
| 683 |
+
sym_id assignment, and delegates to type-specific tokenization functions.
|
| 684 |
+
"""
|
| 685 |
+
chains: list[ChainInfo] = []
|
| 686 |
+
all_tokens: list[TokenInfo] = []
|
| 687 |
+
all_atoms: list[AtomInfo] = []
|
| 688 |
+
|
| 689 |
+
# Entity deduplication
|
| 690 |
+
sequence_to_entity: dict[str, int] = {}
|
| 691 |
+
entity_sym_count: dict[int, int] = {}
|
| 692 |
+
next_entity_id = 0
|
| 693 |
+
|
| 694 |
+
# Gather chain IDs involved in covalent bonds
|
| 695 |
+
covalent_chain_ids: set[str] = set()
|
| 696 |
+
if input.covalent_bonds:
|
| 697 |
+
for cb in input.covalent_bonds:
|
| 698 |
+
covalent_chain_ids.update([cb.chain_id1, cb.chain_id2])
|
| 699 |
+
|
| 700 |
+
token_offset = 0
|
| 701 |
+
atom_offset = 0
|
| 702 |
+
space_uid_offset = 0
|
| 703 |
+
asym_id = 0
|
| 704 |
+
|
| 705 |
+
for item in input.sequences:
|
| 706 |
+
# Entity deduplication
|
| 707 |
+
seq_key = _get_sequence_key(item)
|
| 708 |
+
if seq_key in sequence_to_entity:
|
| 709 |
+
entity_id = sequence_to_entity[seq_key]
|
| 710 |
+
else:
|
| 711 |
+
entity_id = next_entity_id
|
| 712 |
+
sequence_to_entity[seq_key] = entity_id
|
| 713 |
+
next_entity_id += 1
|
| 714 |
+
|
| 715 |
+
# Get all chain IDs for this item
|
| 716 |
+
ids = [item.id] if isinstance(item.id, str) else item.id
|
| 717 |
+
|
| 718 |
+
for chain_id_str in ids:
|
| 719 |
+
# sym_id is the per-entity copy index; increment per chain so
|
| 720 |
+
# ProteinInput(id=['A','B']) gives chain A sym_id=0, chain B sym_id=1.
|
| 721 |
+
sym_id = entity_sym_count.get(entity_id, 0)
|
| 722 |
+
entity_sym_count[entity_id] = sym_id + 1
|
| 723 |
+
if isinstance(item, ProteinInput):
|
| 724 |
+
if item.msa is None:
|
| 725 |
+
warnings.warn(
|
| 726 |
+
f"No MSA provided for {item.id}, using single sequence mode"
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
new_tokens, new_atoms = tokenize_protein(
|
| 730 |
+
sequence=item.sequence,
|
| 731 |
+
modifications=item.modifications,
|
| 732 |
+
entity_id=entity_id,
|
| 733 |
+
asym_id=asym_id,
|
| 734 |
+
sym_id=sym_id,
|
| 735 |
+
token_offset=token_offset,
|
| 736 |
+
atom_offset=atom_offset,
|
| 737 |
+
space_uid_offset=space_uid_offset,
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
elif isinstance(item, (DNAInput, RNAInput)):
|
| 741 |
+
mol_type = MOL_TYPE_DNA if isinstance(item, DNAInput) else MOL_TYPE_RNA
|
| 742 |
+
new_tokens, new_atoms = tokenize_nucleotide(
|
| 743 |
+
sequence=item.sequence,
|
| 744 |
+
modifications=item.modifications,
|
| 745 |
+
mol_type=mol_type,
|
| 746 |
+
entity_id=entity_id,
|
| 747 |
+
asym_id=asym_id,
|
| 748 |
+
sym_id=sym_id,
|
| 749 |
+
token_offset=token_offset,
|
| 750 |
+
atom_offset=atom_offset,
|
| 751 |
+
space_uid_offset=space_uid_offset,
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
elif isinstance(item, LigandInput):
|
| 755 |
+
has_cov = chain_id_str in covalent_chain_ids
|
| 756 |
+
if item.ccd is not None:
|
| 757 |
+
if item.smiles is not None:
|
| 758 |
+
warnings.warn("Both ccd and smiles provided, using ccd")
|
| 759 |
+
new_tokens, new_atoms = tokenize_ligand_ccd(
|
| 760 |
+
ccd_codes=item.ccd,
|
| 761 |
+
entity_id=entity_id,
|
| 762 |
+
asym_id=asym_id,
|
| 763 |
+
sym_id=sym_id,
|
| 764 |
+
token_offset=token_offset,
|
| 765 |
+
atom_offset=atom_offset,
|
| 766 |
+
space_uid_offset=space_uid_offset,
|
| 767 |
+
has_covalent_bond=has_cov,
|
| 768 |
+
)
|
| 769 |
+
elif item.smiles is not None:
|
| 770 |
+
new_tokens, new_atoms = tokenize_ligand_smiles(
|
| 771 |
+
smiles=item.smiles,
|
| 772 |
+
entity_id=entity_id,
|
| 773 |
+
asym_id=asym_id,
|
| 774 |
+
sym_id=sym_id,
|
| 775 |
+
token_offset=token_offset,
|
| 776 |
+
atom_offset=atom_offset,
|
| 777 |
+
space_uid_offset=space_uid_offset,
|
| 778 |
+
seed=seed,
|
| 779 |
+
)
|
| 780 |
+
else:
|
| 781 |
+
raise ValueError("LigandInput must have either ccd or smiles")
|
| 782 |
+
else:
|
| 783 |
+
raise ValueError(f"Unknown input type: {type(item)}")
|
| 784 |
+
|
| 785 |
+
chain = ChainInfo(
|
| 786 |
+
chain_id=chain_id_str,
|
| 787 |
+
asym_id=asym_id,
|
| 788 |
+
entity_id=entity_id,
|
| 789 |
+
sym_id=sym_id,
|
| 790 |
+
mol_type=new_tokens[0].mol_type if new_tokens else MOL_TYPE_PROTEIN,
|
| 791 |
+
tokens=new_tokens,
|
| 792 |
+
)
|
| 793 |
+
chains.append(chain)
|
| 794 |
+
all_tokens.extend(new_tokens)
|
| 795 |
+
all_atoms.extend(new_atoms)
|
| 796 |
+
|
| 797 |
+
token_offset += len(new_tokens)
|
| 798 |
+
atom_offset += len(new_atoms)
|
| 799 |
+
space_uid_offset += len(set(a.space_uid for a in new_atoms))
|
| 800 |
+
asym_id += 1
|
| 801 |
+
|
| 802 |
+
return chains, all_tokens, all_atoms
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
# =============================================================================
|
| 806 |
+
# Feature tensor building
|
| 807 |
+
# =============================================================================
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
def compute_frame_indices(
|
| 811 |
+
tokens: list[TokenInfo], atoms: list[AtomInfo]
|
| 812 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 813 |
+
"""Compute backbone frame indices for each token.
|
| 814 |
+
|
| 815 |
+
Protein: [N, CA, C]; DNA/RNA: [C1', C3', C4']; Ligand: distance-based.
|
| 816 |
+
"""
|
| 817 |
+
# Build atom name -> atom_index lookup per token
|
| 818 |
+
token_atoms: dict[int, dict[str, int]] = defaultdict(dict)
|
| 819 |
+
for atom in atoms:
|
| 820 |
+
if atom.is_valid:
|
| 821 |
+
token_atoms[atom.token_index][atom.name] = atom.atom_index
|
| 822 |
+
|
| 823 |
+
# Ligand-token frames come from CCD reference-conformer geometry,
|
| 824 |
+
# grouped per residue. For each token, the frame is the 3 atoms nearest
|
| 825 |
+
# to its own atom in the residue's ref-pos space, ordered
|
| 826 |
+
# (1st-nearest, self, 2nd-nearest).
|
| 827 |
+
ligand_token_to_atom: dict[int, int] = {}
|
| 828 |
+
ligand_tokens_by_res: dict[tuple[int, int], list[int]] = defaultdict(list)
|
| 829 |
+
for t in tokens:
|
| 830 |
+
if t.mol_type == MOL_TYPE_NONPOLYMER:
|
| 831 |
+
ad = token_atoms.get(t.token_index)
|
| 832 |
+
if ad:
|
| 833 |
+
ligand_token_to_atom[t.token_index] = next(iter(ad.values()))
|
| 834 |
+
ligand_tokens_by_res[(t.asym_id, t.residue_index)].append(t.token_index)
|
| 835 |
+
|
| 836 |
+
ligand_token_frames: dict[int, tuple[int, int, int]] = {}
|
| 837 |
+
for tok_indices in ligand_tokens_by_res.values():
|
| 838 |
+
atom_indices = [
|
| 839 |
+
ligand_token_to_atom[ti] for ti in tok_indices if ti in ligand_token_to_atom
|
| 840 |
+
]
|
| 841 |
+
if len(atom_indices) < 3:
|
| 842 |
+
for ti in tok_indices:
|
| 843 |
+
if ti in ligand_token_to_atom:
|
| 844 |
+
ai = ligand_token_to_atom[ti]
|
| 845 |
+
ligand_token_frames[ti] = (ai, ai, ai)
|
| 846 |
+
continue
|
| 847 |
+
|
| 848 |
+
ref_pos_chain = np.array([atoms[ai].ref_pos for ai in atom_indices])
|
| 849 |
+
dist_mat = np.sqrt(
|
| 850 |
+
((ref_pos_chain[:, None] - ref_pos_chain[None]) ** 2).sum(-1)
|
| 851 |
+
)
|
| 852 |
+
sort_indices = np.argsort(dist_mat, axis=1)
|
| 853 |
+
local_frames = np.column_stack(
|
| 854 |
+
[sort_indices[:, 1], sort_indices[:, 0], sort_indices[:, 2]]
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
for ti in tok_indices:
|
| 858 |
+
if ti not in ligand_token_to_atom:
|
| 859 |
+
continue
|
| 860 |
+
ai = ligand_token_to_atom[ti]
|
| 861 |
+
local_idx = atom_indices.index(ai)
|
| 862 |
+
fl = local_frames[local_idx]
|
| 863 |
+
ligand_token_frames[ti] = (
|
| 864 |
+
atom_indices[fl[0]],
|
| 865 |
+
atom_indices[fl[1]],
|
| 866 |
+
atom_indices[fl[2]],
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
# Build frames for all tokens
|
| 870 |
+
frames_list: list[tuple[int, int, int]] = []
|
| 871 |
+
for t in tokens:
|
| 872 |
+
ad = token_atoms.get(t.token_index, {})
|
| 873 |
+
fallback = list(ad.values())[0] if ad else 0
|
| 874 |
+
|
| 875 |
+
if t.mol_type == MOL_TYPE_PROTEIN:
|
| 876 |
+
if t.res_type == PROTEIN_UNK_RES_TYPE:
|
| 877 |
+
frames_list.append((fallback, fallback, fallback))
|
| 878 |
+
else:
|
| 879 |
+
frames_list.append((ad.get("N", 0), ad.get("CA", 0), ad.get("C", 0)))
|
| 880 |
+
elif t.mol_type in (MOL_TYPE_DNA, MOL_TYPE_RNA):
|
| 881 |
+
if t.res_type == PROTEIN_UNK_RES_TYPE:
|
| 882 |
+
frames_list.append((fallback, fallback, fallback))
|
| 883 |
+
else:
|
| 884 |
+
frames_list.append(
|
| 885 |
+
(ad.get("C1'", 0), ad.get("C3'", 0), ad.get("C4'", 0))
|
| 886 |
+
)
|
| 887 |
+
elif t.mol_type == MOL_TYPE_NONPOLYMER:
|
| 888 |
+
if t.token_index in ligand_token_frames:
|
| 889 |
+
frames_list.append(ligand_token_frames[t.token_index])
|
| 890 |
+
else:
|
| 891 |
+
frames_list.append((fallback, fallback, fallback))
|
| 892 |
+
else:
|
| 893 |
+
frames_list.append((fallback, fallback, fallback))
|
| 894 |
+
|
| 895 |
+
frames = np.array(frames_list, dtype=np.int64)
|
| 896 |
+
|
| 897 |
+
# Compute resolved mask (vectorized)
|
| 898 |
+
n_atoms = len(atoms)
|
| 899 |
+
atom_positions = (
|
| 900 |
+
np.array([a.pos for a in atoms], dtype=np.float32)
|
| 901 |
+
if atoms
|
| 902 |
+
else np.zeros((0, 3), dtype=np.float32)
|
| 903 |
+
)
|
| 904 |
+
atom_is_valid = (
|
| 905 |
+
np.array([a.is_valid for a in atoms], dtype=bool)
|
| 906 |
+
if atoms
|
| 907 |
+
else np.zeros(0, dtype=bool)
|
| 908 |
+
)
|
| 909 |
+
atom_is_resolved = (
|
| 910 |
+
atom_is_valid & np.any(atom_positions != 0, axis=1)
|
| 911 |
+
if n_atoms > 0
|
| 912 |
+
else np.zeros(0, dtype=bool)
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
n_tokens = len(tokens)
|
| 916 |
+
if n_tokens == 0:
|
| 917 |
+
return frames, np.zeros(0, dtype=bool)
|
| 918 |
+
|
| 919 |
+
pos1 = atom_positions[frames[:, 0]]
|
| 920 |
+
pos2 = atom_positions[frames[:, 1]]
|
| 921 |
+
pos3 = atom_positions[frames[:, 2]]
|
| 922 |
+
|
| 923 |
+
all_resolved = (
|
| 924 |
+
atom_is_resolved[frames[:, 0]]
|
| 925 |
+
& atom_is_resolved[frames[:, 1]]
|
| 926 |
+
& atom_is_resolved[frames[:, 2]]
|
| 927 |
+
)
|
| 928 |
+
all_same = (frames[:, 0] == frames[:, 1]) & (frames[:, 1] == frames[:, 2])
|
| 929 |
+
|
| 930 |
+
v1 = pos1 - pos2
|
| 931 |
+
v2 = pos3 - pos2
|
| 932 |
+
norm1 = np.linalg.norm(v1, axis=1)
|
| 933 |
+
norm2 = np.linalg.norm(v2, axis=1)
|
| 934 |
+
valid_norms = (norm1 >= 1e-6) & (norm2 >= 1e-6)
|
| 935 |
+
|
| 936 |
+
cos_angle = np.zeros(n_tokens, dtype=np.float32)
|
| 937 |
+
mask = valid_norms
|
| 938 |
+
if np.any(mask):
|
| 939 |
+
cos_angle[mask] = np.sum(v1[mask] * v2[mask], axis=1) / (
|
| 940 |
+
norm1[mask] * norm2[mask]
|
| 941 |
+
)
|
| 942 |
+
cos_angle = np.clip(cos_angle, -1, 1)
|
| 943 |
+
angle_deg = np.degrees(np.arccos(np.abs(cos_angle)))
|
| 944 |
+
not_colinear = angle_deg >= 25
|
| 945 |
+
|
| 946 |
+
resolved_mask = all_resolved & ~all_same & valid_norms & not_colinear
|
| 947 |
+
return frames, resolved_mask
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def compute_token_bonds(
|
| 951 |
+
tokens: list[TokenInfo],
|
| 952 |
+
atoms: list[AtomInfo],
|
| 953 |
+
input: StructurePredictionInput,
|
| 954 |
+
chains: list[ChainInfo],
|
| 955 |
+
) -> torch.Tensor:
|
| 956 |
+
"""Compute dense token bond matrix [L, L, 1].
|
| 957 |
+
|
| 958 |
+
Includes ligand intra-residue bonds (from CCD) and covalent bonds.
|
| 959 |
+
"""
|
| 960 |
+
n_tokens = len(tokens)
|
| 961 |
+
edge_set: set[tuple[int, int]] = set()
|
| 962 |
+
|
| 963 |
+
def add_bond(i: int, j: int) -> None:
|
| 964 |
+
if i != j:
|
| 965 |
+
edge_set.add((min(i, j), max(i, j)))
|
| 966 |
+
|
| 967 |
+
# Build per-residue atom name -> token_index mapping for ligands and modified residues
|
| 968 |
+
# Key: (asym_id, residue_index, atom_name) -> token_index
|
| 969 |
+
atom_name_to_token: dict[tuple[int, int, str], int] = {}
|
| 970 |
+
for atom in atoms:
|
| 971 |
+
if atom.is_valid:
|
| 972 |
+
t = tokens[atom.token_index] if atom.token_index < len(tokens) else None
|
| 973 |
+
if t and (
|
| 974 |
+
t.mol_type == MOL_TYPE_NONPOLYMER or t.res_type == PROTEIN_UNK_RES_TYPE
|
| 975 |
+
):
|
| 976 |
+
atom_name_to_token[(t.asym_id, t.residue_index, atom.name)] = (
|
| 977 |
+
atom.token_index
|
| 978 |
+
)
|
| 979 |
+
|
| 980 |
+
# Group atom-tokenized tokens by (asym_id, residue_index)
|
| 981 |
+
residue_tokens: dict[tuple[int, int], list[tuple[str, int]]] = defaultdict(list)
|
| 982 |
+
for atom in atoms:
|
| 983 |
+
if not atom.is_valid:
|
| 984 |
+
continue
|
| 985 |
+
t = tokens[atom.token_index] if atom.token_index < len(tokens) else None
|
| 986 |
+
if t and (
|
| 987 |
+
t.mol_type == MOL_TYPE_NONPOLYMER or t.res_type == PROTEIN_UNK_RES_TYPE
|
| 988 |
+
):
|
| 989 |
+
residue_tokens[(t.asym_id, t.residue_index)].append(
|
| 990 |
+
(atom.name, atom.token_index)
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
# Add intra-residue bonds from CCD
|
| 994 |
+
for (asym_id_val, res_idx), atom_list in residue_tokens.items():
|
| 995 |
+
if not atom_list:
|
| 996 |
+
continue
|
| 997 |
+
res_name = tokens[atom_list[0][1]].residue_name
|
| 998 |
+
ccd_bonds = get_ligand_ccd_bonds(res_name)
|
| 999 |
+
atom_to_tok = {name: ti for name, ti in atom_list}
|
| 1000 |
+
|
| 1001 |
+
if ccd_bonds:
|
| 1002 |
+
for a1, a2 in ccd_bonds:
|
| 1003 |
+
if a1 in atom_to_tok and a2 in atom_to_tok:
|
| 1004 |
+
add_bond(atom_to_tok[a1], atom_to_tok[a2])
|
| 1005 |
+
else:
|
| 1006 |
+
# Fallback: fully connected within residue
|
| 1007 |
+
tok_indices = [ti for _, ti in atom_list]
|
| 1008 |
+
for i_idx in tok_indices:
|
| 1009 |
+
for j_idx in tok_indices:
|
| 1010 |
+
add_bond(i_idx, j_idx)
|
| 1011 |
+
|
| 1012 |
+
# Add covalent bonds from input
|
| 1013 |
+
if input.covalent_bonds:
|
| 1014 |
+
# Build chain_id -> chain mapping
|
| 1015 |
+
chain_by_id: dict[str, ChainInfo] = {c.chain_id: c for c in chains}
|
| 1016 |
+
# Build (asym_id, residue_index) -> list of tokens for atom index lookup
|
| 1017 |
+
chain_res_atoms: dict[tuple[int, int], list[AtomInfo]] = defaultdict(list)
|
| 1018 |
+
for atom in atoms:
|
| 1019 |
+
if atom.is_valid and atom.token_index < len(tokens):
|
| 1020 |
+
t = tokens[atom.token_index]
|
| 1021 |
+
chain_res_atoms[(t.asym_id, t.residue_index)].append(atom)
|
| 1022 |
+
|
| 1023 |
+
for cb in input.covalent_bonds:
|
| 1024 |
+
c1 = chain_by_id.get(cb.chain_id1)
|
| 1025 |
+
c2 = chain_by_id.get(cb.chain_id2)
|
| 1026 |
+
if c1 is None or c2 is None:
|
| 1027 |
+
continue
|
| 1028 |
+
|
| 1029 |
+
atoms_1 = chain_res_atoms.get((c1.asym_id, cb.res_idx1), [])
|
| 1030 |
+
atoms_2 = chain_res_atoms.get((c2.asym_id, cb.res_idx2), [])
|
| 1031 |
+
|
| 1032 |
+
if cb.atom_idx1 < len(atoms_1) and cb.atom_idx2 < len(atoms_2):
|
| 1033 |
+
add_bond(
|
| 1034 |
+
atoms_1[cb.atom_idx1].token_index, atoms_2[cb.atom_idx2].token_index
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
# Add peptide bonds at modified-residue boundaries: an atom-tokenized
|
| 1038 |
+
# residue's N atom connects to the prev residue's C atom (and same for
|
| 1039 |
+
# the C side to the next residue's N).
|
| 1040 |
+
tokens_by_chain_res: dict[tuple[int, int], list[TokenInfo]] = defaultdict(list)
|
| 1041 |
+
for t in tokens:
|
| 1042 |
+
if t.mol_type == MOL_TYPE_PROTEIN:
|
| 1043 |
+
tokens_by_chain_res[(t.asym_id, t.residue_index)].append(t)
|
| 1044 |
+
|
| 1045 |
+
def _backbone_token(res_tokens: list[TokenInfo], atom_name: str) -> int | None:
|
| 1046 |
+
# Standard residue (single token wrapping all atoms): return that token.
|
| 1047 |
+
if len(res_tokens) == 1 and res_tokens[0].res_type != PROTEIN_UNK_RES_TYPE:
|
| 1048 |
+
return res_tokens[0].token_index
|
| 1049 |
+
for t in res_tokens:
|
| 1050 |
+
for a_idx in range(t.atom_start, t.atom_start + t.atom_count):
|
| 1051 |
+
if a_idx < len(atoms) and atoms[a_idx].name == atom_name:
|
| 1052 |
+
return t.token_index
|
| 1053 |
+
# Atom-tokenized residue without an atom of that name (e.g. ACE has
|
| 1054 |
+
# no N, NH2 has no C). Fall back to the first atom-tokenized token.
|
| 1055 |
+
return res_tokens[0].token_index if res_tokens else None
|
| 1056 |
+
|
| 1057 |
+
for (asym_id_val, res_idx), res_tokens in tokens_by_chain_res.items():
|
| 1058 |
+
is_atom_tokenized = any(t.res_type == PROTEIN_UNK_RES_TYPE for t in res_tokens)
|
| 1059 |
+
if not is_atom_tokenized:
|
| 1060 |
+
continue # Standard residue — no peptide bond added here.
|
| 1061 |
+
n_tok = _backbone_token(res_tokens, "N")
|
| 1062 |
+
c_tok = _backbone_token(res_tokens, "C")
|
| 1063 |
+
prev_tokens = tokens_by_chain_res.get((asym_id_val, res_idx - 1))
|
| 1064 |
+
if prev_tokens and n_tok is not None:
|
| 1065 |
+
prev_c = _backbone_token(prev_tokens, "C")
|
| 1066 |
+
if prev_c is not None:
|
| 1067 |
+
add_bond(prev_c, n_tok)
|
| 1068 |
+
next_tokens = tokens_by_chain_res.get((asym_id_val, res_idx + 1))
|
| 1069 |
+
if next_tokens and c_tok is not None:
|
| 1070 |
+
next_n = _backbone_token(next_tokens, "N")
|
| 1071 |
+
if next_n is not None:
|
| 1072 |
+
add_bond(c_tok, next_n)
|
| 1073 |
+
|
| 1074 |
+
# Expand to dense matrix
|
| 1075 |
+
bonds = torch.zeros(n_tokens, n_tokens, 1, dtype=torch.float32)
|
| 1076 |
+
for i, j in edge_set:
|
| 1077 |
+
bonds[i, j, 0] = 1.0
|
| 1078 |
+
bonds[j, i, 0] = 1.0
|
| 1079 |
+
return bonds
|
| 1080 |
+
|
| 1081 |
+
|
| 1082 |
+
def compute_representative_atoms(
|
| 1083 |
+
tokens: list[TokenInfo], atoms: list[AtomInfo]
|
| 1084 |
+
) -> torch.Tensor:
|
| 1085 |
+
"""Compute representative atom index per token (for token_to_rep_atom).
|
| 1086 |
+
|
| 1087 |
+
Returns:
|
| 1088 |
+
distogram_atom_idx: [L] — representative atom per token
|
| 1089 |
+
Protein: CB (or CA for GLY), DNA/RNA: C4/C2/C1', Ligand: first atom.
|
| 1090 |
+
"""
|
| 1091 |
+
n_tokens = len(tokens)
|
| 1092 |
+
|
| 1093 |
+
# Build atom name -> index lookup per token
|
| 1094 |
+
token_atoms: dict[int, dict[str, int]] = defaultdict(dict)
|
| 1095 |
+
for atom in atoms:
|
| 1096 |
+
if atom.is_valid:
|
| 1097 |
+
token_atoms[atom.token_index][atom.name] = atom.atom_index
|
| 1098 |
+
|
| 1099 |
+
distogram_atom_idx = torch.zeros(n_tokens, dtype=torch.int64)
|
| 1100 |
+
|
| 1101 |
+
for t in tokens:
|
| 1102 |
+
ad = token_atoms.get(t.token_index, {})
|
| 1103 |
+
fallback_idx = list(ad.values())[0] if ad else 0
|
| 1104 |
+
|
| 1105 |
+
if t.mol_type == MOL_TYPE_PROTEIN:
|
| 1106 |
+
rep_idx = ad.get("CB", ad.get("CA", fallback_idx))
|
| 1107 |
+
elif t.mol_type in (MOL_TYPE_DNA, MOL_TYPE_RNA):
|
| 1108 |
+
if t.res_type in (27, 32): # Unknown nucleotides
|
| 1109 |
+
rep_idx = ad.get("C1'", fallback_idx)
|
| 1110 |
+
elif t.res_type in (23, 24, 28, 29): # Purines (A, G)
|
| 1111 |
+
rep_idx = ad.get("C4", ad.get("C1'", fallback_idx))
|
| 1112 |
+
else: # Pyrimidines (C, U, T)
|
| 1113 |
+
rep_idx = ad.get("C2", ad.get("C1'", fallback_idx))
|
| 1114 |
+
else:
|
| 1115 |
+
rep_idx = fallback_idx
|
| 1116 |
+
|
| 1117 |
+
distogram_atom_idx[t.token_index] = rep_idx
|
| 1118 |
+
|
| 1119 |
+
return distogram_atom_idx
|
| 1120 |
+
|
| 1121 |
+
|
| 1122 |
+
def compute_msa_features(
|
| 1123 |
+
input: StructurePredictionInput,
|
| 1124 |
+
chains: list[ChainInfo],
|
| 1125 |
+
tokens: list[TokenInfo],
|
| 1126 |
+
max_seqs: int = 16384,
|
| 1127 |
+
) -> dict[str, torch.Tensor]:
|
| 1128 |
+
"""Compute MSA features from protein MSAs.
|
| 1129 |
+
|
| 1130 |
+
Uses taxonomy-based pairing across chains
|
| 1131 |
+
(:func:`paired_msa.construct_paired_msa`): rows whose FASTA header
|
| 1132 |
+
contains ``key=N`` get paired across chains sharing the same ``N``.
|
| 1133 |
+
|
| 1134 |
+
Output: msa [M, L], deletion_value [M, L], has_deletion [M, L],
|
| 1135 |
+
deletion_mean [L], msa_mask [M, L]
|
| 1136 |
+
"""
|
| 1137 |
+
from .esmfold2_paired_msa import (
|
| 1138 |
+
construct_paired_msa,
|
| 1139 |
+
protein_letter_to_res_type,
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
n_tokens = len(tokens)
|
| 1143 |
+
|
| 1144 |
+
# A single ProteinInput with id=['A','B','C',...] yields one item but
|
| 1145 |
+
# multiple chains (one per id); broadcast the MSA across all of them.
|
| 1146 |
+
chain_msas: dict[int, MSA | None] = {}
|
| 1147 |
+
item_idx = 0
|
| 1148 |
+
for item in input.sequences:
|
| 1149 |
+
ids = [item.id] if isinstance(item.id, str) else list(item.id)
|
| 1150 |
+
for _ in ids:
|
| 1151 |
+
chain = chains[item_idx]
|
| 1152 |
+
if isinstance(item, ProteinInput):
|
| 1153 |
+
msa = item.msa
|
| 1154 |
+
if msa is None:
|
| 1155 |
+
msa = MSA.from_sequences([item.sequence])
|
| 1156 |
+
chain_msas[chain.asym_id] = msa
|
| 1157 |
+
else:
|
| 1158 |
+
chain_msas[chain.asym_id] = None
|
| 1159 |
+
item_idx += 1
|
| 1160 |
+
|
| 1161 |
+
letter_to_res_type = protein_letter_to_res_type()
|
| 1162 |
+
|
| 1163 |
+
# Build per-chain query res_types (used for chains without an MSA).
|
| 1164 |
+
chain_query_res_types: dict[int, np.ndarray] = {}
|
| 1165 |
+
for chain in chains:
|
| 1166 |
+
chain_tokens = [t for t in tokens if t.asym_id == chain.asym_id]
|
| 1167 |
+
chain_query_res_types[chain.asym_id] = np.array(
|
| 1168 |
+
[t.res_type for t in chain_tokens], dtype=np.int64
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
token_asym_ids = np.array([t.asym_id for t in tokens], dtype=np.int64)
|
| 1172 |
+
token_res_ids = np.array([t.residue_index for t in tokens], dtype=np.int64)
|
| 1173 |
+
|
| 1174 |
+
msa_res, del_counts, paired = construct_paired_msa(
|
| 1175 |
+
chain_msas,
|
| 1176 |
+
chain_query_res_types,
|
| 1177 |
+
token_asym_ids,
|
| 1178 |
+
token_res_ids,
|
| 1179 |
+
letter_to_res_type=letter_to_res_type,
|
| 1180 |
+
max_seqs=max_seqs,
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
# Tokens for chains without an MSA get their res_type at row 0 and gap
|
| 1184 |
+
# elsewhere; this mirrors the prior non-protein-token branch.
|
| 1185 |
+
for t in tokens:
|
| 1186 |
+
if chain_msas.get(t.asym_id) is None:
|
| 1187 |
+
msa_res[:, t.token_index] = MSA_GAP_TOKEN_ID
|
| 1188 |
+
msa_res[0, t.token_index] = t.res_type
|
| 1189 |
+
|
| 1190 |
+
if msa_res.shape[0] == 0:
|
| 1191 |
+
msa_res = np.full((1, n_tokens), MSA_GAP_TOKEN_ID, dtype=np.int64)
|
| 1192 |
+
del_counts = np.zeros((1, n_tokens), dtype=np.float32)
|
| 1193 |
+
|
| 1194 |
+
msa_data = torch.from_numpy(msa_res)
|
| 1195 |
+
del_data = torch.from_numpy(del_counts)
|
| 1196 |
+
|
| 1197 |
+
has_deletion = del_data > 0
|
| 1198 |
+
deletion_value = (np.pi / 2) * torch.arctan(del_data / 3)
|
| 1199 |
+
deletion_mean = deletion_value.mean(dim=0)
|
| 1200 |
+
|
| 1201 |
+
msa_mask = torch.ones_like(msa_data, dtype=torch.bool)
|
| 1202 |
+
|
| 1203 |
+
return {
|
| 1204 |
+
"msa": msa_data,
|
| 1205 |
+
"deletion_value": deletion_value,
|
| 1206 |
+
"has_deletion": has_deletion,
|
| 1207 |
+
"deletion_mean": deletion_mean,
|
| 1208 |
+
"msa_attention_mask": msa_mask,
|
| 1209 |
+
}
|
| 1210 |
+
|
| 1211 |
+
|
| 1212 |
+
def compute_distogram_conditioning(
|
| 1213 |
+
input: StructurePredictionInput,
|
| 1214 |
+
chains: list[ChainInfo],
|
| 1215 |
+
tokens: list[TokenInfo],
|
| 1216 |
+
disto_center: torch.Tensor,
|
| 1217 |
+
min_dist: float = 2.0,
|
| 1218 |
+
max_dist: float = 22.0,
|
| 1219 |
+
num_bins: int = 64,
|
| 1220 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1221 |
+
"""Compute distogram conditioning from user-provided distograms.
|
| 1222 |
+
|
| 1223 |
+
Returns:
|
| 1224 |
+
disto_cond: [L, L] int64 (bin indices)
|
| 1225 |
+
disto_cond_mask: [L, L] bool
|
| 1226 |
+
"""
|
| 1227 |
+
n_tokens = len(tokens)
|
| 1228 |
+
disto_cond = torch.zeros(n_tokens, n_tokens, dtype=torch.long)
|
| 1229 |
+
disto_cond_mask = torch.zeros(n_tokens, n_tokens, dtype=torch.bool)
|
| 1230 |
+
|
| 1231 |
+
if not input.distogram_conditioning:
|
| 1232 |
+
return disto_cond, disto_cond_mask
|
| 1233 |
+
|
| 1234 |
+
# Build chain_id -> asym_id mapping
|
| 1235 |
+
chain_id_to_asym: dict[str, int] = {c.chain_id: c.asym_id for c in chains}
|
| 1236 |
+
|
| 1237 |
+
# Build asym_id -> token indices mapping
|
| 1238 |
+
asym_to_tokens: dict[int, list[int]] = defaultdict(list)
|
| 1239 |
+
for t in tokens:
|
| 1240 |
+
asym_to_tokens[t.asym_id].append(t.token_index)
|
| 1241 |
+
|
| 1242 |
+
boundaries = torch.linspace(min_dist, max_dist, num_bins + 1)
|
| 1243 |
+
|
| 1244 |
+
for dc in input.distogram_conditioning:
|
| 1245 |
+
asym_id_val = chain_id_to_asym.get(dc.chain_id)
|
| 1246 |
+
if asym_id_val is None:
|
| 1247 |
+
continue
|
| 1248 |
+
tok_indices = asym_to_tokens[asym_id_val]
|
| 1249 |
+
n_chain = len(tok_indices)
|
| 1250 |
+
distogram = torch.tensor(dc.distogram, dtype=torch.float32)
|
| 1251 |
+
|
| 1252 |
+
if distogram.shape != (n_chain, n_chain):
|
| 1253 |
+
raise ValueError(
|
| 1254 |
+
f"Distogram shape {distogram.shape} doesn't match chain length {n_chain}"
|
| 1255 |
+
)
|
| 1256 |
+
|
| 1257 |
+
# Bin the distogram
|
| 1258 |
+
binned = torch.bucketize(distogram, boundaries[:-1]) - 1
|
| 1259 |
+
binned = binned.clamp(0, num_bins - 1)
|
| 1260 |
+
|
| 1261 |
+
for i, ti in enumerate(tok_indices):
|
| 1262 |
+
for j, tj in enumerate(tok_indices):
|
| 1263 |
+
disto_cond[ti, tj] = binned[i, j]
|
| 1264 |
+
disto_cond_mask[ti, tj] = True
|
| 1265 |
+
|
| 1266 |
+
return disto_cond, disto_cond_mask
|
| 1267 |
+
|
| 1268 |
+
|
| 1269 |
+
def build_feature_tensors(
|
| 1270 |
+
chains: list[ChainInfo],
|
| 1271 |
+
tokens: list[TokenInfo],
|
| 1272 |
+
atoms: list[AtomInfo],
|
| 1273 |
+
input: StructurePredictionInput,
|
| 1274 |
+
) -> dict[str, torch.Tensor]:
|
| 1275 |
+
"""Build all model input tensors from tokens and atoms."""
|
| 1276 |
+
n_tokens = len(tokens)
|
| 1277 |
+
n_real_atoms = len(atoms)
|
| 1278 |
+
|
| 1279 |
+
# Pad atoms to nearest multiple of 32
|
| 1280 |
+
target_atoms = math.ceil(n_real_atoms / 32) * 32 if n_real_atoms > 0 else 32
|
| 1281 |
+
n_padding = target_atoms - n_real_atoms
|
| 1282 |
+
padding_atoms = [
|
| 1283 |
+
AtomInfo(
|
| 1284 |
+
name="",
|
| 1285 |
+
element="",
|
| 1286 |
+
charge=0,
|
| 1287 |
+
ref_pos=_ZERO_POS.copy(),
|
| 1288 |
+
pos=_ZERO_POS.copy(),
|
| 1289 |
+
token_index=0,
|
| 1290 |
+
atom_index=n_real_atoms + i,
|
| 1291 |
+
space_uid=0,
|
| 1292 |
+
is_valid=False,
|
| 1293 |
+
)
|
| 1294 |
+
for i in range(n_padding)
|
| 1295 |
+
]
|
| 1296 |
+
all_atoms = atoms + padding_atoms
|
| 1297 |
+
n_atoms = len(all_atoms)
|
| 1298 |
+
|
| 1299 |
+
# --- Token-level tensors ---
|
| 1300 |
+
token_index_arr = np.empty(n_tokens, dtype=np.int64)
|
| 1301 |
+
residue_index_arr = np.empty(n_tokens, dtype=np.int64)
|
| 1302 |
+
asym_id_arr = np.empty(n_tokens, dtype=np.int64)
|
| 1303 |
+
sym_id_arr = np.empty(n_tokens, dtype=np.int64)
|
| 1304 |
+
entity_id_arr = np.empty(n_tokens, dtype=np.int64)
|
| 1305 |
+
mol_type_arr = np.empty(n_tokens, dtype=np.int64)
|
| 1306 |
+
res_type_arr = np.empty(n_tokens, dtype=np.int64)
|
| 1307 |
+
input_ids_arr = np.empty(n_tokens, dtype=np.int64)
|
| 1308 |
+
|
| 1309 |
+
for i, t in enumerate(tokens):
|
| 1310 |
+
token_index_arr[i] = t.token_index
|
| 1311 |
+
residue_index_arr[i] = t.residue_index
|
| 1312 |
+
asym_id_arr[i] = t.asym_id
|
| 1313 |
+
sym_id_arr[i] = t.sym_id
|
| 1314 |
+
entity_id_arr[i] = t.entity_id
|
| 1315 |
+
mol_type_arr[i] = t.mol_type
|
| 1316 |
+
res_type_arr[i] = t.res_type
|
| 1317 |
+
input_ids_arr[i] = t.input_id
|
| 1318 |
+
|
| 1319 |
+
token_index = torch.from_numpy(token_index_arr)
|
| 1320 |
+
residue_index = torch.from_numpy(residue_index_arr)
|
| 1321 |
+
asym_id = torch.from_numpy(asym_id_arr)
|
| 1322 |
+
sym_id = torch.from_numpy(sym_id_arr)
|
| 1323 |
+
entity_id = torch.from_numpy(entity_id_arr)
|
| 1324 |
+
mol_type = torch.from_numpy(mol_type_arr)
|
| 1325 |
+
res_type = torch.from_numpy(res_type_arr)
|
| 1326 |
+
input_ids = torch.from_numpy(input_ids_arr)
|
| 1327 |
+
token_pad_mask = torch.ones(n_tokens, dtype=torch.bool)
|
| 1328 |
+
|
| 1329 |
+
# --- Atom-level tensors ---
|
| 1330 |
+
ref_pos_arr = np.zeros((n_atoms, 3), dtype=np.float32)
|
| 1331 |
+
ref_element_arr = np.zeros(n_atoms, dtype=np.int64)
|
| 1332 |
+
ref_charge_arr = np.zeros(n_atoms, dtype=np.int8)
|
| 1333 |
+
ref_atom_name_chars_arr = np.zeros((n_atoms, 4), dtype=np.int64)
|
| 1334 |
+
ref_space_uid_arr = np.zeros(n_atoms, dtype=np.int64)
|
| 1335 |
+
atom_pad_mask_arr = np.zeros(n_atoms, dtype=np.bool_)
|
| 1336 |
+
atom_to_token_arr = np.zeros(n_atoms, dtype=np.int64)
|
| 1337 |
+
all_positions = np.zeros((n_atoms, 3), dtype=np.float64)
|
| 1338 |
+
is_valid_arr = np.zeros(n_atoms, dtype=np.bool_)
|
| 1339 |
+
|
| 1340 |
+
for i, atom in enumerate(all_atoms):
|
| 1341 |
+
if atom.ref_pos is not None:
|
| 1342 |
+
ref_pos_arr[i] = atom.ref_pos
|
| 1343 |
+
ref_charge_arr[i] = atom.charge
|
| 1344 |
+
ref_space_uid_arr[i] = (
|
| 1345 |
+
atom.space_uid if atom.space_uid >= 0 else atom.token_index
|
| 1346 |
+
)
|
| 1347 |
+
atom_pad_mask_arr[i] = atom.is_valid
|
| 1348 |
+
is_valid_arr[i] = atom.is_valid
|
| 1349 |
+
all_positions[i] = atom.pos
|
| 1350 |
+
|
| 1351 |
+
if atom.is_valid:
|
| 1352 |
+
ref_element_arr[i] = get_element_atomic_num(atom.element)
|
| 1353 |
+
name_indices = encode_atom_name(atom.name)
|
| 1354 |
+
ref_atom_name_chars_arr[i] = name_indices
|
| 1355 |
+
atom_to_token_arr[i] = atom.token_index
|
| 1356 |
+
|
| 1357 |
+
ref_pos = torch.from_numpy(ref_pos_arr)
|
| 1358 |
+
ref_element = torch.from_numpy(ref_element_arr)
|
| 1359 |
+
ref_charge = torch.from_numpy(ref_charge_arr)
|
| 1360 |
+
ref_atom_name_chars = torch.from_numpy(ref_atom_name_chars_arr)
|
| 1361 |
+
ref_space_uid = torch.from_numpy(ref_space_uid_arr)
|
| 1362 |
+
atom_pad_mask = torch.from_numpy(atom_pad_mask_arr)
|
| 1363 |
+
atom_to_token = torch.from_numpy(atom_to_token_arr)
|
| 1364 |
+
|
| 1365 |
+
# Coordinates — center on resolved atoms
|
| 1366 |
+
raw_coords = torch.from_numpy(all_positions)
|
| 1367 |
+
is_nonzero = np.any(all_positions != 0, axis=1)
|
| 1368 |
+
atom_resolved_arr = is_valid_arr & is_nonzero
|
| 1369 |
+
resolved_mask = torch.from_numpy(atom_resolved_arr)
|
| 1370 |
+
valid_mask = torch.from_numpy(is_valid_arr)
|
| 1371 |
+
|
| 1372 |
+
if resolved_mask.any():
|
| 1373 |
+
centroid = raw_coords[resolved_mask].mean(dim=0, keepdim=True)
|
| 1374 |
+
raw_coords = raw_coords - centroid
|
| 1375 |
+
raw_coords[~valid_mask] = 0.0
|
| 1376 |
+
|
| 1377 |
+
coords = raw_coords.float().unsqueeze(0) # [1, A, 3]
|
| 1378 |
+
atom_resolved_mask = torch.tensor(atom_resolved_arr, dtype=torch.bool)
|
| 1379 |
+
|
| 1380 |
+
# --- Frames ---
|
| 1381 |
+
frames, _ = compute_frame_indices(tokens, atoms)
|
| 1382 |
+
frames_idx = torch.from_numpy(frames).to(torch.int64)
|
| 1383 |
+
|
| 1384 |
+
# --- Token bonds ---
|
| 1385 |
+
token_bonds = compute_token_bonds(tokens, atoms, input, chains)
|
| 1386 |
+
|
| 1387 |
+
# --- Representative atoms ---
|
| 1388 |
+
distogram_atom_idx = compute_representative_atoms(tokens, atoms)
|
| 1389 |
+
|
| 1390 |
+
# --- MSA features ---
|
| 1391 |
+
msa_features = compute_msa_features(input, chains, tokens)
|
| 1392 |
+
|
| 1393 |
+
# --- Distogram conditioning ---
|
| 1394 |
+
# disto_center is not needed for inference (no experimental coords)
|
| 1395 |
+
disto_center = torch.zeros(n_tokens, 3, dtype=torch.float32)
|
| 1396 |
+
disto_cond, disto_cond_mask = compute_distogram_conditioning(
|
| 1397 |
+
input, chains, tokens, disto_center
|
| 1398 |
+
)
|
| 1399 |
+
|
| 1400 |
+
# ref_pos: CCD conformer positions, used as-is for inference.
|
| 1401 |
+
# No random rotation or masking — at inference there are no resolved
|
| 1402 |
+
# experimental coordinates, so atom_resolved_mask is all False.
|
| 1403 |
+
# The model uses ref_pos for atom feature embedding.
|
| 1404 |
+
|
| 1405 |
+
# --- Pocket (dropped) ---
|
| 1406 |
+
pocket_feature = torch.zeros(n_tokens, dtype=torch.long)
|
| 1407 |
+
|
| 1408 |
+
return {
|
| 1409 |
+
# Token-level
|
| 1410 |
+
"token_index": token_index,
|
| 1411 |
+
"residue_index": residue_index,
|
| 1412 |
+
"asym_id": asym_id,
|
| 1413 |
+
"entity_id": entity_id,
|
| 1414 |
+
"sym_id": sym_id,
|
| 1415 |
+
"mol_type": mol_type,
|
| 1416 |
+
"res_type": res_type,
|
| 1417 |
+
"input_ids": input_ids,
|
| 1418 |
+
"token_bonds": token_bonds,
|
| 1419 |
+
"token_attention_mask": token_pad_mask,
|
| 1420 |
+
"pocket_feature": pocket_feature,
|
| 1421 |
+
# Atom-level
|
| 1422 |
+
"ref_pos": ref_pos,
|
| 1423 |
+
"ref_element": ref_element,
|
| 1424 |
+
"ref_charge": ref_charge,
|
| 1425 |
+
"ref_atom_name_chars": ref_atom_name_chars,
|
| 1426 |
+
"ref_space_uid": ref_space_uid,
|
| 1427 |
+
"gt_coords": coords,
|
| 1428 |
+
"atom_attention_mask": atom_pad_mask,
|
| 1429 |
+
"atom_to_token": atom_to_token,
|
| 1430 |
+
"is_resolved": atom_resolved_mask,
|
| 1431 |
+
"distogram_atom_idx": distogram_atom_idx,
|
| 1432 |
+
# Frames
|
| 1433 |
+
"frames_idx": frames_idx,
|
| 1434 |
+
# Distogram
|
| 1435 |
+
"disto_cond": disto_cond,
|
| 1436 |
+
"disto_cond_mask": disto_cond_mask,
|
| 1437 |
+
# MSA
|
| 1438 |
+
**msa_features,
|
| 1439 |
+
}
|
| 1440 |
+
|
| 1441 |
+
|
| 1442 |
+
# =============================================================================
|
| 1443 |
+
# Top-level entry point
|
| 1444 |
+
# =============================================================================
|
| 1445 |
+
|
| 1446 |
+
|
| 1447 |
+
def prepare_esmfold2_input(
|
| 1448 |
+
input: StructurePredictionInput, seed: int | None = None
|
| 1449 |
+
) -> tuple[dict[str, torch.Tensor], list[ChainInfo]]:
|
| 1450 |
+
"""Prepare ESMFold2 model inputs from StructurePredictionInput.
|
| 1451 |
+
|
| 1452 |
+
Args:
|
| 1453 |
+
input: The structure prediction input (sequences, conditioning, etc.)
|
| 1454 |
+
seed: Random seed for SMILES conformer generation and augmentation.
|
| 1455 |
+
|
| 1456 |
+
Returns:
|
| 1457 |
+
Tuple of (feature_dict, chain_infos) where feature_dict contains
|
| 1458 |
+
all tensors for the model forward pass, and chain_infos contains
|
| 1459 |
+
metadata for output processing.
|
| 1460 |
+
"""
|
| 1461 |
+
chains, tokens, atoms = build_chains_from_input(input, seed)
|
| 1462 |
+
features = build_feature_tensors(chains, tokens, atoms, input)
|
| 1463 |
+
return features, chains
|
| 1464 |
+
|
esmfold2_processor.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from contextlib import contextmanager, nullcontext
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .esmfold2_conformers import load_ccd
|
| 10 |
+
from .esmfold2_output import build_molecular_complex_from_features
|
| 11 |
+
from .esmfold2_prepare_input import ChainInfo, prepare_esmfold2_input
|
| 12 |
+
from .esmfold2_types import (
|
| 13 |
+
MSA,
|
| 14 |
+
Modification,
|
| 15 |
+
ProteinInput,
|
| 16 |
+
StructurePredictionInput,
|
| 17 |
+
)
|
| 18 |
+
from .esmfold2_molecular_complex import MolecularComplexResult
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@contextmanager
|
| 22 |
+
def _seed_context(seed: int | None):
|
| 23 |
+
if seed is None:
|
| 24 |
+
yield
|
| 25 |
+
return
|
| 26 |
+
py_state = random.getstate()
|
| 27 |
+
np_state = np.random.get_state()
|
| 28 |
+
torch_state = torch.random.get_rng_state()
|
| 29 |
+
cuda_state = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
|
| 30 |
+
random.seed(seed)
|
| 31 |
+
np.random.seed(seed)
|
| 32 |
+
torch.manual_seed(seed)
|
| 33 |
+
if torch.cuda.is_available():
|
| 34 |
+
torch.cuda.manual_seed_all(seed)
|
| 35 |
+
try:
|
| 36 |
+
yield
|
| 37 |
+
finally:
|
| 38 |
+
random.setstate(py_state)
|
| 39 |
+
np.random.set_state(np_state)
|
| 40 |
+
torch.random.set_rng_state(torch_state)
|
| 41 |
+
if cuda_state is not None:
|
| 42 |
+
torch.cuda.set_rng_state_all(cuda_state)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def clean_esmfold2_input(input: StructurePredictionInput) -> StructurePredictionInput:
|
| 46 |
+
"""Group identical protein sequences into the same ProteinInput with multiple ids.
|
| 47 |
+
|
| 48 |
+
Example: Passing a tetramer like [ProteinInput(id=["0"], seq="AAA|AAA|BBB|BBB")]
|
| 49 |
+
gets converted into [ProteinInput(id=["0_0", "0_1"], seq="AAA"),
|
| 50 |
+
ProteinInput(id=["0_2", "0_3"], seq="BBB")]
|
| 51 |
+
|
| 52 |
+
Preserves the original order of unique sequences. Also converts "|" chainbreak
|
| 53 |
+
tokens to ":" in the sequence.
|
| 54 |
+
"""
|
| 55 |
+
cleaned_sequences: list = []
|
| 56 |
+
chain_to_ids: dict[str, list[str]] = {}
|
| 57 |
+
chain_to_modifications: dict[str, list] = {}
|
| 58 |
+
chain_to_msa: dict[str, MSA | None] = {}
|
| 59 |
+
|
| 60 |
+
for item in input.sequences:
|
| 61 |
+
if isinstance(item, ProteinInput):
|
| 62 |
+
sequence = ":".join(item.sequence.split("|"))
|
| 63 |
+
if ":" not in sequence:
|
| 64 |
+
cleaned_sequences.append(item)
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
if ":" in sequence and input.covalent_bonds is not None:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
"Covalent bonds are not supported when using chainbreaks. "
|
| 70 |
+
"Chains must be separated into multiple ProteinInput objects."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
base_id = item.id[0] if isinstance(item.id, list) else item.id
|
| 74 |
+
chain_to_ids = {}
|
| 75 |
+
chain_to_modifications = {}
|
| 76 |
+
chain_to_msa = {}
|
| 77 |
+
chains = sequence.split(":")
|
| 78 |
+
|
| 79 |
+
chain_start_positions = []
|
| 80 |
+
pos = 0
|
| 81 |
+
for chain in chains:
|
| 82 |
+
chain_start_positions.append(pos)
|
| 83 |
+
pos += len(chain) + 1
|
| 84 |
+
|
| 85 |
+
if item.modifications is not None:
|
| 86 |
+
for chain_idx, chain in enumerate(chains):
|
| 87 |
+
chain_start = chain_start_positions[chain_idx]
|
| 88 |
+
chain_end = chain_start + len(chain)
|
| 89 |
+
chain_modifications = []
|
| 90 |
+
for mod in item.modifications:
|
| 91 |
+
if chain_start <= mod.position < chain_end:
|
| 92 |
+
adjusted_mod = Modification(
|
| 93 |
+
position=mod.position - chain_start, ccd=mod.ccd
|
| 94 |
+
)
|
| 95 |
+
chain_modifications.append(adjusted_mod)
|
| 96 |
+
if chain not in chain_to_modifications:
|
| 97 |
+
chain_to_modifications[chain] = chain_modifications
|
| 98 |
+
else:
|
| 99 |
+
chain_to_modifications[chain].extend(chain_modifications)
|
| 100 |
+
|
| 101 |
+
if item.msa is not None:
|
| 102 |
+
for chain_idx, chain in enumerate(chains):
|
| 103 |
+
if chain not in chain_to_msa:
|
| 104 |
+
chain_start = chain_start_positions[chain_idx]
|
| 105 |
+
chain_end = chain_start + len(chain)
|
| 106 |
+
chain_msa = item.msa.select_positions( # type: ignore
|
| 107 |
+
np.arange(chain_start, chain_end)
|
| 108 |
+
)
|
| 109 |
+
chain_to_msa[chain] = chain_msa
|
| 110 |
+
|
| 111 |
+
for i, chain in enumerate(chains):
|
| 112 |
+
chain_id = base_id + "_" + str(i)
|
| 113 |
+
if chain in chain_to_ids:
|
| 114 |
+
chain_to_ids[chain].append(chain_id)
|
| 115 |
+
else:
|
| 116 |
+
chain_to_ids[chain] = [chain_id]
|
| 117 |
+
cleaned_sequences.append((item, chain))
|
| 118 |
+
else:
|
| 119 |
+
cleaned_sequences.append(item)
|
| 120 |
+
|
| 121 |
+
for i in range(len(cleaned_sequences)):
|
| 122 |
+
if isinstance(cleaned_sequences[i], tuple):
|
| 123 |
+
item, chain = cleaned_sequences[i]
|
| 124 |
+
chain_ids = chain_to_ids[chain]
|
| 125 |
+
chain_modifications = (
|
| 126 |
+
chain_to_modifications.get(chain) if item.modifications else None
|
| 127 |
+
)
|
| 128 |
+
chain_msa = chain_to_msa.get(chain) if item.msa else None
|
| 129 |
+
cleaned_sequences[i] = ProteinInput(
|
| 130 |
+
id=chain_ids,
|
| 131 |
+
sequence=chain,
|
| 132 |
+
msa=chain_msa,
|
| 133 |
+
modifications=chain_modifications,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return StructurePredictionInput(
|
| 137 |
+
sequences=cleaned_sequences,
|
| 138 |
+
distogram_conditioning=input.distogram_conditioning,
|
| 139 |
+
covalent_bonds=input.covalent_bonds,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class ESMFold2InputBuilder:
|
| 144 |
+
def __init__(self, ccd_cache: Path | None = None):
|
| 145 |
+
load_ccd(ccd_cache)
|
| 146 |
+
|
| 147 |
+
def prepare_input(
|
| 148 |
+
self,
|
| 149 |
+
input: StructurePredictionInput,
|
| 150 |
+
seed: int | None = None,
|
| 151 |
+
device: torch.device | str | None = None,
|
| 152 |
+
) -> tuple[dict, list[ChainInfo]]:
|
| 153 |
+
"""Prepare raw input for the folding model.
|
| 154 |
+
|
| 155 |
+
Converts user-provided StructurePredictionInput into batched tensors
|
| 156 |
+
ready for model inference.
|
| 157 |
+
|
| 158 |
+
Parameters
|
| 159 |
+
----------
|
| 160 |
+
input : StructurePredictionInput
|
| 161 |
+
Input specification (sequences, structures, constraints, etc.).
|
| 162 |
+
seed : int, optional
|
| 163 |
+
Random seed for reproducibility.
|
| 164 |
+
device : torch.device or str, optional
|
| 165 |
+
Target device for the returned tensors. Defaults to CPU; pass
|
| 166 |
+
``model.device`` to skip a separate ``.to(...)`` step. ``fold()``
|
| 167 |
+
forwards ``model.device`` automatically.
|
| 168 |
+
|
| 169 |
+
Returns
|
| 170 |
+
-------
|
| 171 |
+
tuple[dict, list[ChainInfo]]
|
| 172 |
+
Batched input tensors and chain metadata for output processing.
|
| 173 |
+
"""
|
| 174 |
+
structure_prediction_input = clean_esmfold2_input(input)
|
| 175 |
+
with _seed_context(seed) if seed is not None else nullcontext():
|
| 176 |
+
features, chain_infos = prepare_esmfold2_input(
|
| 177 |
+
structure_prediction_input, seed=seed
|
| 178 |
+
)
|
| 179 |
+
features = {
|
| 180 |
+
k: (v[None].to(device) if device is not None else v[None])
|
| 181 |
+
if isinstance(v, torch.Tensor)
|
| 182 |
+
else v
|
| 183 |
+
for k, v in features.items()
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
return features, chain_infos
|
| 187 |
+
|
| 188 |
+
def __call__(
|
| 189 |
+
self,
|
| 190 |
+
input: StructurePredictionInput,
|
| 191 |
+
seed: int | None = None,
|
| 192 |
+
device: torch.device | str | None = None,
|
| 193 |
+
) -> tuple[dict, list[ChainInfo]]:
|
| 194 |
+
return self.prepare_input(input, seed=seed, device=device)
|
| 195 |
+
|
| 196 |
+
def decode(
|
| 197 |
+
self,
|
| 198 |
+
output: dict[str, torch.Tensor],
|
| 199 |
+
features: dict[str, torch.Tensor],
|
| 200 |
+
chain_infos: list[ChainInfo],
|
| 201 |
+
*,
|
| 202 |
+
num_diffusion_samples: int = 1,
|
| 203 |
+
complex_id: str = "pred",
|
| 204 |
+
) -> MolecularComplexResult | list[MolecularComplexResult]:
|
| 205 |
+
"""Convert raw model outputs into one MolecularComplexResult per sample.
|
| 206 |
+
|
| 207 |
+
Parameters
|
| 208 |
+
----------
|
| 209 |
+
output : dict[str, Tensor]
|
| 210 |
+
Output dict returned by ESMFold2Model.forward.
|
| 211 |
+
features : dict[str, Tensor]
|
| 212 |
+
Feature dict from :meth:`prepare_input` (batched, on the model device).
|
| 213 |
+
chain_infos : list[ChainInfo]
|
| 214 |
+
Chain metadata returned alongside `features`.
|
| 215 |
+
num_diffusion_samples : int
|
| 216 |
+
Number of diffusion samples present in the output (Bm = B * num_diffusion_samples).
|
| 217 |
+
complex_id : str
|
| 218 |
+
Identifier assigned to each MolecularComplex.
|
| 219 |
+
|
| 220 |
+
Returns
|
| 221 |
+
-------
|
| 222 |
+
MolecularComplexResult or list[MolecularComplexResult]
|
| 223 |
+
A single result when num_diffusion_samples == 1, otherwise a list of length Bm.
|
| 224 |
+
"""
|
| 225 |
+
atom_mask = features["atom_attention_mask"][0]
|
| 226 |
+
ref_element = features["ref_element"][0]
|
| 227 |
+
ref_atom_name_chars = features["ref_atom_name_chars"][0]
|
| 228 |
+
|
| 229 |
+
sample_coords = output["sample_atom_coords"]
|
| 230 |
+
plddts = output["plddt"]
|
| 231 |
+
Bm = sample_coords.shape[0]
|
| 232 |
+
|
| 233 |
+
ptm_t = output.get("ptm")
|
| 234 |
+
iptm_t = output.get("iptm")
|
| 235 |
+
pae_t = output.get("pae")
|
| 236 |
+
distogram_t = output.get("distogram_logits")
|
| 237 |
+
pair_chains_t = output.get("pair_chains_iptm")
|
| 238 |
+
residue_index_t = output.get("residue_index")
|
| 239 |
+
entity_id_t = output.get("entity_id")
|
| 240 |
+
|
| 241 |
+
results: list[MolecularComplexResult] = []
|
| 242 |
+
for i in range(Bm):
|
| 243 |
+
mc = build_molecular_complex_from_features(
|
| 244 |
+
coords=sample_coords[i],
|
| 245 |
+
plddt=plddts[i],
|
| 246 |
+
atom_mask=atom_mask,
|
| 247 |
+
ref_element=ref_element,
|
| 248 |
+
ref_atom_name_chars=ref_atom_name_chars,
|
| 249 |
+
chain_infos=chain_infos,
|
| 250 |
+
complex_id=complex_id,
|
| 251 |
+
)
|
| 252 |
+
results.append(
|
| 253 |
+
MolecularComplexResult(
|
| 254 |
+
complex=mc,
|
| 255 |
+
plddt=plddts[i].detach().cpu(),
|
| 256 |
+
ptm=float(ptm_t[i].item()) if ptm_t is not None else None,
|
| 257 |
+
iptm=float(iptm_t[i].item()) if iptm_t is not None else None,
|
| 258 |
+
pae=pae_t[i].detach().cpu() if pae_t is not None else None,
|
| 259 |
+
distogram=(
|
| 260 |
+
distogram_t[0].detach().cpu()
|
| 261 |
+
if distogram_t is not None
|
| 262 |
+
else None
|
| 263 |
+
),
|
| 264 |
+
pair_chains_iptm=(
|
| 265 |
+
pair_chains_t[i].detach().cpu()
|
| 266 |
+
if pair_chains_t is not None
|
| 267 |
+
else None
|
| 268 |
+
),
|
| 269 |
+
residue_index=(
|
| 270 |
+
residue_index_t[0].detach().cpu()
|
| 271 |
+
if residue_index_t is not None
|
| 272 |
+
else None
|
| 273 |
+
),
|
| 274 |
+
entity_id=(
|
| 275 |
+
entity_id_t[0].detach().cpu()
|
| 276 |
+
if entity_id_t is not None
|
| 277 |
+
else None
|
| 278 |
+
),
|
| 279 |
+
)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if num_diffusion_samples == 1 and len(results) == 1:
|
| 283 |
+
return results[0]
|
| 284 |
+
return results
|
| 285 |
+
|
| 286 |
+
def fold(
|
| 287 |
+
self,
|
| 288 |
+
model: Any,
|
| 289 |
+
input: StructurePredictionInput,
|
| 290 |
+
*,
|
| 291 |
+
num_loops: int = 3,
|
| 292 |
+
num_sampling_steps: int = 200,
|
| 293 |
+
num_diffusion_samples: int = 1,
|
| 294 |
+
seed: int | None = None,
|
| 295 |
+
noise_scale: float | None = None,
|
| 296 |
+
step_scale: float | None = None,
|
| 297 |
+
max_inference_sigma: int | None = None,
|
| 298 |
+
early_exit: bool = False,
|
| 299 |
+
complex_id: str = "pred",
|
| 300 |
+
) -> MolecularComplexResult | list[MolecularComplexResult]:
|
| 301 |
+
"""Fold a structure end-to-end: encode → model → decode.
|
| 302 |
+
|
| 303 |
+
Parameters
|
| 304 |
+
----------
|
| 305 |
+
model : ESMFold2Model
|
| 306 |
+
The folding model. Must already be on the target device and in eval mode.
|
| 307 |
+
input : StructurePredictionInput
|
| 308 |
+
User-facing input specification.
|
| 309 |
+
num_loops, num_sampling_steps, num_diffusion_samples : int
|
| 310 |
+
Inference knobs forwarded to the model.
|
| 311 |
+
seed : int, optional
|
| 312 |
+
Seeds both input prep (SMILES conformer generation) and diffusion sampling.
|
| 313 |
+
noise_scale, step_scale, max_inference_sigma, early_exit
|
| 314 |
+
Optional sampler overrides forwarded to the model when not None.
|
| 315 |
+
complex_id : str
|
| 316 |
+
Identifier assigned to the predicted MolecularComplex(es).
|
| 317 |
+
|
| 318 |
+
Returns
|
| 319 |
+
-------
|
| 320 |
+
MolecularComplexResult or list[MolecularComplexResult]
|
| 321 |
+
A single result when num_diffusion_samples == 1, otherwise a list.
|
| 322 |
+
"""
|
| 323 |
+
features, chain_infos = self.prepare_input(
|
| 324 |
+
input, seed=seed, device=model.device
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
sampler_kwargs: dict[str, Any] = {}
|
| 328 |
+
if noise_scale is not None:
|
| 329 |
+
sampler_kwargs["noise_scale"] = noise_scale
|
| 330 |
+
if step_scale is not None:
|
| 331 |
+
sampler_kwargs["step_scale"] = step_scale
|
| 332 |
+
if max_inference_sigma is not None:
|
| 333 |
+
sampler_kwargs["max_inference_sigma"] = max_inference_sigma
|
| 334 |
+
|
| 335 |
+
with torch.no_grad():
|
| 336 |
+
with _seed_context(seed) if seed is not None else nullcontext():
|
| 337 |
+
output = model(
|
| 338 |
+
**features,
|
| 339 |
+
num_loops=num_loops,
|
| 340 |
+
num_sampling_steps=num_sampling_steps,
|
| 341 |
+
num_diffusion_samples=num_diffusion_samples,
|
| 342 |
+
early_exit=early_exit,
|
| 343 |
+
**sampler_kwargs,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
return self.decode(
|
| 347 |
+
output,
|
| 348 |
+
features,
|
| 349 |
+
chain_infos,
|
| 350 |
+
num_diffusion_samples=num_diffusion_samples,
|
| 351 |
+
complex_id=complex_id,
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
__all__ = ["ESMFold2InputBuilder", "clean_esmfold2_input"]
|
| 356 |
+
|
esmfold2_protein_chain.py
ADDED
|
@@ -0,0 +1,1376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import io
|
| 4 |
+
import warnings
|
| 5 |
+
from dataclasses import asdict, dataclass, replace
|
| 6 |
+
from functools import cached_property
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Mapping, Sequence
|
| 9 |
+
|
| 10 |
+
import biotite.structure as bs
|
| 11 |
+
import brotli
|
| 12 |
+
import msgpack
|
| 13 |
+
import msgpack_numpy
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from biotite.database import rcsb
|
| 17 |
+
from biotite.structure.io.pdb import PDBFile
|
| 18 |
+
from biotite.structure.io.pdbx import CIFCategory, CIFColumn, CIFData, CIFFile
|
| 19 |
+
from biotite.structure.io.pdbx import set_structure as set_structure_pdbx
|
| 20 |
+
from scipy.spatial import ConvexHull, KDTree
|
| 21 |
+
from scipy.spatial.distance import cdist, pdist, squareform
|
| 22 |
+
|
| 23 |
+
from . import esmfold2_residue_constants
|
| 24 |
+
from .esmfold2_misc import slice_python_object_as_numpy
|
| 25 |
+
from .esmfold2_affine3d import Affine3D
|
| 26 |
+
from .esmfold2_aligner import Aligner
|
| 27 |
+
from .esmfold2_atom_indexer import AtomIndexer
|
| 28 |
+
from .esmfold2_metrics import compute_gdt_ts, compute_lddt_ca
|
| 29 |
+
from .esmfold2_mmcif_parsing import MmcifWrapper, Residue
|
| 30 |
+
from .esmfold2_normalize_coordinates import (
|
| 31 |
+
apply_frame_to_coords,
|
| 32 |
+
get_protein_normalization_frame,
|
| 33 |
+
)
|
| 34 |
+
from .esmfold2_protein_structure import index_by_atom_name
|
| 35 |
+
from .esmfold2_utils_types import PathOrBuffer
|
| 36 |
+
|
| 37 |
+
msgpack_numpy.patch()
|
| 38 |
+
CHAIN_ID_CONST = "A"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _str_key_to_int_key(dct: dict, ignore_keys: list[str] | None = None) -> dict:
|
| 42 |
+
new_dict = {}
|
| 43 |
+
for k, v in dct.items():
|
| 44 |
+
v_new = v
|
| 45 |
+
if k not in ignore_keys and isinstance(v, dict):
|
| 46 |
+
v_new = _str_key_to_int_key(v, ignore_keys=ignore_keys)
|
| 47 |
+
# Note assembly_composition is *supposed* to have string keys.
|
| 48 |
+
if isinstance(k, str) and k.isdigit():
|
| 49 |
+
new_dict[int(k)] = v_new
|
| 50 |
+
else:
|
| 51 |
+
new_dict[k] = v_new
|
| 52 |
+
return new_dict
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _num_non_null_residues(seqres_to_structure_chain: Mapping[int, Residue]) -> int:
|
| 56 |
+
return sum(
|
| 57 |
+
residue.residue_number is not None
|
| 58 |
+
for residue in seqres_to_structure_chain.values()
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def infer_CB(C, N, Ca, L: float = 1.522, A: float = 1.927, D: float = -2.143):
|
| 63 |
+
"""
|
| 64 |
+
Inspired by a util in trDesign:
|
| 65 |
+
https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L92
|
| 66 |
+
|
| 67 |
+
input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
|
| 68 |
+
output: 4th coord
|
| 69 |
+
"""
|
| 70 |
+
norm = lambda x: x / np.sqrt(np.square(x).sum(-1, keepdims=True) + 1e-8)
|
| 71 |
+
with np.errstate(invalid="ignore"): # inf - inf = nan is ok here
|
| 72 |
+
vec_bc = N - Ca
|
| 73 |
+
vec_ba = N - C
|
| 74 |
+
bc = norm(vec_bc)
|
| 75 |
+
n = norm(np.cross(vec_ba, bc))
|
| 76 |
+
m = [bc, np.cross(n, bc), n]
|
| 77 |
+
d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
|
| 78 |
+
return Ca + sum([m * d for m, d in zip(m, d)])
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def chain_to_ndarray(
|
| 82 |
+
atom_array: bs.AtomArray, mmcif: MmcifWrapper, chain_id: str, is_predicted=False
|
| 83 |
+
):
|
| 84 |
+
entity_id = None
|
| 85 |
+
for entity, chains in mmcif.entities.items():
|
| 86 |
+
if chain_id in chains:
|
| 87 |
+
entity_id = entity
|
| 88 |
+
num_res = len(mmcif.chain_to_seqres[chain_id])
|
| 89 |
+
sequence = mmcif.chain_to_seqres[chain_id]
|
| 90 |
+
|
| 91 |
+
atom_positions = np.full([num_res, residue_constants.atom_type_num, 3], np.nan)
|
| 92 |
+
atom_mask = np.full([num_res, residue_constants.atom_type_num], False, dtype=bool)
|
| 93 |
+
residue_index = np.full([num_res], -1, dtype=np.int64)
|
| 94 |
+
insertion_code = np.full([num_res], "", dtype="<U4")
|
| 95 |
+
|
| 96 |
+
confidence = np.ones([num_res], dtype=np.float32)
|
| 97 |
+
|
| 98 |
+
for res_index in range(num_res):
|
| 99 |
+
chain = atom_array[atom_array.chain_id == chain_id]
|
| 100 |
+
assert isinstance(chain, bs.AtomArray)
|
| 101 |
+
res_at_position = mmcif.seqres_to_structure[chain_id][res_index]
|
| 102 |
+
|
| 103 |
+
if res_at_position.residue_number is None:
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
residue_index[res_index] = res_at_position.residue_number
|
| 107 |
+
insertion_code[res_index] = res_at_position.insertion_code
|
| 108 |
+
res = chain[
|
| 109 |
+
(chain.res_id == res_at_position.residue_number)
|
| 110 |
+
& (chain.ins_code == res_at_position.insertion_code)
|
| 111 |
+
& (chain.hetero == res_at_position.hetflag)
|
| 112 |
+
]
|
| 113 |
+
assert isinstance(res, bs.AtomArray)
|
| 114 |
+
|
| 115 |
+
# Atom level features
|
| 116 |
+
for atom in res:
|
| 117 |
+
atom_name = atom.atom_name
|
| 118 |
+
if atom_name == "SE" and atom.res_name == "MSE":
|
| 119 |
+
# Put the coords of the selenium atom in the sulphur column
|
| 120 |
+
atom_name = "SD"
|
| 121 |
+
|
| 122 |
+
if atom_name in residue_constants.atom_order:
|
| 123 |
+
atom_positions[res_index, residue_constants.atom_order[atom_name]] = (
|
| 124 |
+
atom.coord
|
| 125 |
+
)
|
| 126 |
+
atom_mask[res_index, residue_constants.atom_order[atom_name]] = True
|
| 127 |
+
if is_predicted and atom_name == "CA":
|
| 128 |
+
confidence[res_index] = atom.b_factor
|
| 129 |
+
|
| 130 |
+
assert all(sequence), "Some residue name was not specified correctly"
|
| 131 |
+
return (
|
| 132 |
+
sequence,
|
| 133 |
+
atom_positions,
|
| 134 |
+
atom_mask,
|
| 135 |
+
residue_index,
|
| 136 |
+
insertion_code,
|
| 137 |
+
confidence,
|
| 138 |
+
entity_id,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@dataclass(frozen=True)
|
| 143 |
+
class ProteinChain:
|
| 144 |
+
"""Dataclass with atom37 representation of a single protein chain."""
|
| 145 |
+
|
| 146 |
+
id: str
|
| 147 |
+
sequence: str
|
| 148 |
+
chain_id: str # author chain id - mutable
|
| 149 |
+
entity_id: int | None
|
| 150 |
+
residue_index: np.ndarray
|
| 151 |
+
insertion_code: np.ndarray
|
| 152 |
+
atom37_positions: np.ndarray
|
| 153 |
+
atom37_mask: np.ndarray
|
| 154 |
+
confidence: np.ndarray
|
| 155 |
+
mmcif: MmcifWrapper | None = None
|
| 156 |
+
atom37_confidence: np.ndarray | None = None # [L, 37] per-atom pLDDT
|
| 157 |
+
|
| 158 |
+
def __post_init__(self):
|
| 159 |
+
assert self.atom37_mask.dtype == bool, self.atom37_mask.dtype
|
| 160 |
+
assert self.atom37_positions.shape[0] == len(self.sequence), (
|
| 161 |
+
self.atom37_positions.shape,
|
| 162 |
+
len(self.sequence),
|
| 163 |
+
)
|
| 164 |
+
assert self.atom37_mask.shape[0] == len(self.sequence), (
|
| 165 |
+
self.atom37_mask.shape,
|
| 166 |
+
len(self.sequence),
|
| 167 |
+
)
|
| 168 |
+
assert self.residue_index.shape[0] == len(self.sequence), (
|
| 169 |
+
self.residue_index.shape,
|
| 170 |
+
len(self.sequence),
|
| 171 |
+
)
|
| 172 |
+
assert self.insertion_code.shape[0] == len(self.sequence), (
|
| 173 |
+
self.insertion_code.shape,
|
| 174 |
+
len(self.sequence),
|
| 175 |
+
)
|
| 176 |
+
assert self.confidence.shape[0] == len(self.sequence), (
|
| 177 |
+
self.confidence.shape,
|
| 178 |
+
len(self.sequence),
|
| 179 |
+
)
|
| 180 |
+
if self.atom37_confidence is not None:
|
| 181 |
+
assert self.atom37_confidence.shape == self.atom37_mask.shape, (
|
| 182 |
+
self.atom37_confidence.shape,
|
| 183 |
+
self.atom37_mask.shape,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
@cached_property
|
| 187 |
+
def atoms(self) -> AtomIndexer:
|
| 188 |
+
return AtomIndexer(self, property="atom37_positions", dim=-2)
|
| 189 |
+
|
| 190 |
+
@cached_property
|
| 191 |
+
def atom_mask(self) -> AtomIndexer:
|
| 192 |
+
return AtomIndexer(self, property="atom37_mask", dim=-1)
|
| 193 |
+
|
| 194 |
+
@cached_property
|
| 195 |
+
def atom_array(self) -> bs.AtomArray:
|
| 196 |
+
atoms = []
|
| 197 |
+
for res_idx_i, (
|
| 198 |
+
res_name,
|
| 199 |
+
res_idx,
|
| 200 |
+
ins_code,
|
| 201 |
+
positions,
|
| 202 |
+
mask,
|
| 203 |
+
conf,
|
| 204 |
+
) in enumerate(
|
| 205 |
+
zip(
|
| 206 |
+
self.sequence,
|
| 207 |
+
self.residue_index,
|
| 208 |
+
self.insertion_code,
|
| 209 |
+
self.atom37_positions,
|
| 210 |
+
self.atom37_mask.astype(bool),
|
| 211 |
+
self.confidence,
|
| 212 |
+
)
|
| 213 |
+
):
|
| 214 |
+
for i, pos in zip(np.where(mask)[0], positions[mask]):
|
| 215 |
+
b_factor = (
|
| 216 |
+
self.atom37_confidence[res_idx_i, i]
|
| 217 |
+
if self.atom37_confidence is not None
|
| 218 |
+
else conf
|
| 219 |
+
)
|
| 220 |
+
atom = bs.Atom(
|
| 221 |
+
coord=pos,
|
| 222 |
+
chain_id="A" if self.chain_id is None else self.chain_id,
|
| 223 |
+
res_id=res_idx,
|
| 224 |
+
ins_code=ins_code,
|
| 225 |
+
res_name=residue_constants.restype_1to3.get(res_name, "UNK"),
|
| 226 |
+
hetero=False,
|
| 227 |
+
atom_name=residue_constants.atom_types[i],
|
| 228 |
+
element=residue_constants.atom_types[i][0],
|
| 229 |
+
b_factor=float(b_factor),
|
| 230 |
+
)
|
| 231 |
+
atoms.append(atom)
|
| 232 |
+
return bs.array(atoms)
|
| 233 |
+
|
| 234 |
+
@cached_property
|
| 235 |
+
def residue_index_no_insertions(self) -> np.ndarray:
|
| 236 |
+
return self.residue_index + np.cumsum(self.insertion_code != "")
|
| 237 |
+
|
| 238 |
+
@cached_property
|
| 239 |
+
def atom_array_no_insertions(self) -> bs.AtomArray:
|
| 240 |
+
atoms = []
|
| 241 |
+
for res_idx, (res_name, positions, mask, conf) in enumerate(
|
| 242 |
+
zip(
|
| 243 |
+
self.sequence,
|
| 244 |
+
self.atom37_positions,
|
| 245 |
+
self.atom37_mask.astype(bool),
|
| 246 |
+
self.confidence,
|
| 247 |
+
)
|
| 248 |
+
):
|
| 249 |
+
for i, pos in zip(np.where(mask)[0], positions[mask]):
|
| 250 |
+
b_factor = (
|
| 251 |
+
self.atom37_confidence[res_idx, i]
|
| 252 |
+
if self.atom37_confidence is not None
|
| 253 |
+
else conf
|
| 254 |
+
)
|
| 255 |
+
atom = bs.Atom(
|
| 256 |
+
coord=pos,
|
| 257 |
+
# hard coded to as we currently only support single chain structures
|
| 258 |
+
chain_id=CHAIN_ID_CONST,
|
| 259 |
+
res_id=res_idx + 1,
|
| 260 |
+
res_name=residue_constants.restype_1to3.get(res_name, "UNK"),
|
| 261 |
+
hetero=False,
|
| 262 |
+
atom_name=residue_constants.atom_types[i],
|
| 263 |
+
element=residue_constants.atom_types[i][0],
|
| 264 |
+
b_factor=float(b_factor),
|
| 265 |
+
)
|
| 266 |
+
atoms.append(atom)
|
| 267 |
+
return bs.array(atoms)
|
| 268 |
+
|
| 269 |
+
def __getitem__(self, idx: int | list[int] | slice | np.ndarray | torch.Tensor):
|
| 270 |
+
if isinstance(idx, int):
|
| 271 |
+
idx = [idx]
|
| 272 |
+
if isinstance(idx, torch.Tensor):
|
| 273 |
+
idx = idx.cpu().numpy()
|
| 274 |
+
|
| 275 |
+
sequence = slice_python_object_as_numpy(self.sequence, idx)
|
| 276 |
+
return replace(
|
| 277 |
+
self,
|
| 278 |
+
sequence=sequence,
|
| 279 |
+
residue_index=self.residue_index[..., idx],
|
| 280 |
+
insertion_code=self.insertion_code[..., idx],
|
| 281 |
+
atom37_positions=self.atom37_positions[..., idx, :, :],
|
| 282 |
+
atom37_mask=self.atom37_mask[..., idx, :],
|
| 283 |
+
confidence=self.confidence[..., idx],
|
| 284 |
+
atom37_confidence=self.atom37_confidence[..., idx, :]
|
| 285 |
+
if self.atom37_confidence is not None
|
| 286 |
+
else None,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
def __len__(self):
|
| 290 |
+
return len(self.sequence)
|
| 291 |
+
|
| 292 |
+
def cbeta_contacts(self, distance_threshold: float = 8.0) -> np.ndarray:
|
| 293 |
+
distance = self.pdist_CB
|
| 294 |
+
contacts = (distance < distance_threshold).astype(np.int64)
|
| 295 |
+
contacts[np.isnan(distance)] = -1
|
| 296 |
+
np.fill_diagonal(contacts, -1)
|
| 297 |
+
return contacts
|
| 298 |
+
|
| 299 |
+
def to_pdb(self, path: PathOrBuffer, include_insertions: bool = True):
|
| 300 |
+
"""Dssp works better w/o insertions."""
|
| 301 |
+
f = PDBFile()
|
| 302 |
+
if not include_insertions:
|
| 303 |
+
f.set_structure(self.atom_array_no_insertions)
|
| 304 |
+
else:
|
| 305 |
+
f.set_structure(self.atom_array)
|
| 306 |
+
f.write(path)
|
| 307 |
+
|
| 308 |
+
def to_pdb_string(self, include_insertions: bool = True) -> str:
|
| 309 |
+
buf = io.StringIO()
|
| 310 |
+
self.to_pdb(buf, include_insertions=include_insertions)
|
| 311 |
+
buf.seek(0)
|
| 312 |
+
return buf.read()
|
| 313 |
+
|
| 314 |
+
def to_mmcif(self, path: PathOrBuffer):
|
| 315 |
+
f = CIFFile()
|
| 316 |
+
set_structure_pdbx(f, self.atom_array, data_block=self.id)
|
| 317 |
+
|
| 318 |
+
# incantations molstar needs to render pLDDT / confidence onto
|
| 319 |
+
# the structure with "alphafold-view"
|
| 320 |
+
f.block["ma_qa_metric"] = CIFCategory(
|
| 321 |
+
name="ma_qa_metric",
|
| 322 |
+
columns={
|
| 323 |
+
"id": CIFColumn(data=CIFData(array=np.array([1, 2]), dtype=np.int64)),
|
| 324 |
+
"mode": CIFColumn(
|
| 325 |
+
data=CIFData(array=np.array(["global", "local"]), dtype=np.str_)
|
| 326 |
+
),
|
| 327 |
+
"name": CIFColumn(
|
| 328 |
+
data=CIFData(array=np.array(["pLDDT", "pLDDT"]), dtype=np.str_)
|
| 329 |
+
),
|
| 330 |
+
},
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# table is a duplicate of data already in the atom array, but
|
| 334 |
+
# needed by molstar to render pLDDT / confidence
|
| 335 |
+
resid_pldd_table = {
|
| 336 |
+
# hard coded to as we currently only support single chain structures
|
| 337 |
+
"label_asym_id": CIFColumn(
|
| 338 |
+
data=CIFData(
|
| 339 |
+
array=[CHAIN_ID_CONST] * len(self.residue_index), dtype=np.str_
|
| 340 |
+
)
|
| 341 |
+
),
|
| 342 |
+
"label_comp_id": CIFColumn(
|
| 343 |
+
data=CIFData(
|
| 344 |
+
array=[
|
| 345 |
+
residue_constants.restype_1to3.get(c, "UNK")
|
| 346 |
+
for c in self.sequence
|
| 347 |
+
],
|
| 348 |
+
dtype=np.str_,
|
| 349 |
+
)
|
| 350 |
+
),
|
| 351 |
+
"label_seq_id": CIFColumn(
|
| 352 |
+
data=CIFData(array=self.residue_index, dtype=np.int64)
|
| 353 |
+
),
|
| 354 |
+
"ordinal_id": CIFColumn(
|
| 355 |
+
data=CIFData(array=self.residue_index, dtype=np.int64)
|
| 356 |
+
),
|
| 357 |
+
# hard coded to show these are all local plDDT values
|
| 358 |
+
"metric_id": CIFColumn(
|
| 359 |
+
data=CIFData(array=["2"] * len(self.residue_index), dtype=np.str_)
|
| 360 |
+
),
|
| 361 |
+
"metric_value": CIFColumn(
|
| 362 |
+
data=CIFData(array=self.confidence, dtype=np.float32)
|
| 363 |
+
),
|
| 364 |
+
# hard coded to show there are the initial version, there are no revisions
|
| 365 |
+
"model_id": CIFColumn(
|
| 366 |
+
data=CIFData(array=["1"] * len(self.residue_index), dtype=np.str_)
|
| 367 |
+
),
|
| 368 |
+
}
|
| 369 |
+
f.block["ma_qa_metric_local"] = CIFCategory(
|
| 370 |
+
name="ma_qa_metric_local", columns=resid_pldd_table
|
| 371 |
+
)
|
| 372 |
+
f.write(path)
|
| 373 |
+
|
| 374 |
+
def to_mmcif_string(self) -> str:
|
| 375 |
+
buf = io.StringIO()
|
| 376 |
+
self.to_mmcif(buf)
|
| 377 |
+
buf.seek(0)
|
| 378 |
+
return buf.read()
|
| 379 |
+
|
| 380 |
+
def state_dict(self, backbone_only=False, json_serializable=False):
|
| 381 |
+
"""This state dict is optimized for storage, so it turns things to fp16 whenever
|
| 382 |
+
possible. Note that we also only support int32 residue indices, I'm hoping we don't
|
| 383 |
+
need more than 2**32 residues..."""
|
| 384 |
+
dct = {k: v for k, v in asdict(self).items() if k not in ["mmcif"]}
|
| 385 |
+
if backbone_only:
|
| 386 |
+
dct["atom37_mask"][:, 3:] = False
|
| 387 |
+
dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]]
|
| 388 |
+
if dct.get("atom37_confidence") is not None:
|
| 389 |
+
dct["atom37_confidence"] = dct["atom37_confidence"][dct["atom37_mask"]]
|
| 390 |
+
else:
|
| 391 |
+
dct.pop("atom37_confidence", None)
|
| 392 |
+
|
| 393 |
+
for k, v in dct.items():
|
| 394 |
+
if isinstance(v, np.ndarray):
|
| 395 |
+
match v.dtype:
|
| 396 |
+
case np.int64:
|
| 397 |
+
dct[k] = v.astype(np.int32)
|
| 398 |
+
case np.float64 | np.float32:
|
| 399 |
+
dct[k] = v.astype(np.float16)
|
| 400 |
+
case _:
|
| 401 |
+
pass
|
| 402 |
+
if json_serializable:
|
| 403 |
+
dct[k] = v.tolist()
|
| 404 |
+
return dct
|
| 405 |
+
|
| 406 |
+
def to_blob(self, backbone_only=False) -> bytes:
|
| 407 |
+
return brotli.compress(msgpack.dumps(self.state_dict(backbone_only)), quality=5)
|
| 408 |
+
|
| 409 |
+
@classmethod
|
| 410 |
+
def from_open_source(cls, pc: ProteinChain):
|
| 411 |
+
return cls(**vars(pc))
|
| 412 |
+
|
| 413 |
+
@classmethod
|
| 414 |
+
def from_state_dict(cls, dct):
|
| 415 |
+
# Note: assembly_composition is *supposed* to have string keys.
|
| 416 |
+
dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"])
|
| 417 |
+
|
| 418 |
+
for k, v in dct.items():
|
| 419 |
+
if isinstance(v, list):
|
| 420 |
+
dct[k] = np.array(v)
|
| 421 |
+
|
| 422 |
+
atom37 = np.full((*dct["atom37_mask"].shape, 3), np.nan)
|
| 423 |
+
atom37[dct["atom37_mask"]] = dct["atom37_positions"]
|
| 424 |
+
dct["atom37_positions"] = atom37
|
| 425 |
+
if "atom37_confidence" in dct:
|
| 426 |
+
atom37_conf = np.full(dct["atom37_mask"].shape, np.nan, dtype=np.float32)
|
| 427 |
+
atom37_conf[dct["atom37_mask"]] = dct["atom37_confidence"]
|
| 428 |
+
dct["atom37_confidence"] = atom37_conf
|
| 429 |
+
dct = {
|
| 430 |
+
k: (
|
| 431 |
+
v.astype(np.float32)
|
| 432 |
+
if k in ["atom37_positions", "confidence", "atom37_confidence"]
|
| 433 |
+
else v
|
| 434 |
+
)
|
| 435 |
+
for k, v in dct.items()
|
| 436 |
+
if not (k == "atom37_confidence" and v is None)
|
| 437 |
+
}
|
| 438 |
+
return cls(**dct, mmcif=None)
|
| 439 |
+
|
| 440 |
+
@classmethod
|
| 441 |
+
def from_blob(cls, input: Path | str | io.BytesIO | bytes):
|
| 442 |
+
"""NOTE(@zlin): blob + sparse coding + brotli + fp16 reduces memory
|
| 443 |
+
of chains from 52G/1M chains to 20G/1M chains, I think this is a good first
|
| 444 |
+
shot at compressing and dumping chains to disk. I'm sure there's better ways."""
|
| 445 |
+
match input:
|
| 446 |
+
case Path() | str():
|
| 447 |
+
bytes = Path(input).read_bytes()
|
| 448 |
+
case io.BytesIO():
|
| 449 |
+
bytes = input.getvalue()
|
| 450 |
+
case _:
|
| 451 |
+
bytes = input
|
| 452 |
+
return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes)))
|
| 453 |
+
|
| 454 |
+
def sasa(self, by_residue: bool = True):
|
| 455 |
+
arr = self.atom_array_no_insertions
|
| 456 |
+
sasa_per_atom = bs.sasa(arr) # type: ignore
|
| 457 |
+
if by_residue:
|
| 458 |
+
# Sum per-atom SASA into residue "bins", with np.bincount.
|
| 459 |
+
assert arr.res_id is not None
|
| 460 |
+
# NOTE(rverkuil): arr.res_id is 1-indexed, but np.bincount returns a sum for bin 0, so we strip.
|
| 461 |
+
# NOTE(aderry): We compute only for residues with coordinates, return NaN otherwise.
|
| 462 |
+
num_trailing_residues = len(self) - arr.res_id.max()
|
| 463 |
+
sasa_per_residue = np.concatenate(
|
| 464 |
+
[
|
| 465 |
+
np.bincount(arr.res_id, weights=sasa_per_atom)[1:],
|
| 466 |
+
np.zeros(num_trailing_residues),
|
| 467 |
+
]
|
| 468 |
+
)
|
| 469 |
+
sasa_per_residue[~self.atom37_mask.any(-1)] = np.nan
|
| 470 |
+
assert len(sasa_per_residue) == len(self)
|
| 471 |
+
return sasa_per_residue
|
| 472 |
+
return sasa_per_atom
|
| 473 |
+
|
| 474 |
+
def sap_score(self, aggregation: str = "atom") -> np.ndarray:
|
| 475 |
+
"""Computes per-atom SAP score.
|
| 476 |
+
Can optionally aggregate by residue (by averaging over atoms. NOTE: this returns values only for residues that have coordinates!)
|
| 477 |
+
or full-protein (sum of SAP score for atoms with SAP > 0, as in Lauer et al. 2011)."""
|
| 478 |
+
sap_radius = 5.0
|
| 479 |
+
arr = self.atom_array_no_insertions
|
| 480 |
+
|
| 481 |
+
# asserts to avoid type errors
|
| 482 |
+
assert arr.res_id is not None
|
| 483 |
+
assert arr.res_name is not None
|
| 484 |
+
assert arr.atom_name is not None
|
| 485 |
+
assert arr.coord is not None
|
| 486 |
+
|
| 487 |
+
# compute SASA and residue-specific properties
|
| 488 |
+
sasa_per_atom = self.sasa(by_residue=False)
|
| 489 |
+
resid_to_resname = dict(zip(arr.res_id, arr.res_name))
|
| 490 |
+
|
| 491 |
+
max_side_chain_asa = np.full(len(self), np.nan)
|
| 492 |
+
res_hydrophobicity = np.full(len(self), np.nan)
|
| 493 |
+
resolved_res_mask = self.atom37_mask.any(-1)
|
| 494 |
+
num_trailing_residues = len(self) - arr.res_id.max()
|
| 495 |
+
|
| 496 |
+
max_side_chain_asa[resolved_res_mask] = np.array(
|
| 497 |
+
[
|
| 498 |
+
residue_constants.side_chain_asa[resid_to_resname[i]]
|
| 499 |
+
for i in np.unique(arr.res_id)
|
| 500 |
+
]
|
| 501 |
+
)
|
| 502 |
+
res_hydrophobicity[resolved_res_mask] = np.array(
|
| 503 |
+
[
|
| 504 |
+
residue_constants.hydrophobicity[resid_to_resname[i]]
|
| 505 |
+
for i in np.unique(arr.res_id)
|
| 506 |
+
]
|
| 507 |
+
)
|
| 508 |
+
assert len(max_side_chain_asa) == len(self)
|
| 509 |
+
assert len(res_hydrophobicity) == len(self)
|
| 510 |
+
|
| 511 |
+
# compute SAP score
|
| 512 |
+
is_side_chain = ~bs.filter_peptide_backbone(arr)
|
| 513 |
+
sasa_per_atom[is_side_chain] = 0
|
| 514 |
+
kdtree = KDTree(arr.coord)
|
| 515 |
+
neighbors = kdtree.query_ball_tree(kdtree, sap_radius, p=2.0)
|
| 516 |
+
sap_by_atom = np.zeros_like(sasa_per_atom)
|
| 517 |
+
for i, nn_list in enumerate(neighbors):
|
| 518 |
+
saa_nn = np.zeros_like(sasa_per_atom)
|
| 519 |
+
saa_nn[nn_list] = sasa_per_atom[nn_list]
|
| 520 |
+
sasa_within_r = np.concatenate(
|
| 521 |
+
[
|
| 522 |
+
np.bincount(arr.res_id, weights=saa_nn)[1:],
|
| 523 |
+
np.zeros(num_trailing_residues),
|
| 524 |
+
]
|
| 525 |
+
)
|
| 526 |
+
sap = np.nansum((sasa_within_r / max_side_chain_asa) * res_hydrophobicity)
|
| 527 |
+
sap_by_atom[i] = sap
|
| 528 |
+
|
| 529 |
+
match aggregation:
|
| 530 |
+
case "atom":
|
| 531 |
+
return sap_by_atom
|
| 532 |
+
case "residue":
|
| 533 |
+
sap_by_residue = np.concatenate(
|
| 534 |
+
[
|
| 535 |
+
np.bincount(arr.res_id, weights=sap_by_atom)[1:],
|
| 536 |
+
np.zeros(num_trailing_residues),
|
| 537 |
+
]
|
| 538 |
+
) / (
|
| 539 |
+
np.concatenate(
|
| 540 |
+
[np.bincount(arr.res_id)[1:], np.zeros(num_trailing_residues)]
|
| 541 |
+
)
|
| 542 |
+
+ 1e-8
|
| 543 |
+
)
|
| 544 |
+
sap_by_residue[~resolved_res_mask] = np.nan
|
| 545 |
+
assert len(sap_by_residue) == len(self)
|
| 546 |
+
return sap_by_residue
|
| 547 |
+
case "protein":
|
| 548 |
+
return sum(sap_by_atom[sap_by_atom > 0]) # pyright: ignore[reportReturnType]
|
| 549 |
+
case _:
|
| 550 |
+
raise ValueError(
|
| 551 |
+
f"Invalid aggregation method: {aggregation}. Must be one of 'atom', 'residue', or 'protein'"
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
def globularity(self) -> float:
|
| 555 |
+
# Computes globularity using total volumes divided by MVEE.
|
| 556 |
+
# We make the simplifying approximation that atoms never overlap.
|
| 557 |
+
# The globularity is only computed where structure exists.
|
| 558 |
+
# Besides the approximation above, this is inspired by:
|
| 559 |
+
|
| 560 |
+
# https://www.mdpi.com/2073-4352/11/12/1539
|
| 561 |
+
# NOTE(@zeming): due to the approximation we make here, that atoms never overlap, you might get >1 globularity
|
| 562 |
+
mask = self.atom37_mask.any(-1)
|
| 563 |
+
points = self.atom37_positions[self.atom37_mask]
|
| 564 |
+
sequence = [aa for aa, m in zip(self.sequence, mask) if m] # type: ignore
|
| 565 |
+
A, _ = self._mvee(points, tol=1e-3)
|
| 566 |
+
mvee_volume = (4 * np.pi) / (3 * np.sqrt(np.linalg.det(A)))
|
| 567 |
+
volume = sum(residue_constants.amino_acid_volumes[x] for x in sequence)
|
| 568 |
+
ratio = volume / mvee_volume
|
| 569 |
+
|
| 570 |
+
# The paper says you must compare the ellipsoidal profile with T, a measurement of
|
| 571 |
+
# how elongated the ellipsoid is. We want a single number, so we multiply by 1/2T, so
|
| 572 |
+
# that value is normalized between 0-1
|
| 573 |
+
eigenvalues = np.linalg.eigvals(A)
|
| 574 |
+
R = 1 / np.sqrt(eigenvalues)
|
| 575 |
+
# ellipsoid radii length triangle inequality coefficient
|
| 576 |
+
T = max(R[0] / (R[1] + R[2]), R[1] / (R[0] + R[2]), R[2] / (R[0] + R[1]))
|
| 577 |
+
elongation_metric = 1 / max(T, 1)
|
| 578 |
+
return ratio * elongation_metric
|
| 579 |
+
|
| 580 |
+
@staticmethod
|
| 581 |
+
def _mvee(P: np.ndarray, tol, max_iter=10000):
|
| 582 |
+
# Finds minimum volume enclosing ellipsoid of a set of points.
|
| 583 |
+
# Returns A, c where the ellipse is defined as:
|
| 584 |
+
# (x-c).T @ A @ (x-c) = 1
|
| 585 |
+
hull = ConvexHull(P)
|
| 586 |
+
P = P[hull.vertices]
|
| 587 |
+
P = P.T
|
| 588 |
+
|
| 589 |
+
# Data points
|
| 590 |
+
d, N = P.shape
|
| 591 |
+
Q = np.zeros((d + 1, N))
|
| 592 |
+
Q[:d, :] = P[:d, :N]
|
| 593 |
+
Q[d, :] = np.ones((1, N))
|
| 594 |
+
|
| 595 |
+
# Initializations
|
| 596 |
+
count = 1
|
| 597 |
+
err = 1.0
|
| 598 |
+
u = np.full((N, 1), 1 / N) # 1st iteration
|
| 599 |
+
|
| 600 |
+
# Khachiyan Algorithm
|
| 601 |
+
for i in range(max_iter):
|
| 602 |
+
X = Q.dot(np.diag(u.squeeze())) @ Q.T
|
| 603 |
+
M = np.diag(Q.T @ np.linalg.inv(X) @ Q)
|
| 604 |
+
maximum, j = np.max(M), np.argmax(M)
|
| 605 |
+
step_size = (maximum - d - 1) / ((d + 1) * (maximum - 1))
|
| 606 |
+
new_u = (1 - step_size) * u
|
| 607 |
+
new_u[j] += step_size
|
| 608 |
+
count += 1
|
| 609 |
+
err = np.linalg.norm(new_u - u)
|
| 610 |
+
u = new_u
|
| 611 |
+
if err < tol:
|
| 612 |
+
break
|
| 613 |
+
else:
|
| 614 |
+
raise ValueError("MVEE did not converge")
|
| 615 |
+
|
| 616 |
+
d = P.shape[0] # Fixed: use P.shape[0] instead of P.shape
|
| 617 |
+
U = np.diag(u.squeeze())
|
| 618 |
+
|
| 619 |
+
# The A matrix for the ellipse
|
| 620 |
+
A = (1 / d) * np.linalg.inv(P @ U @ P.T - (P @ u) @ (P @ u).T)
|
| 621 |
+
|
| 622 |
+
# Center of the ellipse
|
| 623 |
+
c = P @ u
|
| 624 |
+
|
| 625 |
+
return A, c
|
| 626 |
+
|
| 627 |
+
def radius_of_gyration(self):
|
| 628 |
+
arr = self.atom_array_no_insertions
|
| 629 |
+
return bs.gyration_radius(arr)
|
| 630 |
+
|
| 631 |
+
def align(
|
| 632 |
+
self,
|
| 633 |
+
target: ProteinChain,
|
| 634 |
+
mobile_inds: list[int] | np.ndarray | None = None,
|
| 635 |
+
target_inds: list[int] | np.ndarray | None = None,
|
| 636 |
+
only_use_backbone: bool = False,
|
| 637 |
+
):
|
| 638 |
+
"""
|
| 639 |
+
Aligns the current protein to the provided target.
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
target (ProteinChain): The target protein to align to.
|
| 643 |
+
mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
|
| 644 |
+
target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
|
| 645 |
+
only_use_backbone (bool, optional): If True, only align the backbone atoms.
|
| 646 |
+
"""
|
| 647 |
+
aligner = Aligner(
|
| 648 |
+
self if mobile_inds is None else self[mobile_inds],
|
| 649 |
+
target if target_inds is None else target[target_inds],
|
| 650 |
+
only_use_backbone,
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
+
return aligner.apply(self)
|
| 654 |
+
|
| 655 |
+
def rmsd(
|
| 656 |
+
self,
|
| 657 |
+
target: ProteinChain,
|
| 658 |
+
also_check_reflection: bool = False,
|
| 659 |
+
mobile_inds: list[int] | np.ndarray | None = None,
|
| 660 |
+
target_inds: list[int] | np.ndarray | None = None,
|
| 661 |
+
only_compute_backbone_rmsd: bool = False,
|
| 662 |
+
):
|
| 663 |
+
"""
|
| 664 |
+
Compute the RMSD between this protein chain and another.
|
| 665 |
+
|
| 666 |
+
Args:
|
| 667 |
+
target (ProteinChain): The target (other) protein chain to compare to.
|
| 668 |
+
also_check_reflection (bool, optional): If True, also check if the reflection of the mobile atoms has a lower RMSD.
|
| 669 |
+
mobile_inds (list[int], optional): The indices of the mobile atoms to align. These are NOT residue indices
|
| 670 |
+
target_inds (list[int], optional): The indices of the target atoms to align. These are NOT residue indices
|
| 671 |
+
only_compute_backbone_rmsd (bool, optional): If True, only compute the RMSD of the backbone atoms.
|
| 672 |
+
"""
|
| 673 |
+
if isinstance(target, bs.AtomArray):
|
| 674 |
+
raise ValueError(
|
| 675 |
+
"Support for bs.AtomArray removed, use "
|
| 676 |
+
"ProteinChain.from_atomarry for ProteinChain."
|
| 677 |
+
)
|
| 678 |
+
aligner = Aligner(
|
| 679 |
+
self if mobile_inds is None else self[mobile_inds],
|
| 680 |
+
target if target_inds is None else target[target_inds],
|
| 681 |
+
only_compute_backbone_rmsd,
|
| 682 |
+
)
|
| 683 |
+
avg_rmsd = aligner.rmsd
|
| 684 |
+
|
| 685 |
+
if not also_check_reflection:
|
| 686 |
+
return avg_rmsd
|
| 687 |
+
|
| 688 |
+
aligner = Aligner(
|
| 689 |
+
self if mobile_inds is None else self[mobile_inds],
|
| 690 |
+
target if target_inds is None else target[target_inds],
|
| 691 |
+
only_compute_backbone_rmsd,
|
| 692 |
+
use_reflection=True,
|
| 693 |
+
)
|
| 694 |
+
avg_rmsd_neg = aligner.rmsd
|
| 695 |
+
|
| 696 |
+
return min(avg_rmsd, avg_rmsd_neg)
|
| 697 |
+
|
| 698 |
+
def lddt_ca(
|
| 699 |
+
self,
|
| 700 |
+
native: ProteinChain,
|
| 701 |
+
mobile_inds: list[int] | np.ndarray | None = None,
|
| 702 |
+
target_inds: list[int] | np.ndarray | None = None,
|
| 703 |
+
**kwargs,
|
| 704 |
+
) -> float | np.ndarray:
|
| 705 |
+
"""Compute the LDDT between this protein chain and another. NOTE: LDDT IS NOT SYMMETRIC.
|
| 706 |
+
The call should always be prediction.lddt_ca(native).
|
| 707 |
+
|
| 708 |
+
Arguments:
|
| 709 |
+
native (ProteinChain): The ground truth protein chain
|
| 710 |
+
mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
|
| 711 |
+
target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
|
| 712 |
+
|
| 713 |
+
Returns:
|
| 714 |
+
float | np.ndarray: The LDDT score between the two protein chains, either
|
| 715 |
+
a single float or per-residue LDDT scores if `per_residue` is True.
|
| 716 |
+
"""
|
| 717 |
+
lddt = compute_lddt_ca(
|
| 718 |
+
torch.tensor(self.atom37_positions[mobile_inds]).unsqueeze(0),
|
| 719 |
+
torch.tensor(native.atom37_positions[target_inds]).unsqueeze(0),
|
| 720 |
+
torch.tensor(native.atom37_mask[mobile_inds]).unsqueeze(0),
|
| 721 |
+
**kwargs,
|
| 722 |
+
)
|
| 723 |
+
return float(lddt) if lddt.numel() == 1 else lddt.numpy().flatten()
|
| 724 |
+
|
| 725 |
+
def gdt_ts(
|
| 726 |
+
self,
|
| 727 |
+
target: ProteinChain,
|
| 728 |
+
mobile_inds: list[int] | np.ndarray | None = None,
|
| 729 |
+
target_inds: list[int] | np.ndarray | None = None,
|
| 730 |
+
**kwargs,
|
| 731 |
+
) -> float | np.ndarray:
|
| 732 |
+
"""Compute the GDT_TS between this protein chain and another.
|
| 733 |
+
|
| 734 |
+
Arguments:
|
| 735 |
+
target (ProteinChain): The other protein chain to compare to.
|
| 736 |
+
mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
|
| 737 |
+
target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
|
| 738 |
+
|
| 739 |
+
Returns:
|
| 740 |
+
float: The GDT_TS score between the two protein chains.
|
| 741 |
+
"""
|
| 742 |
+
gdt_ts = compute_gdt_ts(
|
| 743 |
+
mobile=torch.tensor(
|
| 744 |
+
index_by_atom_name(self.atom37_positions[mobile_inds], "CA"),
|
| 745 |
+
dtype=torch.float32,
|
| 746 |
+
).unsqueeze(0),
|
| 747 |
+
target=torch.tensor(
|
| 748 |
+
index_by_atom_name(target.atom37_positions[target_inds], "CA"),
|
| 749 |
+
dtype=torch.float32,
|
| 750 |
+
).unsqueeze(0),
|
| 751 |
+
atom_exists_mask=torch.tensor(
|
| 752 |
+
index_by_atom_name(self.atom37_mask[mobile_inds], "CA", dim=-1)
|
| 753 |
+
& index_by_atom_name(target.atom37_mask[target_inds], "CA", dim=-1)
|
| 754 |
+
).unsqueeze(0),
|
| 755 |
+
**kwargs,
|
| 756 |
+
)
|
| 757 |
+
return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
|
| 758 |
+
|
| 759 |
+
@classmethod
|
| 760 |
+
def chain_iterable_from_mmcif(
|
| 761 |
+
cls,
|
| 762 |
+
path: PathOrBuffer | MmcifWrapper,
|
| 763 |
+
id: str | None = None,
|
| 764 |
+
is_predicted: bool = False,
|
| 765 |
+
keep_source: bool = False,
|
| 766 |
+
):
|
| 767 |
+
"""Return a list[ProteinChain] object from an mmcif file, a iterable list of all protein chain
|
| 768 |
+
from an mmcif file
|
| 769 |
+
"""
|
| 770 |
+
if isinstance(path, MmcifWrapper):
|
| 771 |
+
mmcif = path
|
| 772 |
+
else:
|
| 773 |
+
mmcif = MmcifWrapper.read(path, id)
|
| 774 |
+
for chain in bs.chain_iter(mmcif.structure):
|
| 775 |
+
chain = chain[bs.filter_amino_acids(chain) & ~chain.hetero]
|
| 776 |
+
if len(chain) == 0:
|
| 777 |
+
continue
|
| 778 |
+
chain_id = chain.chain_id[0]
|
| 779 |
+
entity_id = None
|
| 780 |
+
for entity, chains in mmcif.entities.items():
|
| 781 |
+
if chain_id in chains:
|
| 782 |
+
entity_id = entity
|
| 783 |
+
assert entity_id is not None
|
| 784 |
+
(
|
| 785 |
+
sequence,
|
| 786 |
+
atom_positions,
|
| 787 |
+
atom_mask,
|
| 788 |
+
residue_index,
|
| 789 |
+
insertion_code,
|
| 790 |
+
confidence,
|
| 791 |
+
_,
|
| 792 |
+
) = chain_to_ndarray(chain, mmcif, chain_id, is_predicted)
|
| 793 |
+
assert all(sequence), "Some residue name was not specified correctly"
|
| 794 |
+
|
| 795 |
+
yield cls(
|
| 796 |
+
id=mmcif.id,
|
| 797 |
+
sequence=sequence,
|
| 798 |
+
chain_id=chain_id,
|
| 799 |
+
entity_id=entity_id,
|
| 800 |
+
atom37_positions=atom_positions,
|
| 801 |
+
atom37_mask=atom_mask,
|
| 802 |
+
residue_index=residue_index,
|
| 803 |
+
insertion_code=insertion_code,
|
| 804 |
+
confidence=confidence,
|
| 805 |
+
mmcif=mmcif if keep_source else None,
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
@classmethod
|
| 809 |
+
def from_mmcif(
|
| 810 |
+
cls,
|
| 811 |
+
path: PathOrBuffer | MmcifWrapper,
|
| 812 |
+
chain_id: str | None = None,
|
| 813 |
+
entity_id: int | None = None,
|
| 814 |
+
id: str | None = None,
|
| 815 |
+
is_predicted: bool = False,
|
| 816 |
+
keep_source: bool = False,
|
| 817 |
+
):
|
| 818 |
+
"""Return a ProteinChain object from an mmcif file.
|
| 819 |
+
|
| 820 |
+
Args:
|
| 821 |
+
path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
|
| 822 |
+
id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
|
| 823 |
+
is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
|
| 824 |
+
chain_id (str, optional): Select a chain corresponding to (author) chain id.
|
| 825 |
+
entity_id (int, optional): Select a chain corresponding to a particular entity.
|
| 826 |
+
|
| 827 |
+
If neither `chain_id` nor `entity_id` is specified, defaults to the first entity.
|
| 828 |
+
"""
|
| 829 |
+
if isinstance(path, MmcifWrapper):
|
| 830 |
+
mmcif = path
|
| 831 |
+
else:
|
| 832 |
+
mmcif = MmcifWrapper.read(path, id)
|
| 833 |
+
|
| 834 |
+
# If neither chain_id nor entity_id is specified, default to the first entity
|
| 835 |
+
if chain_id is None and entity_id is None:
|
| 836 |
+
if not mmcif.entities:
|
| 837 |
+
raise ValueError("Structure contains no entities")
|
| 838 |
+
entity_id = min(mmcif.entities.keys()) # Pick the first entity by ID
|
| 839 |
+
|
| 840 |
+
if entity_id is not None:
|
| 841 |
+
assert chain_id is None
|
| 842 |
+
if entity_id not in mmcif.entities:
|
| 843 |
+
raise ValueError(
|
| 844 |
+
f"Structure does not contain entity `{entity_id}`. Valid entities: {mmcif.entities.keys()}"
|
| 845 |
+
)
|
| 846 |
+
chains = mmcif.entities[entity_id]
|
| 847 |
+
|
| 848 |
+
# Select the chain id corresponding to the longest chain. If all are equal length, selects the first.
|
| 849 |
+
chain_id = max(
|
| 850 |
+
chains,
|
| 851 |
+
key=lambda chain: _num_non_null_residues(
|
| 852 |
+
mmcif.seqres_to_structure[chain]
|
| 853 |
+
),
|
| 854 |
+
)
|
| 855 |
+
else:
|
| 856 |
+
assert chain_id is not None
|
| 857 |
+
for entity, chains in mmcif.entities.items():
|
| 858 |
+
if chain_id in chains:
|
| 859 |
+
entity_id = entity
|
| 860 |
+
if entity_id is None:
|
| 861 |
+
warnings.warn(
|
| 862 |
+
"Failed to detect entity_id from mmcif file, it may be malformed."
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
atom_array = mmcif.structure
|
| 866 |
+
(
|
| 867 |
+
sequence,
|
| 868 |
+
atom_positions,
|
| 869 |
+
atom_mask,
|
| 870 |
+
residue_index,
|
| 871 |
+
insertion_code,
|
| 872 |
+
confidence,
|
| 873 |
+
_,
|
| 874 |
+
) = chain_to_ndarray(atom_array, mmcif, chain_id, is_predicted)
|
| 875 |
+
assert all(sequence), "Some residue name was not specified correctly"
|
| 876 |
+
|
| 877 |
+
return cls(
|
| 878 |
+
id=mmcif.id,
|
| 879 |
+
sequence=sequence,
|
| 880 |
+
chain_id=chain_id,
|
| 881 |
+
entity_id=entity_id,
|
| 882 |
+
atom37_positions=atom_positions,
|
| 883 |
+
atom37_mask=atom_mask.astype(bool),
|
| 884 |
+
residue_index=residue_index,
|
| 885 |
+
insertion_code=insertion_code,
|
| 886 |
+
confidence=confidence,
|
| 887 |
+
mmcif=mmcif if keep_source else None,
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
@classmethod
|
| 891 |
+
def from_atom37(
|
| 892 |
+
cls,
|
| 893 |
+
atom37_positions: np.ndarray | torch.Tensor,
|
| 894 |
+
*,
|
| 895 |
+
id: str | None = None,
|
| 896 |
+
sequence: str | None = None,
|
| 897 |
+
chain_id: str | None = None,
|
| 898 |
+
entity_id: int | None = None,
|
| 899 |
+
residue_index: np.ndarray | torch.Tensor | None = None,
|
| 900 |
+
insertion_code: np.ndarray | None = None,
|
| 901 |
+
confidence: np.ndarray | torch.Tensor | None = None,
|
| 902 |
+
):
|
| 903 |
+
if isinstance(atom37_positions, torch.Tensor):
|
| 904 |
+
atom37_positions = atom37_positions.cpu().numpy()
|
| 905 |
+
if atom37_positions.ndim == 4:
|
| 906 |
+
if atom37_positions.shape[0] != 1:
|
| 907 |
+
raise ValueError(
|
| 908 |
+
f"Cannot handle batched inputs, atom37_positions has shape {atom37_positions.shape}"
|
| 909 |
+
)
|
| 910 |
+
atom37_positions = atom37_positions[0]
|
| 911 |
+
|
| 912 |
+
assert isinstance(atom37_positions, np.ndarray)
|
| 913 |
+
seqlen = atom37_positions.shape[0]
|
| 914 |
+
|
| 915 |
+
atom_mask = np.isfinite(atom37_positions).all(-1)
|
| 916 |
+
|
| 917 |
+
if id is None:
|
| 918 |
+
id = ""
|
| 919 |
+
|
| 920 |
+
if sequence is None:
|
| 921 |
+
sequence = "A" * seqlen
|
| 922 |
+
|
| 923 |
+
if chain_id is None:
|
| 924 |
+
chain_id = "A"
|
| 925 |
+
|
| 926 |
+
if residue_index is None:
|
| 927 |
+
residue_index = np.arange(1, seqlen + 1)
|
| 928 |
+
elif isinstance(residue_index, torch.Tensor):
|
| 929 |
+
residue_index = residue_index.cpu().numpy()
|
| 930 |
+
assert isinstance(residue_index, np.ndarray)
|
| 931 |
+
if residue_index.ndim == 2:
|
| 932 |
+
if residue_index.shape[0] != 1:
|
| 933 |
+
raise ValueError(
|
| 934 |
+
f"Cannot handle batched inputs, residue_index has shape {residue_index.shape}"
|
| 935 |
+
)
|
| 936 |
+
residue_index = residue_index[0]
|
| 937 |
+
assert isinstance(residue_index, np.ndarray)
|
| 938 |
+
|
| 939 |
+
if insertion_code is None:
|
| 940 |
+
insertion_code = np.array(["" for _ in range(seqlen)])
|
| 941 |
+
|
| 942 |
+
if confidence is None:
|
| 943 |
+
confidence = np.ones(seqlen, dtype=np.float32)
|
| 944 |
+
elif isinstance(confidence, torch.Tensor):
|
| 945 |
+
confidence = confidence.cpu().numpy()
|
| 946 |
+
assert isinstance(confidence, np.ndarray)
|
| 947 |
+
if confidence.ndim == 2:
|
| 948 |
+
if confidence.shape[0] != 1:
|
| 949 |
+
raise ValueError(
|
| 950 |
+
f"Cannot handle batched inputs, confidence has shape {confidence.shape}"
|
| 951 |
+
)
|
| 952 |
+
confidence = confidence[0]
|
| 953 |
+
assert isinstance(confidence, np.ndarray)
|
| 954 |
+
|
| 955 |
+
return cls(
|
| 956 |
+
id=id,
|
| 957 |
+
sequence=sequence, # type: ignore
|
| 958 |
+
chain_id=chain_id,
|
| 959 |
+
entity_id=entity_id,
|
| 960 |
+
atom37_positions=atom37_positions,
|
| 961 |
+
atom37_mask=atom_mask.astype(bool),
|
| 962 |
+
residue_index=residue_index,
|
| 963 |
+
insertion_code=insertion_code,
|
| 964 |
+
confidence=confidence,
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
@classmethod
|
| 968 |
+
def from_backbone_atom_coordinates(
|
| 969 |
+
cls, backbone_atom_coordinates: np.ndarray | torch.Tensor, **kwargs
|
| 970 |
+
):
|
| 971 |
+
"""Create a ProteinChain from a set of backbone atom coordinates.
|
| 972 |
+
|
| 973 |
+
This function simply expands the seqlen x 3 x 3 array of backbone atom
|
| 974 |
+
coordinates to a seqlen x 37 x 3 array of all atom coordinates, with the padded
|
| 975 |
+
positions set to infinity. This allows us to use from_atom37 to create the
|
| 976 |
+
appropriate ProteinChain object with the appropriate atom37_mask.
|
| 977 |
+
|
| 978 |
+
This function passes all kwargs to from_atom37.
|
| 979 |
+
"""
|
| 980 |
+
if isinstance(backbone_atom_coordinates, torch.Tensor):
|
| 981 |
+
backbone_atom_coordinates = backbone_atom_coordinates.cpu().numpy()
|
| 982 |
+
if backbone_atom_coordinates.ndim == 4:
|
| 983 |
+
if backbone_atom_coordinates.shape[0] != 1:
|
| 984 |
+
raise ValueError(
|
| 985 |
+
f"Cannot handle batched inputs, backbone_atom_coordinates has "
|
| 986 |
+
f"shape {backbone_atom_coordinates.shape}"
|
| 987 |
+
)
|
| 988 |
+
backbone_atom_coordinates = backbone_atom_coordinates[0]
|
| 989 |
+
|
| 990 |
+
assert isinstance(backbone_atom_coordinates, np.ndarray)
|
| 991 |
+
assert backbone_atom_coordinates.ndim == 3
|
| 992 |
+
assert backbone_atom_coordinates.shape[-2] == 3
|
| 993 |
+
assert backbone_atom_coordinates.shape[-1] == 3
|
| 994 |
+
|
| 995 |
+
atom37_positions = np.full(
|
| 996 |
+
(backbone_atom_coordinates.shape[0], 37, 3),
|
| 997 |
+
np.inf,
|
| 998 |
+
dtype=backbone_atom_coordinates.dtype,
|
| 999 |
+
)
|
| 1000 |
+
atom37_positions[:, :3, :] = backbone_atom_coordinates
|
| 1001 |
+
|
| 1002 |
+
return cls.from_atom37(atom37_positions=atom37_positions, **kwargs)
|
| 1003 |
+
|
| 1004 |
+
@classmethod
|
| 1005 |
+
def from_pdb(
|
| 1006 |
+
cls,
|
| 1007 |
+
path: PathOrBuffer,
|
| 1008 |
+
chain_id: str = "detect",
|
| 1009 |
+
id: str | None = None,
|
| 1010 |
+
is_predicted: bool = False,
|
| 1011 |
+
) -> "ProteinChain":
|
| 1012 |
+
"""Return a ProteinChain object from an pdb file. NOTE: prefer mmcif for rcsb PDB files.
|
| 1013 |
+
This function is mostly to interface with old PDB files and predicted structures -
|
| 1014 |
+
it will not fill out the entity id correctly
|
| 1015 |
+
|
| 1016 |
+
Args:
|
| 1017 |
+
path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
|
| 1018 |
+
id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
|
| 1019 |
+
is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
|
| 1020 |
+
chain_id (str, optional): Select a chain corresponding to (author) chain id. "detect" uses the
|
| 1021 |
+
first detected chain
|
| 1022 |
+
"""
|
| 1023 |
+
|
| 1024 |
+
if id is not None:
|
| 1025 |
+
file_id = id
|
| 1026 |
+
else:
|
| 1027 |
+
match path:
|
| 1028 |
+
case Path() | str():
|
| 1029 |
+
file_id = Path(path).with_suffix("").name
|
| 1030 |
+
case _:
|
| 1031 |
+
file_id = "null"
|
| 1032 |
+
|
| 1033 |
+
atom_array = PDBFile.read(path).get_structure(
|
| 1034 |
+
model=1, extra_fields=["b_factor"]
|
| 1035 |
+
)
|
| 1036 |
+
if chain_id == "detect":
|
| 1037 |
+
chain_id = atom_array.chain_id[0]
|
| 1038 |
+
atom_array = atom_array[
|
| 1039 |
+
bs.filter_amino_acids(atom_array)
|
| 1040 |
+
& ~atom_array.hetero
|
| 1041 |
+
& (atom_array.chain_id == chain_id)
|
| 1042 |
+
]
|
| 1043 |
+
|
| 1044 |
+
entity_id = 1 # Not supplied in PDBfiles
|
| 1045 |
+
|
| 1046 |
+
sequence = "".join(
|
| 1047 |
+
residue_constants.restype_3to1.get(monomer[0].res_name, "X")
|
| 1048 |
+
for monomer in bs.residue_iter(atom_array)
|
| 1049 |
+
)
|
| 1050 |
+
num_res = len(sequence)
|
| 1051 |
+
|
| 1052 |
+
atom_positions = np.full(
|
| 1053 |
+
[num_res, residue_constants.atom_type_num, 3], np.nan, dtype=np.float32
|
| 1054 |
+
)
|
| 1055 |
+
atom_mask = np.full(
|
| 1056 |
+
[num_res, residue_constants.atom_type_num], False, dtype=bool
|
| 1057 |
+
)
|
| 1058 |
+
residue_index = np.full([num_res], -1, dtype=np.int64)
|
| 1059 |
+
insertion_code = np.full([num_res], "", dtype="<U4")
|
| 1060 |
+
|
| 1061 |
+
confidence = np.ones([num_res], dtype=np.float32)
|
| 1062 |
+
|
| 1063 |
+
for i, res in enumerate(bs.residue_iter(atom_array)):
|
| 1064 |
+
chain = atom_array[atom_array.chain_id == chain_id]
|
| 1065 |
+
assert isinstance(chain, bs.AtomArray)
|
| 1066 |
+
|
| 1067 |
+
res_index = res[0].res_id
|
| 1068 |
+
residue_index[i] = res_index
|
| 1069 |
+
insertion_code[i] = res[0].ins_code
|
| 1070 |
+
|
| 1071 |
+
# Atom level features
|
| 1072 |
+
for atom in res:
|
| 1073 |
+
atom_name = atom.atom_name
|
| 1074 |
+
if atom_name == "SE" and atom.res_name == "MSE":
|
| 1075 |
+
# Put the coords of the selenium atom in the sulphur column
|
| 1076 |
+
atom_name = "SD"
|
| 1077 |
+
|
| 1078 |
+
if atom_name in residue_constants.atom_order:
|
| 1079 |
+
atom_positions[i, residue_constants.atom_order[atom_name]] = (
|
| 1080 |
+
atom.coord
|
| 1081 |
+
)
|
| 1082 |
+
atom_mask[i, residue_constants.atom_order[atom_name]] = True
|
| 1083 |
+
if is_predicted and atom_name == "CA":
|
| 1084 |
+
confidence[i] = atom.b_factor
|
| 1085 |
+
|
| 1086 |
+
assert all(sequence), "Some residue name was not specified correctly"
|
| 1087 |
+
|
| 1088 |
+
return cls(
|
| 1089 |
+
id=file_id,
|
| 1090 |
+
sequence=sequence,
|
| 1091 |
+
chain_id=chain_id,
|
| 1092 |
+
entity_id=entity_id,
|
| 1093 |
+
atom37_positions=atom_positions,
|
| 1094 |
+
atom37_mask=atom_mask.astype(bool),
|
| 1095 |
+
residue_index=residue_index,
|
| 1096 |
+
insertion_code=insertion_code,
|
| 1097 |
+
confidence=confidence,
|
| 1098 |
+
mmcif=None,
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
@classmethod
|
| 1102 |
+
def from_mds(cls, data: dict[str, Any]) -> "ProteinChain":
|
| 1103 |
+
return cls(
|
| 1104 |
+
id=data["id"],
|
| 1105 |
+
chain_id=data["chain_id"],
|
| 1106 |
+
entity_id=data["entity_id"],
|
| 1107 |
+
sequence=data["sequence"],
|
| 1108 |
+
residue_index=data["residue_index"],
|
| 1109 |
+
insertion_code=np.asarray(data["insertion_code"]),
|
| 1110 |
+
atom37_positions=data["atom37_positions"],
|
| 1111 |
+
atom37_mask=data["atom37_mask"].astype(bool),
|
| 1112 |
+
confidence=data["confidence"],
|
| 1113 |
+
mmcif=None,
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
@classmethod
|
| 1117 |
+
def from_rcsb(
|
| 1118 |
+
cls,
|
| 1119 |
+
pdb_id: str,
|
| 1120 |
+
chain_id: str | None = None,
|
| 1121 |
+
entity_id: int | None = None,
|
| 1122 |
+
keep_source: bool = False,
|
| 1123 |
+
) -> ProteinChain:
|
| 1124 |
+
f: io.StringIO = rcsb.fetch(pdb_id, "cif") # type: ignore
|
| 1125 |
+
return cls.from_mmcif(
|
| 1126 |
+
f,
|
| 1127 |
+
id=pdb_id,
|
| 1128 |
+
chain_id=chain_id,
|
| 1129 |
+
entity_id=entity_id,
|
| 1130 |
+
keep_source=keep_source,
|
| 1131 |
+
is_predicted=False,
|
| 1132 |
+
)
|
| 1133 |
+
|
| 1134 |
+
@classmethod
|
| 1135 |
+
def from_atomarray(
|
| 1136 |
+
cls, atom_array: bs.AtomArray, id: str | None = None, is_predicted: bool = False
|
| 1137 |
+
) -> "ProteinChain":
|
| 1138 |
+
"""A simple converter from bs.AtomArray -> ProteinChain.
|
| 1139 |
+
Uses PDB file format as intermediate."""
|
| 1140 |
+
atom_array = atom_array.copy()
|
| 1141 |
+
atom_array.box = None # remove surrounding box, from_pdb won't handle this
|
| 1142 |
+
pdb_file = PDBFile() # pyright: ignore
|
| 1143 |
+
pdb_file.set_structure(atom_array)
|
| 1144 |
+
|
| 1145 |
+
buf = io.StringIO()
|
| 1146 |
+
pdb_file.write(buf)
|
| 1147 |
+
buf.seek(0)
|
| 1148 |
+
return cls.from_pdb(buf, id=id, is_predicted=is_predicted)
|
| 1149 |
+
|
| 1150 |
+
def get_normalization_frame(self) -> Affine3D:
|
| 1151 |
+
"""Given a set of coordinates, compute a single frame.
|
| 1152 |
+
Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame.
|
| 1153 |
+
|
| 1154 |
+
Returns:
|
| 1155 |
+
Affine3D: [] tensor of Affine3D frame
|
| 1156 |
+
"""
|
| 1157 |
+
coords = torch.from_numpy(self.atom37_positions)
|
| 1158 |
+
frame = get_protein_normalization_frame(coords)
|
| 1159 |
+
|
| 1160 |
+
return frame
|
| 1161 |
+
|
| 1162 |
+
def apply_frame(self, frame: Affine3D) -> ProteinChain:
|
| 1163 |
+
"""Given a frame, apply the frame to the protein's coordinates.
|
| 1164 |
+
|
| 1165 |
+
Args:
|
| 1166 |
+
frame (Affine3D): [] tensor of Affine3D frame
|
| 1167 |
+
|
| 1168 |
+
Returns:
|
| 1169 |
+
ProteinChain: Transformed protein chain
|
| 1170 |
+
"""
|
| 1171 |
+
coords = torch.from_numpy(self.atom37_positions).to(frame.trans.dtype)
|
| 1172 |
+
coords = apply_frame_to_coords(coords, frame)
|
| 1173 |
+
atom37_positions = coords.numpy()
|
| 1174 |
+
return replace(self, atom37_positions=atom37_positions)
|
| 1175 |
+
|
| 1176 |
+
def normalize_coordinates(self) -> ProteinChain:
|
| 1177 |
+
"""Normalize the coordinates of the protein chain."""
|
| 1178 |
+
return self.apply_frame(self.get_normalization_frame())
|
| 1179 |
+
|
| 1180 |
+
def infer_oxygen(self) -> ProteinChain:
|
| 1181 |
+
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
|
| 1182 |
+
O_missing_indices = np.argwhere(
|
| 1183 |
+
~np.isfinite(self.atoms["O"]).all(axis=1)
|
| 1184 |
+
).squeeze()
|
| 1185 |
+
|
| 1186 |
+
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
|
| 1187 |
+
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
|
| 1188 |
+
N = torch.roll(N, -3)
|
| 1189 |
+
N[..., -1, :] = torch.nan
|
| 1190 |
+
|
| 1191 |
+
# Get the frame defined by the CA-C-N atom
|
| 1192 |
+
frames = Affine3D.from_graham_schmidt(CA, C, N)
|
| 1193 |
+
O = frames.apply(O_vector)
|
| 1194 |
+
atom37_positions = self.atom37_positions.copy()
|
| 1195 |
+
atom37_mask = self.atom37_mask.copy()
|
| 1196 |
+
|
| 1197 |
+
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]] = O[
|
| 1198 |
+
O_missing_indices
|
| 1199 |
+
].numpy()
|
| 1200 |
+
atom37_mask[O_missing_indices, residue_constants.atom_order["O"]] = ~np.isnan(
|
| 1201 |
+
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]]
|
| 1202 |
+
).any(-1)
|
| 1203 |
+
new_chain = replace(
|
| 1204 |
+
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
|
| 1205 |
+
)
|
| 1206 |
+
return new_chain
|
| 1207 |
+
|
| 1208 |
+
@cached_property
|
| 1209 |
+
def inferred_cbeta(self) -> np.ndarray:
|
| 1210 |
+
"""Infer cbeta positions based on N, C, CA."""
|
| 1211 |
+
N, CA, C = np.moveaxis(self.atoms[["N", "CA", "C"]], 1, 0)
|
| 1212 |
+
# See usage in trDesign codebase.
|
| 1213 |
+
# https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L140
|
| 1214 |
+
CB = infer_CB(C, N, CA, 1.522, 1.927, -2.143)
|
| 1215 |
+
return CB
|
| 1216 |
+
|
| 1217 |
+
def infer_cbeta(self, infer_cbeta_for_glycine: bool = False) -> ProteinChain:
|
| 1218 |
+
"""Return a new chain with inferred CB atoms at all residues except GLY.
|
| 1219 |
+
|
| 1220 |
+
Args:
|
| 1221 |
+
infer_cbeta_for_glycine (bool): If True, infers a beta carbon for glycine
|
| 1222 |
+
residues, even though that residue doesn't have one. Default off.
|
| 1223 |
+
|
| 1224 |
+
NOTE(rverkuil): The reason for having this switch in the first place
|
| 1225 |
+
is that sometimes we want a (inferred) CB coordinate for every residue,
|
| 1226 |
+
for example for making a pairwise distance matrix, or doing an RMSD
|
| 1227 |
+
calculation between two designs for a given structural template, w/
|
| 1228 |
+
CB atoms.
|
| 1229 |
+
"""
|
| 1230 |
+
atom37_positions = self.atom37_positions.copy()
|
| 1231 |
+
atom37_mask = self.atom37_mask.copy()
|
| 1232 |
+
|
| 1233 |
+
inferred_cbeta_positions = self.inferred_cbeta
|
| 1234 |
+
if not infer_cbeta_for_glycine:
|
| 1235 |
+
inferred_cbeta_positions[np.array(list(self.sequence)) == "G", :] = np.nan
|
| 1236 |
+
|
| 1237 |
+
atom37_positions[:, residue_constants.atom_order["CB"]] = (
|
| 1238 |
+
inferred_cbeta_positions
|
| 1239 |
+
)
|
| 1240 |
+
atom37_mask[:, residue_constants.atom_order["CB"]] = ~np.isnan(
|
| 1241 |
+
atom37_positions[:, residue_constants.atom_order["CB"]]
|
| 1242 |
+
).any(-1)
|
| 1243 |
+
new_chain = replace(
|
| 1244 |
+
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
|
| 1245 |
+
)
|
| 1246 |
+
return new_chain
|
| 1247 |
+
|
| 1248 |
+
@cached_property
|
| 1249 |
+
def pdist_CA(self) -> np.ndarray:
|
| 1250 |
+
CA = self.atoms["CA"]
|
| 1251 |
+
pdist_CA = squareform(pdist(CA))
|
| 1252 |
+
return pdist_CA
|
| 1253 |
+
|
| 1254 |
+
@cached_property
|
| 1255 |
+
def pdist_CB(self) -> np.ndarray:
|
| 1256 |
+
pdist_CB = squareform(pdist(self.inferred_cbeta))
|
| 1257 |
+
return pdist_CB
|
| 1258 |
+
|
| 1259 |
+
@classmethod
|
| 1260 |
+
def as_complex(cls, chains: Sequence[ProteinChain]):
|
| 1261 |
+
raise RuntimeError(
|
| 1262 |
+
".as_complex() has been deprecated in favor of .concat(). "
|
| 1263 |
+
".concat() will eventually be deprecated in favor of ProteinComplex..."
|
| 1264 |
+
)
|
| 1265 |
+
|
| 1266 |
+
@classmethod
|
| 1267 |
+
def concat(cls, chains: Sequence[ProteinChain], use_chainbreak: bool = True):
|
| 1268 |
+
sep_tokens = {
|
| 1269 |
+
"residue_index": np.array([-1]),
|
| 1270 |
+
"insertion_code": np.array([""]),
|
| 1271 |
+
"atom37_positions": np.full([1, 37, 3], np.inf),
|
| 1272 |
+
"atom37_mask": np.zeros([1, 37], dtype=bool),
|
| 1273 |
+
"confidence": np.array([0]),
|
| 1274 |
+
}
|
| 1275 |
+
|
| 1276 |
+
def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):
|
| 1277 |
+
if use_chainbreak:
|
| 1278 |
+
full_array = []
|
| 1279 |
+
for array in arrays:
|
| 1280 |
+
full_array.append(array)
|
| 1281 |
+
full_array.append(sep)
|
| 1282 |
+
full_array = full_array[:-1]
|
| 1283 |
+
return np.concatenate(full_array, 0)
|
| 1284 |
+
else:
|
| 1285 |
+
return np.concatenate(arrays, 0)
|
| 1286 |
+
|
| 1287 |
+
array_args: dict[str, np.ndarray] = {
|
| 1288 |
+
name: join_arrays([getattr(chain, name) for chain in chains], sep)
|
| 1289 |
+
for name, sep in sep_tokens.items()
|
| 1290 |
+
}
|
| 1291 |
+
|
| 1292 |
+
chain_break = residue_constants.CHAIN_BREAK_TOKEN if use_chainbreak else ""
|
| 1293 |
+
return cls(
|
| 1294 |
+
id=chains[0].id,
|
| 1295 |
+
sequence=chain_break.join(chain.sequence for chain in chains),
|
| 1296 |
+
chain_id="A",
|
| 1297 |
+
entity_id=None,
|
| 1298 |
+
mmcif=None,
|
| 1299 |
+
**array_args,
|
| 1300 |
+
)
|
| 1301 |
+
|
| 1302 |
+
def find_nonpolymer_contacts(self):
|
| 1303 |
+
assert self.mmcif is not None
|
| 1304 |
+
nonpolymer_and_chain_id_to_array = self.mmcif.non_polymer_coords
|
| 1305 |
+
|
| 1306 |
+
results = []
|
| 1307 |
+
for (
|
| 1308 |
+
nonpolymer,
|
| 1309 |
+
_,
|
| 1310 |
+
), nonpolymer_array in nonpolymer_and_chain_id_to_array.items():
|
| 1311 |
+
assert nonpolymer_array.coord is not None
|
| 1312 |
+
chain_coords = self.atom37_positions[self.atom37_mask]
|
| 1313 |
+
distance = cdist(nonpolymer_array.coord, chain_coords)
|
| 1314 |
+
|
| 1315 |
+
is_contact = distance < 5
|
| 1316 |
+
if not is_contact.any():
|
| 1317 |
+
continue
|
| 1318 |
+
contacting_atoms = np.where(is_contact.any(0))[0]
|
| 1319 |
+
chain_index = np.where(self.atom37_mask)[0]
|
| 1320 |
+
contacting_residues = np.unique(chain_index[contacting_atoms])
|
| 1321 |
+
|
| 1322 |
+
result = {
|
| 1323 |
+
"ligand": nonpolymer.name,
|
| 1324 |
+
"ligand_id": nonpolymer.comp_id,
|
| 1325 |
+
"contacting_residues": contacting_residues.tolist(),
|
| 1326 |
+
}
|
| 1327 |
+
results.append(result)
|
| 1328 |
+
return results
|
| 1329 |
+
|
| 1330 |
+
def select_residue_indices(
|
| 1331 |
+
self, indices: list[int | str], ignore_x_mismatch: bool = False
|
| 1332 |
+
) -> ProteinChain:
|
| 1333 |
+
numeric_indices = [
|
| 1334 |
+
idx if isinstance(idx, int) else int(idx[1:]) for idx in indices
|
| 1335 |
+
]
|
| 1336 |
+
mask = np.isin(self.residue_index, numeric_indices)
|
| 1337 |
+
new = self[mask]
|
| 1338 |
+
mismatches = []
|
| 1339 |
+
for aa, idx in zip(new.sequence, indices):
|
| 1340 |
+
if isinstance(idx, int):
|
| 1341 |
+
continue
|
| 1342 |
+
if aa == "X" and ignore_x_mismatch:
|
| 1343 |
+
continue
|
| 1344 |
+
if aa != idx[0]:
|
| 1345 |
+
mismatches.append((aa, idx))
|
| 1346 |
+
if mismatches:
|
| 1347 |
+
mismatch_str = "; ".join(
|
| 1348 |
+
f"Position {idx[1:]}, Expected: {idx[0]}, Received: {aa}"
|
| 1349 |
+
for aa, idx in mismatches
|
| 1350 |
+
)
|
| 1351 |
+
raise RuntimeError(mismatch_str)
|
| 1352 |
+
|
| 1353 |
+
return new
|
| 1354 |
+
|
| 1355 |
+
def to_structure_encoder_inputs(
|
| 1356 |
+
self,
|
| 1357 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 1358 |
+
"""Convert protein chain to structure encoder inputs.
|
| 1359 |
+
|
| 1360 |
+
Returns:
|
| 1361 |
+
tuple: (coordinates, plddt, residue_index) where:
|
| 1362 |
+
- coordinates: (1, L, 37, 3) tensor of atom positions
|
| 1363 |
+
- plddt: (1, L) tensor of confidence scores
|
| 1364 |
+
- residue_index: (1, L) tensor of residue indices
|
| 1365 |
+
"""
|
| 1366 |
+
# Convert to tensors and add batch dimension
|
| 1367 |
+
coordinates = (
|
| 1368 |
+
torch.from_numpy(self.atom37_positions).float().unsqueeze(0)
|
| 1369 |
+
) # (1, L, 37, 3)
|
| 1370 |
+
plddt = torch.from_numpy(self.confidence).float().unsqueeze(0) # (1, L)
|
| 1371 |
+
residue_index = (
|
| 1372 |
+
torch.from_numpy(self.residue_index).long().unsqueeze(0)
|
| 1373 |
+
) # (1, L)
|
| 1374 |
+
|
| 1375 |
+
return coordinates, plddt, residue_index
|
| 1376 |
+
|
esmfold2_protein_complex.py
ADDED
|
@@ -0,0 +1,1241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import io
|
| 4 |
+
import itertools
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import warnings
|
| 8 |
+
from dataclasses import asdict, dataclass, replace
|
| 9 |
+
from functools import cached_property
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from subprocess import check_output
|
| 12 |
+
from tempfile import TemporaryDirectory
|
| 13 |
+
from typing import Any, Iterable, Sequence
|
| 14 |
+
|
| 15 |
+
import biotite.structure as bs
|
| 16 |
+
import brotli
|
| 17 |
+
import msgpack
|
| 18 |
+
import msgpack_numpy
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from biotite.database import rcsb
|
| 22 |
+
from biotite.file import InvalidFileError
|
| 23 |
+
from biotite.structure.io.pdb import PDBFile
|
| 24 |
+
from biotite.structure.io.pdbx import CIFCategory, CIFColumn, CIFData, CIFFile
|
| 25 |
+
from biotite.structure.io.pdbx import set_structure as set_structure_pdbx
|
| 26 |
+
from biotite.structure.io.pdbx.convert import _get_transformations, get_structure
|
| 27 |
+
from biotite.structure.util import matrix_rotate
|
| 28 |
+
from scipy.spatial import KDTree
|
| 29 |
+
|
| 30 |
+
from . import esmfold2_residue_constants
|
| 31 |
+
from .esmfold2_misc import slice_python_object_as_numpy
|
| 32 |
+
from .esmfold2_affine3d import Affine3D
|
| 33 |
+
from .esmfold2_aligner import Aligner
|
| 34 |
+
from .esmfold2_atom_indexer import AtomIndexer
|
| 35 |
+
from .esmfold2_metrics import compute_gdt_ts, compute_lddt_ca
|
| 36 |
+
from .esmfold2_mmcif_parsing import MmcifWrapper, NoProteinError
|
| 37 |
+
from .esmfold2_protein_chain import (
|
| 38 |
+
ProteinChain,
|
| 39 |
+
_str_key_to_int_key,
|
| 40 |
+
chain_to_ndarray,
|
| 41 |
+
index_by_atom_name,
|
| 42 |
+
infer_CB,
|
| 43 |
+
)
|
| 44 |
+
from .esmfold2_utils_types import PathOrBuffer
|
| 45 |
+
|
| 46 |
+
msgpack_numpy.patch()
|
| 47 |
+
|
| 48 |
+
SINGLE_LETTER_CHAIN_IDS = (
|
| 49 |
+
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _parse_operation_expression(expression):
|
| 54 |
+
"""
|
| 55 |
+
Get successive operation steps (IDs) for the given
|
| 56 |
+
``oper_expression``.
|
| 57 |
+
Form the cartesian product, if necessary.
|
| 58 |
+
Copied from biotite and fixed a bug
|
| 59 |
+
"""
|
| 60 |
+
# Split groups by parentheses:
|
| 61 |
+
# use the opening parenthesis as delimiter
|
| 62 |
+
# and just remove the closing parenthesis
|
| 63 |
+
expressions_per_step = expression.replace(")", "").split("(")
|
| 64 |
+
expressions_per_step = [e for e in expressions_per_step if len(e) > 0]
|
| 65 |
+
# Important: Operations are applied from right to left
|
| 66 |
+
expressions_per_step.reverse()
|
| 67 |
+
|
| 68 |
+
operations = []
|
| 69 |
+
for expr in expressions_per_step:
|
| 70 |
+
cur_expr = expr.split(",")
|
| 71 |
+
cur_op = []
|
| 72 |
+
# Deal with e='1-10,20-30,40-50' type expressions
|
| 73 |
+
for e in cur_expr:
|
| 74 |
+
if "-" in e:
|
| 75 |
+
first, last = e.split("-")
|
| 76 |
+
cur_op.extend(str(id) for id in range(int(first), int(last) + 1))
|
| 77 |
+
else:
|
| 78 |
+
cur_op.append(e)
|
| 79 |
+
operations.append(cur_op)
|
| 80 |
+
|
| 81 |
+
# Cartesian product of operations
|
| 82 |
+
return list(itertools.product(*operations))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _apply_transformations_fast(chains, transformation_dict, operations):
|
| 86 |
+
"""
|
| 87 |
+
Get subassembly by applying the given operations to the input
|
| 88 |
+
structure containing affected asym IDs.
|
| 89 |
+
"""
|
| 90 |
+
# Additional first dimesion for 'structure.repeat()'
|
| 91 |
+
results = []
|
| 92 |
+
|
| 93 |
+
# Apply corresponding transformation for each copy in the assembly
|
| 94 |
+
for c in chains:
|
| 95 |
+
for operation in operations:
|
| 96 |
+
coord = c.atom37_positions.copy()
|
| 97 |
+
# Execute for each transformation step
|
| 98 |
+
# in the operation expression
|
| 99 |
+
for op_step in operation:
|
| 100 |
+
T = transformation_dict[op_step]
|
| 101 |
+
# Rotate
|
| 102 |
+
coord = matrix_rotate(coord, T.rotation)
|
| 103 |
+
# Translate
|
| 104 |
+
coord += T.target_translation
|
| 105 |
+
new_chain = replace(c, atom37_positions=coord)
|
| 106 |
+
results.append(new_chain)
|
| 107 |
+
|
| 108 |
+
return results
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@dataclass
|
| 112 |
+
class ProteinComplexMetadata:
|
| 113 |
+
entity_lookup: dict[int, int]
|
| 114 |
+
chain_lookup: dict[int, str]
|
| 115 |
+
mmcif: MmcifWrapper | None = None
|
| 116 |
+
# This is a dictionary that maps assembly ids to the list of unique chains
|
| 117 |
+
# in that assembly. Allows for usage of `switch_assembly`.
|
| 118 |
+
assembly_composition: dict[str, list[str]] | None = None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@dataclass
|
| 122 |
+
class DockQSingleScore:
|
| 123 |
+
native_chains: tuple[str, str]
|
| 124 |
+
DockQ: float
|
| 125 |
+
interface_rms: float
|
| 126 |
+
ligand_rms: float
|
| 127 |
+
fnat: float
|
| 128 |
+
fnonnat: float
|
| 129 |
+
clashes: float
|
| 130 |
+
F1: float
|
| 131 |
+
DockQ_F1: float
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@dataclass
|
| 135 |
+
class DockQResult:
|
| 136 |
+
total_dockq: float
|
| 137 |
+
native_interfaces: int
|
| 138 |
+
chain_mapping: dict[str, str]
|
| 139 |
+
interfaces: dict[tuple[str, str], DockQSingleScore]
|
| 140 |
+
# zip(aligned.chain_iter(), native.chain_iter()) gives you the pairing
|
| 141 |
+
# aligned.rmsd(native) should give you a low rmsd irrespective of shuffling
|
| 142 |
+
aligned: ProteinComplex
|
| 143 |
+
aligned_rmsd: float
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@dataclass(frozen=True)
|
| 147 |
+
class ProteinComplex:
|
| 148 |
+
"""Dataclass with atom37 representation of an entire protein complex."""
|
| 149 |
+
|
| 150 |
+
id: str
|
| 151 |
+
sequence: str
|
| 152 |
+
entity_id: np.ndarray # entities map to unique sequences
|
| 153 |
+
chain_id: np.ndarray # multiple chains might share an entity id
|
| 154 |
+
sym_id: np.ndarray # complexes might be copies of the same chain
|
| 155 |
+
residue_index: np.ndarray
|
| 156 |
+
insertion_code: np.ndarray
|
| 157 |
+
atom37_positions: np.ndarray
|
| 158 |
+
atom37_mask: np.ndarray
|
| 159 |
+
confidence: np.ndarray
|
| 160 |
+
# This metadata is parsed from the MMCIF file. For synthetic data, we do a best effort.
|
| 161 |
+
metadata: ProteinComplexMetadata
|
| 162 |
+
atom37_confidence: np.ndarray | None = None # [L, 37] per-atom pLDDT
|
| 163 |
+
|
| 164 |
+
def __post_init__(self):
|
| 165 |
+
l = len(self.sequence)
|
| 166 |
+
assert self.atom37_positions.shape[0] == l, (self.atom37_positions.shape, l)
|
| 167 |
+
assert self.atom37_mask.shape[0] == l, (self.atom37_mask.shape, l)
|
| 168 |
+
assert self.residue_index.shape[0] == l, (self.residue_index.shape, l)
|
| 169 |
+
assert self.insertion_code.shape[0] == l, (self.insertion_code.shape, l)
|
| 170 |
+
assert self.confidence.shape[0] == l, (self.confidence.shape, l)
|
| 171 |
+
assert self.entity_id.shape[0] == l, (self.entity_id.shape, l)
|
| 172 |
+
assert self.chain_id.shape[0] == l, (self.chain_id.shape, l)
|
| 173 |
+
assert self.sym_id.shape[0] == l, (self.sym_id.shape, l)
|
| 174 |
+
if self.atom37_confidence is not None:
|
| 175 |
+
assert self.atom37_confidence.shape == self.atom37_mask.shape, (
|
| 176 |
+
self.atom37_confidence.shape,
|
| 177 |
+
self.atom37_mask.shape,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def __getitem__(self, idx: int | list[int] | slice | np.ndarray):
|
| 181 |
+
"""This function slices protein complexes without consideration of chain breaks
|
| 182 |
+
NOTE: When slicing with a boolean mask, it's possible that the output array won't
|
| 183 |
+
be the expected length. This is because we do our best to preserve chainbreak tokens.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
if isinstance(idx, int):
|
| 187 |
+
idx = [idx]
|
| 188 |
+
if isinstance(idx, list):
|
| 189 |
+
raise ValueError(
|
| 190 |
+
"ProteinComplex doesn't supports indexing with lists of indices"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
if isinstance(idx, np.ndarray):
|
| 194 |
+
is_chainbreak = np.asarray([s == "|" for s in self.sequence])
|
| 195 |
+
idx = idx.astype(bool) | is_chainbreak
|
| 196 |
+
|
| 197 |
+
complex = self._unsafe_slice(idx)
|
| 198 |
+
if len(complex) == 0:
|
| 199 |
+
return complex
|
| 200 |
+
|
| 201 |
+
# detect runs of chainbreaks by searching for instances of '||' in complex.sequence
|
| 202 |
+
chainbreak_runs = np.asarray(
|
| 203 |
+
[
|
| 204 |
+
complex.sequence[i : i + 2] == "||"
|
| 205 |
+
for i in range(len(complex.sequence) - 1)
|
| 206 |
+
]
|
| 207 |
+
+ [complex.sequence[-1] == "|"]
|
| 208 |
+
)
|
| 209 |
+
# We should remove as many chainbreaks as possible from the start of the sequence
|
| 210 |
+
for i in range(len(chainbreak_runs)):
|
| 211 |
+
if complex.sequence[i] == "|":
|
| 212 |
+
chainbreak_runs[i] = True
|
| 213 |
+
else:
|
| 214 |
+
break
|
| 215 |
+
complex = complex._unsafe_slice(~chainbreak_runs)
|
| 216 |
+
return complex
|
| 217 |
+
|
| 218 |
+
def _unsafe_slice(self, idx: int | list[int] | slice | np.ndarray):
|
| 219 |
+
sequence = slice_python_object_as_numpy(self.sequence, idx)
|
| 220 |
+
return replace(
|
| 221 |
+
self,
|
| 222 |
+
sequence=sequence,
|
| 223 |
+
entity_id=self.entity_id[..., idx],
|
| 224 |
+
chain_id=self.chain_id[..., idx],
|
| 225 |
+
sym_id=self.sym_id[..., idx],
|
| 226 |
+
residue_index=self.residue_index[..., idx],
|
| 227 |
+
insertion_code=self.insertion_code[..., idx],
|
| 228 |
+
atom37_positions=self.atom37_positions[..., idx, :, :],
|
| 229 |
+
atom37_mask=self.atom37_mask[..., idx, :],
|
| 230 |
+
confidence=self.confidence[..., idx],
|
| 231 |
+
atom37_confidence=self.atom37_confidence[..., idx, :]
|
| 232 |
+
if self.atom37_confidence is not None
|
| 233 |
+
else None,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def __len__(self):
|
| 237 |
+
return len(self.sequence)
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def num_chains(self):
|
| 241 |
+
return len(self.chain_boundaries)
|
| 242 |
+
|
| 243 |
+
@cached_property
|
| 244 |
+
def atoms(self) -> AtomIndexer:
|
| 245 |
+
return AtomIndexer(self, property="atom37_positions", dim=-2)
|
| 246 |
+
|
| 247 |
+
@cached_property
|
| 248 |
+
def atom_mask(self) -> AtomIndexer:
|
| 249 |
+
return AtomIndexer(self, property="atom37_mask", dim=-1)
|
| 250 |
+
|
| 251 |
+
@cached_property
|
| 252 |
+
def chain_lengths(self) -> np.ndarray:
|
| 253 |
+
return np.diff(self.chain_boundaries, axis=1).flatten()
|
| 254 |
+
|
| 255 |
+
@cached_property
|
| 256 |
+
def chain_boundaries(self) -> list[tuple[int, int]]:
|
| 257 |
+
cb = [-1]
|
| 258 |
+
for i, s in enumerate(self.sequence):
|
| 259 |
+
if s == "|":
|
| 260 |
+
cb.append(i)
|
| 261 |
+
cb.append(len(self))
|
| 262 |
+
return [(cb[i] + 1, cb[i + 1]) for i in range(len(cb) - 1)]
|
| 263 |
+
|
| 264 |
+
def get_chain_by_index(self, index: int) -> ProteinChain:
|
| 265 |
+
try:
|
| 266 |
+
start, end = self.chain_boundaries[index]
|
| 267 |
+
return self[start:end].as_chain()
|
| 268 |
+
except IndexError:
|
| 269 |
+
raise IndexError(f"Chain index {index} out of bounds")
|
| 270 |
+
|
| 271 |
+
def get_chain_by_id(
|
| 272 |
+
self, chain_id: str, sample_chain_if_duplicate: bool = True
|
| 273 |
+
) -> ProteinChain:
|
| 274 |
+
valid_indices = [
|
| 275 |
+
index
|
| 276 |
+
for index, id_of_index in self.metadata.chain_lookup.items()
|
| 277 |
+
if id_of_index == chain_id
|
| 278 |
+
]
|
| 279 |
+
if not valid_indices:
|
| 280 |
+
raise KeyError(f"Chain ID {chain_id} not found")
|
| 281 |
+
if sample_chain_if_duplicate:
|
| 282 |
+
index_to_return = random.choice(valid_indices)
|
| 283 |
+
return self.get_chain_by_index(index_to_return)
|
| 284 |
+
else:
|
| 285 |
+
if len(valid_indices) > 1:
|
| 286 |
+
raise ValueError(f"Multiple chains with chain ID {chain_id} found")
|
| 287 |
+
return self.get_chain_by_index(valid_indices[0])
|
| 288 |
+
|
| 289 |
+
def chain_iter(self) -> Iterable[ProteinChain]:
|
| 290 |
+
for start, end in self.chain_boundaries:
|
| 291 |
+
c = self[start:end]
|
| 292 |
+
yield c.as_chain()
|
| 293 |
+
|
| 294 |
+
def as_chain(self, force_conversion: bool = False) -> ProteinChain:
|
| 295 |
+
"""Convert the ProteinComplex to a ProteinChain.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
force_conversion (bool): Forces the conversion into a protein chain even if the complex has multiple chains.
|
| 299 |
+
The purpose of this is to use ProteinChain specific functions (like cbeta_contacts).
|
| 300 |
+
|
| 301 |
+
"""
|
| 302 |
+
if not force_conversion:
|
| 303 |
+
assert len(np.unique(self.chain_id)) == 1, f"{self.id}"
|
| 304 |
+
assert len(np.unique(self.entity_id)) == 1, f"{self.id}"
|
| 305 |
+
if self.chain_id[0] not in self.metadata.chain_lookup:
|
| 306 |
+
warnings.warn("Chain ID not found in metadata, using 'A' as default")
|
| 307 |
+
if self.entity_id[0] not in self.metadata.entity_lookup:
|
| 308 |
+
warnings.warn("Entity ID not found in metadata, using None as default")
|
| 309 |
+
chain_id = self.metadata.chain_lookup.get(self.chain_id[0], "A")
|
| 310 |
+
entity_id = self.metadata.entity_lookup.get(self.entity_id[0], None)
|
| 311 |
+
else:
|
| 312 |
+
chain_id = "A"
|
| 313 |
+
entity_id = None
|
| 314 |
+
|
| 315 |
+
return ProteinChain(
|
| 316 |
+
id=self.id,
|
| 317 |
+
sequence=self.sequence,
|
| 318 |
+
chain_id=chain_id,
|
| 319 |
+
entity_id=entity_id,
|
| 320 |
+
atom37_positions=self.atom37_positions,
|
| 321 |
+
atom37_mask=self.atom37_mask,
|
| 322 |
+
residue_index=self.residue_index,
|
| 323 |
+
insertion_code=self.insertion_code,
|
| 324 |
+
confidence=self.confidence,
|
| 325 |
+
mmcif=self.metadata.mmcif,
|
| 326 |
+
atom37_confidence=self.atom37_confidence,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
@classmethod
|
| 330 |
+
def from_pdb(
|
| 331 |
+
cls, path: PathOrBuffer, id: str | None = None, is_predicted: bool = False
|
| 332 |
+
) -> "ProteinComplex":
|
| 333 |
+
atom_array = PDBFile.read(path).get_structure(
|
| 334 |
+
model=1, extra_fields=["b_factor"]
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
chains = []
|
| 338 |
+
for chain in bs.chain_iter(atom_array):
|
| 339 |
+
chain = chain[~chain.hetero]
|
| 340 |
+
if len(chain) == 0:
|
| 341 |
+
continue
|
| 342 |
+
chains.append(ProteinChain.from_atomarray(chain, id, is_predicted))
|
| 343 |
+
return ProteinComplex.from_chains(chains)
|
| 344 |
+
|
| 345 |
+
def to_pdb(self, path: PathOrBuffer, include_insertions: bool = True):
|
| 346 |
+
atom_array = None
|
| 347 |
+
for chain in self.chain_iter():
|
| 348 |
+
carr = (
|
| 349 |
+
chain.atom_array
|
| 350 |
+
if include_insertions
|
| 351 |
+
else chain.atom_array_no_insertions
|
| 352 |
+
)
|
| 353 |
+
atom_array = carr if atom_array is None else atom_array + carr
|
| 354 |
+
f = PDBFile()
|
| 355 |
+
f.set_structure(atom_array)
|
| 356 |
+
f.write(path)
|
| 357 |
+
|
| 358 |
+
def to_pdb_string(self, include_insertions: bool = True) -> str:
|
| 359 |
+
buf = io.StringIO()
|
| 360 |
+
self.to_pdb(buf, include_insertions=include_insertions)
|
| 361 |
+
buf.seek(0)
|
| 362 |
+
return buf.read()
|
| 363 |
+
|
| 364 |
+
def normalize_chain_ids_for_pdb(self):
|
| 365 |
+
# Since PDB files have 1-letter chain IDs and don't support the idea of a symmetric index,
|
| 366 |
+
# we can normalize it instead which might be necessary for DockQ and to_pdb.
|
| 367 |
+
ids = SINGLE_LETTER_CHAIN_IDS
|
| 368 |
+
chains = []
|
| 369 |
+
for i, chain in enumerate(self.chain_iter()):
|
| 370 |
+
chain = replace(chain, chain_id=ids[i])
|
| 371 |
+
if i > len(ids):
|
| 372 |
+
raise RuntimeError("Too many chains to write to PDB file")
|
| 373 |
+
chains.append(chain)
|
| 374 |
+
|
| 375 |
+
return ProteinComplex.from_chains(chains)
|
| 376 |
+
|
| 377 |
+
def find_assembly_ids_with_chain(self, id: str) -> list[str]:
|
| 378 |
+
good_chains = []
|
| 379 |
+
if (comp := self.metadata.assembly_composition) is not None:
|
| 380 |
+
for assembly_id, chain_ids in comp.items():
|
| 381 |
+
if id in chain_ids:
|
| 382 |
+
good_chains.append(assembly_id)
|
| 383 |
+
else:
|
| 384 |
+
raise ValueError(
|
| 385 |
+
"Cannot switch assemblies on this ProteinComplex, you must create the assembly from mmcif to support this"
|
| 386 |
+
)
|
| 387 |
+
return good_chains
|
| 388 |
+
|
| 389 |
+
def switch_assembly(self, id: str):
|
| 390 |
+
assert self.metadata.mmcif is not None
|
| 391 |
+
return get_assembly_fast(self.metadata.mmcif, assembly_id=id)
|
| 392 |
+
|
| 393 |
+
def state_dict(self, backbone_only=False, json_serializable=False):
|
| 394 |
+
"""This state dict is optimized for storage, so it turns things to fp16 whenever
|
| 395 |
+
possible. Note that we also only support int32 residue indices, I'm hoping we don't
|
| 396 |
+
need more than 2**32 residues..."""
|
| 397 |
+
dct = {k: v for k, v in vars(self).items()}
|
| 398 |
+
if backbone_only:
|
| 399 |
+
dct["atom37_mask"][:, 3:] = False
|
| 400 |
+
dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]]
|
| 401 |
+
if dct.get("atom37_confidence") is not None:
|
| 402 |
+
dct["atom37_confidence"] = dct["atom37_confidence"][dct["atom37_mask"]]
|
| 403 |
+
else:
|
| 404 |
+
dct.pop("atom37_confidence", None)
|
| 405 |
+
for k, v in dct.items():
|
| 406 |
+
if isinstance(v, np.ndarray):
|
| 407 |
+
match v.dtype:
|
| 408 |
+
case np.int64:
|
| 409 |
+
dct[k] = v.astype(np.int32)
|
| 410 |
+
case np.float64 | np.float32:
|
| 411 |
+
dct[k] = v.astype(np.float16)
|
| 412 |
+
case _:
|
| 413 |
+
pass
|
| 414 |
+
if json_serializable:
|
| 415 |
+
dct[k] = v.tolist()
|
| 416 |
+
elif isinstance(v, ProteinComplexMetadata):
|
| 417 |
+
dct[k] = asdict(v)
|
| 418 |
+
dct["metadata"]["mmcif"] = None
|
| 419 |
+
# These can be populated with non-serializable objects and are not needed for reconstruction
|
| 420 |
+
dct.pop("atoms", None)
|
| 421 |
+
dct.pop("atom_mask", None)
|
| 422 |
+
dct.pop("per_chain_kd_trees", None)
|
| 423 |
+
return dct
|
| 424 |
+
|
| 425 |
+
def to_blob(self, backbone_only=False) -> bytes:
|
| 426 |
+
return brotli.compress(msgpack.dumps(self.state_dict(backbone_only)), quality=5)
|
| 427 |
+
|
| 428 |
+
@classmethod
|
| 429 |
+
def from_state_dict(cls, dct):
|
| 430 |
+
# Note: assembly_composition is *supposed* to have string keys.
|
| 431 |
+
dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"])
|
| 432 |
+
|
| 433 |
+
for k, v in dct.items():
|
| 434 |
+
if isinstance(v, list):
|
| 435 |
+
dct[k] = np.array(v)
|
| 436 |
+
|
| 437 |
+
atom37 = np.full((*dct["atom37_mask"].shape, 3), np.nan)
|
| 438 |
+
atom37[dct["atom37_mask"]] = dct["atom37_positions"]
|
| 439 |
+
dct["atom37_positions"] = atom37
|
| 440 |
+
if "atom37_confidence" in dct:
|
| 441 |
+
atom37_conf = np.full(dct["atom37_mask"].shape, np.nan, dtype=np.float32)
|
| 442 |
+
atom37_conf[dct["atom37_mask"]] = dct["atom37_confidence"]
|
| 443 |
+
dct["atom37_confidence"] = atom37_conf
|
| 444 |
+
dct = {
|
| 445 |
+
k: (
|
| 446 |
+
v.astype(np.float32)
|
| 447 |
+
if k in ["atom37_positions", "confidence", "atom37_confidence"]
|
| 448 |
+
else v
|
| 449 |
+
)
|
| 450 |
+
for k, v in dct.items()
|
| 451 |
+
}
|
| 452 |
+
if "chain_boundaries" in dct:
|
| 453 |
+
del dct["chain_boundaries"]
|
| 454 |
+
if "chain_boundaries" in dct["metadata"]:
|
| 455 |
+
del dct["metadata"]["chain_boundaries"]
|
| 456 |
+
dct["metadata"] = ProteinComplexMetadata(**dct["metadata"])
|
| 457 |
+
return cls(**dct)
|
| 458 |
+
|
| 459 |
+
@classmethod
|
| 460 |
+
def from_blob(cls, input: Path | str | io.BytesIO | bytes):
|
| 461 |
+
"""NOTE(@zlin): blob + sparse coding + brotli + fp16 reduces memory
|
| 462 |
+
of chains from 52G/1M chains to 20G/1M chains, I think this is a good first
|
| 463 |
+
shot at compressing and dumping chains to disk. I'm sure there's better ways."""
|
| 464 |
+
match input:
|
| 465 |
+
case Path() | str():
|
| 466 |
+
bytes = Path(input).read_bytes()
|
| 467 |
+
case io.BytesIO():
|
| 468 |
+
bytes = input.getvalue()
|
| 469 |
+
case _:
|
| 470 |
+
bytes = input
|
| 471 |
+
return cls.from_state_dict(
|
| 472 |
+
msgpack.loads(brotli.decompress(bytes), strict_map_key=False)
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
@classmethod
|
| 476 |
+
def from_rcsb(cls, pdb_id: str, keep_source: bool = False) -> ProteinComplex:
|
| 477 |
+
f: io.StringIO = rcsb.fetch(pdb_id, "cif") # type: ignore
|
| 478 |
+
return cls.from_mmcif(f, id=pdb_id, keep_source=keep_source, is_predicted=False)
|
| 479 |
+
|
| 480 |
+
@classmethod
|
| 481 |
+
def from_mmcif(
|
| 482 |
+
cls,
|
| 483 |
+
path: PathOrBuffer,
|
| 484 |
+
id: str | None = None,
|
| 485 |
+
assembly_id: str | None = None,
|
| 486 |
+
is_predicted: bool = False,
|
| 487 |
+
keep_source: bool = False,
|
| 488 |
+
):
|
| 489 |
+
"""Return a ProteinComplex object from an mmcif file.
|
| 490 |
+
TODO(@zeming): there's actually multiple complexes per file, but for ease of implementation,
|
| 491 |
+
we only consider the first defined complex!
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
|
| 495 |
+
id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
|
| 496 |
+
is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
|
| 497 |
+
chain_id (str, optional): Select a chain corresponding to (author) chain id.
|
| 498 |
+
"""
|
| 499 |
+
mmcif = MmcifWrapper.read(path, id)
|
| 500 |
+
return get_assembly_fast(mmcif, assembly_id=assembly_id)
|
| 501 |
+
|
| 502 |
+
@classmethod
|
| 503 |
+
def from_chains(
|
| 504 |
+
cls,
|
| 505 |
+
chains: Sequence[ProteinChain],
|
| 506 |
+
mmcif: MmcifWrapper | None = None,
|
| 507 |
+
all_assembly_metadata_dictionary: dict[str, list[str]] | None = None,
|
| 508 |
+
):
|
| 509 |
+
if not chains:
|
| 510 |
+
raise ValueError(
|
| 511 |
+
"Cannot create a ProteinComplex from an empty list of chains"
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# TODO(roshan): Make a proper protein complex class
|
| 515 |
+
def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):
|
| 516 |
+
full_array = []
|
| 517 |
+
for array in arrays:
|
| 518 |
+
full_array.append(array)
|
| 519 |
+
full_array.append(sep)
|
| 520 |
+
full_array = full_array[:-1]
|
| 521 |
+
return np.concatenate(full_array, 0)
|
| 522 |
+
|
| 523 |
+
sep_tokens = {
|
| 524 |
+
"residue_index": np.array([-1]),
|
| 525 |
+
"insertion_code": np.array([""]),
|
| 526 |
+
"atom37_positions": np.full([1, 37, 3], np.nan),
|
| 527 |
+
"atom37_mask": np.zeros([1, 37], dtype=bool),
|
| 528 |
+
"confidence": np.array([0]),
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
+
any_has_atom37_conf = any(c.atom37_confidence is not None for c in chains)
|
| 532 |
+
if any_has_atom37_conf:
|
| 533 |
+
sep_tokens["atom37_confidence"] = np.full([1, 37], np.nan, dtype=np.float32)
|
| 534 |
+
|
| 535 |
+
def _get_chain_attr(chain: ProteinChain, name: str) -> np.ndarray:
|
| 536 |
+
val = getattr(chain, name)
|
| 537 |
+
if val is None and name == "atom37_confidence":
|
| 538 |
+
return np.full([len(chain), 37], np.nan, dtype=np.float32)
|
| 539 |
+
return val
|
| 540 |
+
|
| 541 |
+
array_args: dict[str, np.ndarray] = {
|
| 542 |
+
name: join_arrays([_get_chain_attr(chain, name) for chain in chains], sep)
|
| 543 |
+
for name, sep in sep_tokens.items()
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
multimer_arrays = []
|
| 547 |
+
chain2num_max = -1
|
| 548 |
+
chain2num = {}
|
| 549 |
+
ent2num_max = -1
|
| 550 |
+
ent2num = {}
|
| 551 |
+
total_index = 0
|
| 552 |
+
for i, c in enumerate(chains):
|
| 553 |
+
num_res = c.residue_index.shape[0]
|
| 554 |
+
if c.chain_id not in chain2num:
|
| 555 |
+
chain2num[c.chain_id] = (chain2num_max := chain2num_max + 1)
|
| 556 |
+
chain_id_array = np.full([num_res], chain2num[c.chain_id], dtype=np.int64)
|
| 557 |
+
|
| 558 |
+
if c.entity_id is None:
|
| 559 |
+
entity_num = (ent2num_max := ent2num_max + 1)
|
| 560 |
+
else:
|
| 561 |
+
if c.entity_id not in ent2num:
|
| 562 |
+
ent2num[c.entity_id] = (ent2num_max := ent2num_max + 1)
|
| 563 |
+
entity_num = ent2num[c.entity_id]
|
| 564 |
+
entity_id_array = np.full([num_res], entity_num, dtype=np.int64)
|
| 565 |
+
|
| 566 |
+
sym_id_array = np.full([num_res], i, dtype=np.int64)
|
| 567 |
+
|
| 568 |
+
multimer_arrays.append(
|
| 569 |
+
{
|
| 570 |
+
"chain_id": chain_id_array,
|
| 571 |
+
"entity_id": entity_id_array,
|
| 572 |
+
"sym_id": sym_id_array,
|
| 573 |
+
}
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
total_index += num_res + 1
|
| 577 |
+
|
| 578 |
+
sep = np.array([-1])
|
| 579 |
+
update = {
|
| 580 |
+
name: join_arrays([dct[name] for dct in multimer_arrays], sep=sep)
|
| 581 |
+
for name in ["chain_id", "entity_id", "sym_id"]
|
| 582 |
+
}
|
| 583 |
+
array_args.update(update)
|
| 584 |
+
|
| 585 |
+
metadata = ProteinComplexMetadata(
|
| 586 |
+
mmcif=mmcif,
|
| 587 |
+
chain_lookup={v: k for k, v in chain2num.items()},
|
| 588 |
+
entity_lookup={v: k for k, v in ent2num.items()},
|
| 589 |
+
assembly_composition=all_assembly_metadata_dictionary,
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
return cls(
|
| 593 |
+
id=chains[0].id,
|
| 594 |
+
sequence=residue_constants.CHAIN_BREAK_TOKEN.join(
|
| 595 |
+
chain.sequence for chain in chains
|
| 596 |
+
),
|
| 597 |
+
metadata=metadata,
|
| 598 |
+
**array_args,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
def infer_oxygen(self) -> ProteinComplex:
|
| 602 |
+
"""Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
|
| 603 |
+
O_missing_indices = np.argwhere(
|
| 604 |
+
~np.isfinite(self.atoms["O"]).all(axis=1)
|
| 605 |
+
).squeeze()
|
| 606 |
+
|
| 607 |
+
O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
|
| 608 |
+
N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
|
| 609 |
+
N = torch.roll(N, -3)
|
| 610 |
+
N[..., -1, :] = torch.nan
|
| 611 |
+
|
| 612 |
+
# Get the frame defined by the CA-C-N atom
|
| 613 |
+
frames = Affine3D.from_graham_schmidt(CA, C, N)
|
| 614 |
+
O = frames.apply(O_vector)
|
| 615 |
+
atom37_positions = self.atom37_positions.copy()
|
| 616 |
+
atom37_mask = self.atom37_mask.copy()
|
| 617 |
+
|
| 618 |
+
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]] = O[
|
| 619 |
+
O_missing_indices
|
| 620 |
+
].numpy()
|
| 621 |
+
atom37_mask[O_missing_indices, residue_constants.atom_order["O"]] = ~np.isnan(
|
| 622 |
+
atom37_positions[O_missing_indices, residue_constants.atom_order["O"]]
|
| 623 |
+
).any(-1)
|
| 624 |
+
new_chain = replace(
|
| 625 |
+
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
|
| 626 |
+
)
|
| 627 |
+
return new_chain
|
| 628 |
+
|
| 629 |
+
def infer_cbeta(self, infer_cbeta_for_glycine: bool = False) -> ProteinComplex:
|
| 630 |
+
"""Return a new chain with inferred CB atoms at all residues except GLY.
|
| 631 |
+
|
| 632 |
+
Args:
|
| 633 |
+
infer_cbeta_for_glycine (bool): If True, infers a beta carbon for glycine
|
| 634 |
+
residues, even though that residue doesn't have one. Default off.
|
| 635 |
+
|
| 636 |
+
NOTE(rverkuil): The reason for having this switch in the first place
|
| 637 |
+
is that sometimes we want a (inferred) CB coordinate for every residue,
|
| 638 |
+
for example for making a pairwise distance matrix, or doing an RMSD
|
| 639 |
+
calculation between two designs for a given structural template, w/
|
| 640 |
+
CB atoms.
|
| 641 |
+
"""
|
| 642 |
+
atom37_positions = self.atom37_positions.copy()
|
| 643 |
+
atom37_mask = self.atom37_mask.copy()
|
| 644 |
+
|
| 645 |
+
N, CA, C = np.moveaxis(self.atoms[["N", "CA", "C"]], 1, 0)
|
| 646 |
+
# See usage in trDesign codebase.
|
| 647 |
+
# https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L140
|
| 648 |
+
inferred_cbeta_positions = infer_CB(C, N, CA, 1.522, 1.927, -2.143)
|
| 649 |
+
if not infer_cbeta_for_glycine:
|
| 650 |
+
inferred_cbeta_positions[np.array(list(self.sequence)) == "G", :] = np.nan
|
| 651 |
+
|
| 652 |
+
atom37_positions[:, residue_constants.atom_order["CB"]] = (
|
| 653 |
+
inferred_cbeta_positions
|
| 654 |
+
)
|
| 655 |
+
atom37_mask[:, residue_constants.atom_order["CB"]] = ~np.isnan(
|
| 656 |
+
atom37_positions[:, residue_constants.atom_order["CB"]]
|
| 657 |
+
).any(-1)
|
| 658 |
+
new_chain = replace(
|
| 659 |
+
self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
|
| 660 |
+
)
|
| 661 |
+
return new_chain
|
| 662 |
+
|
| 663 |
+
@classmethod
|
| 664 |
+
def from_open_source(cls, pc: ProteinComplex):
|
| 665 |
+
# TODO(@zeming): deprecated, should delete
|
| 666 |
+
return pc
|
| 667 |
+
|
| 668 |
+
@classmethod
|
| 669 |
+
def concat(cls, objs: list[ProteinComplex]) -> ProteinComplex:
|
| 670 |
+
pdb_ids = [obj.id for obj in objs]
|
| 671 |
+
if len(set(pdb_ids)) > 1:
|
| 672 |
+
raise RuntimeError(
|
| 673 |
+
"Concatention of protein complexes across different PDB ids is unsupported"
|
| 674 |
+
)
|
| 675 |
+
return ProteinComplex.from_chains(
|
| 676 |
+
list(itertools.chain.from_iterable(obj.chain_iter() for obj in objs))
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
def _sanity_check_complexes_are_comparable(self, other: ProteinComplex):
|
| 680 |
+
assert len(self) == len(other), "Protein complexes must have the same length"
|
| 681 |
+
assert len(list(self.chain_iter())) == len(
|
| 682 |
+
list(other.chain_iter())
|
| 683 |
+
), "Protein complexes must have the same number of chains"
|
| 684 |
+
|
| 685 |
+
def rmsd(
|
| 686 |
+
self,
|
| 687 |
+
target: ProteinComplex,
|
| 688 |
+
also_check_reflection: bool = False,
|
| 689 |
+
mobile_inds: list[int] | np.ndarray | None = None,
|
| 690 |
+
target_inds: list[int] | np.ndarray | None = None,
|
| 691 |
+
only_compute_backbone_rmsd: bool = False,
|
| 692 |
+
compute_chain_assignment: bool = True,
|
| 693 |
+
):
|
| 694 |
+
"""
|
| 695 |
+
Compute the RMSD between this protein chain and another.
|
| 696 |
+
|
| 697 |
+
Args:
|
| 698 |
+
target (ProteinComplex): The target (other) protein complex to compare to.
|
| 699 |
+
also_check_reflection (bool, optional): If True, also check if the reflection of the mobile atoms has a lower RMSD.
|
| 700 |
+
mobile_inds (list[int], optional): The indices of the mobile atoms to align. These are NOT residue indices
|
| 701 |
+
target_inds (list[int], optional): The indices of the target atoms to align. These are NOT residue indices
|
| 702 |
+
only_compute_backbone_rmsd (bool, optional): If True, only compute the RMSD of the backbone atoms.
|
| 703 |
+
"""
|
| 704 |
+
if compute_chain_assignment:
|
| 705 |
+
aligned = self.dockq(target).aligned
|
| 706 |
+
else:
|
| 707 |
+
aligned = self
|
| 708 |
+
|
| 709 |
+
aligner = Aligner(
|
| 710 |
+
aligned if mobile_inds is None else aligned[mobile_inds],
|
| 711 |
+
target if target_inds is None else target[target_inds],
|
| 712 |
+
only_compute_backbone_rmsd,
|
| 713 |
+
)
|
| 714 |
+
avg_rmsd = aligner.rmsd
|
| 715 |
+
|
| 716 |
+
if not also_check_reflection:
|
| 717 |
+
return avg_rmsd
|
| 718 |
+
|
| 719 |
+
aligner = Aligner(
|
| 720 |
+
aligned if mobile_inds is None else aligned[mobile_inds],
|
| 721 |
+
target if target_inds is None else target[target_inds],
|
| 722 |
+
only_compute_backbone_rmsd,
|
| 723 |
+
use_reflection=True,
|
| 724 |
+
)
|
| 725 |
+
avg_rmsd_neg = aligner.rmsd
|
| 726 |
+
|
| 727 |
+
return min(avg_rmsd, avg_rmsd_neg)
|
| 728 |
+
|
| 729 |
+
def lddt_ca(
|
| 730 |
+
self,
|
| 731 |
+
target: ProteinComplex,
|
| 732 |
+
mobile_inds: list[int] | np.ndarray | None = None,
|
| 733 |
+
target_inds: list[int] | np.ndarray | None = None,
|
| 734 |
+
compute_chain_assignment: bool = True,
|
| 735 |
+
**kwargs,
|
| 736 |
+
) -> float | np.ndarray:
|
| 737 |
+
"""Compute the LDDT between this protein complex and another.
|
| 738 |
+
|
| 739 |
+
Arguments:
|
| 740 |
+
target (ProteinComplex): The other protein complex to compare to.
|
| 741 |
+
mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
|
| 742 |
+
target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
|
| 743 |
+
|
| 744 |
+
Returns:
|
| 745 |
+
float | np.ndarray: The LDDT score between the two protein chains, either
|
| 746 |
+
a single float or per-residue LDDT scores if `per_residue` is True.
|
| 747 |
+
"""
|
| 748 |
+
if compute_chain_assignment:
|
| 749 |
+
aligned = self.dockq(target).aligned
|
| 750 |
+
else:
|
| 751 |
+
aligned = self
|
| 752 |
+
lddt = compute_lddt_ca(
|
| 753 |
+
torch.tensor(aligned.atom37_positions[mobile_inds]).unsqueeze(0),
|
| 754 |
+
torch.tensor(target.atom37_positions[target_inds]).unsqueeze(0),
|
| 755 |
+
torch.tensor(aligned.atom37_mask[mobile_inds]).unsqueeze(0),
|
| 756 |
+
**kwargs,
|
| 757 |
+
)
|
| 758 |
+
return float(lddt) if lddt.numel() == 1 else lddt.numpy().flatten()
|
| 759 |
+
|
| 760 |
+
def gdt_ts(
|
| 761 |
+
self,
|
| 762 |
+
target: ProteinComplex,
|
| 763 |
+
mobile_inds: list[int] | np.ndarray | None = None,
|
| 764 |
+
target_inds: list[int] | np.ndarray | None = None,
|
| 765 |
+
compute_chain_assignment: bool = True,
|
| 766 |
+
**kwargs,
|
| 767 |
+
) -> float | np.ndarray:
|
| 768 |
+
"""Compute the GDT_TS between this protein complex and another.
|
| 769 |
+
|
| 770 |
+
Arguments:
|
| 771 |
+
target (ProteinComplex): The other protein complex to compare to.
|
| 772 |
+
mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
|
| 773 |
+
target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
|
| 774 |
+
|
| 775 |
+
Returns:
|
| 776 |
+
float: The GDT_TS score between the two protein chains.
|
| 777 |
+
"""
|
| 778 |
+
if compute_chain_assignment:
|
| 779 |
+
aligned = self.dockq(target).aligned
|
| 780 |
+
else:
|
| 781 |
+
aligned = self
|
| 782 |
+
gdt_ts = compute_gdt_ts(
|
| 783 |
+
mobile=torch.tensor(
|
| 784 |
+
index_by_atom_name(aligned.atom37_positions[mobile_inds], "CA"),
|
| 785 |
+
dtype=torch.float32,
|
| 786 |
+
).unsqueeze(0),
|
| 787 |
+
target=torch.tensor(
|
| 788 |
+
index_by_atom_name(target.atom37_positions[target_inds], "CA"),
|
| 789 |
+
dtype=torch.float32,
|
| 790 |
+
).unsqueeze(0),
|
| 791 |
+
atom_exists_mask=torch.tensor(
|
| 792 |
+
index_by_atom_name(aligned.atom37_mask[mobile_inds], "CA", dim=-1)
|
| 793 |
+
& index_by_atom_name(target.atom37_mask[target_inds], "CA", dim=-1)
|
| 794 |
+
).unsqueeze(0),
|
| 795 |
+
**kwargs,
|
| 796 |
+
)
|
| 797 |
+
return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
|
| 798 |
+
|
| 799 |
+
def dockq(self, native: ProteinComplex):
|
| 800 |
+
# This function uses dockqv2 to compute the DockQ score. Because it does a mapping
|
| 801 |
+
# over all possible chains, it's quite slow. Be careful not to use this in an inference loop
|
| 802 |
+
# or something that requires fast scoring. It defaults to 8 CPUs.
|
| 803 |
+
#
|
| 804 |
+
# TODO(@zeming): Because we haven't properly implemented protein complexes for mmcif,
|
| 805 |
+
# if your protein has multi-letter or repeated chain IDs, this will fail. Please call
|
| 806 |
+
# pc = pc.normalize_chain_ids_for_pdb() before calling this function in that case (limit is 62 chains)
|
| 807 |
+
|
| 808 |
+
try:
|
| 809 |
+
pass
|
| 810 |
+
except BaseException:
|
| 811 |
+
raise RuntimeError(
|
| 812 |
+
"DockQ is not installed. Please update your environment."
|
| 813 |
+
)
|
| 814 |
+
self._sanity_check_complexes_are_comparable(native)
|
| 815 |
+
|
| 816 |
+
def sanity_check_chain_ids(pc: ProteinComplex):
|
| 817 |
+
ids = []
|
| 818 |
+
for i, chain in enumerate(pc.chain_iter()):
|
| 819 |
+
if i > len(SINGLE_LETTER_CHAIN_IDS):
|
| 820 |
+
raise ValueError("Too many chains to write to PDB file")
|
| 821 |
+
if len(chain.chain_id) > 1:
|
| 822 |
+
raise ValueError(
|
| 823 |
+
"We only supports single letter chain IDs for DockQ"
|
| 824 |
+
)
|
| 825 |
+
ids.append(chain.chain_id)
|
| 826 |
+
if len(set(ids)) != len(ids):
|
| 827 |
+
raise ValueError(f"Duplicate chain IDs in protein complex: {ids}")
|
| 828 |
+
return ids
|
| 829 |
+
|
| 830 |
+
sanity_check_chain_ids(self)
|
| 831 |
+
sanity_check_chain_ids(native)
|
| 832 |
+
|
| 833 |
+
with TemporaryDirectory() as tdir:
|
| 834 |
+
dir = Path(tdir)
|
| 835 |
+
self.to_pdb(dir / "self.pdb")
|
| 836 |
+
native.to_pdb(dir / "native.pdb")
|
| 837 |
+
|
| 838 |
+
output = check_output(["DockQ", dir / "self.pdb", dir / "native.pdb"])
|
| 839 |
+
lines = output.decode().split("\n")
|
| 840 |
+
|
| 841 |
+
# Remove the header comments
|
| 842 |
+
start_index = next(
|
| 843 |
+
i for i, line in enumerate(lines) if line.startswith("Model")
|
| 844 |
+
)
|
| 845 |
+
lines = lines[start_index:]
|
| 846 |
+
|
| 847 |
+
result = {}
|
| 848 |
+
interfaces = []
|
| 849 |
+
current_interface: dict = {}
|
| 850 |
+
|
| 851 |
+
for line in lines:
|
| 852 |
+
line = line.strip()
|
| 853 |
+
if not line:
|
| 854 |
+
continue
|
| 855 |
+
|
| 856 |
+
if line.startswith("Model :"):
|
| 857 |
+
pass # Tmp pdb file location, it's useless...
|
| 858 |
+
elif line.startswith("Native :"):
|
| 859 |
+
pass # Tmp pdb file location, it's useless...
|
| 860 |
+
elif line.startswith("Total DockQ"):
|
| 861 |
+
total_dockq_match = re.search(
|
| 862 |
+
r"Total DockQ over (\d+) native interfaces: ([\d.]+) with (.*) model:native mapping",
|
| 863 |
+
line,
|
| 864 |
+
)
|
| 865 |
+
if total_dockq_match:
|
| 866 |
+
result["value"] = float(total_dockq_match.group(2))
|
| 867 |
+
result["native interfaces"] = int(total_dockq_match.group(1))
|
| 868 |
+
native_chains, self_chains = total_dockq_match.group(3).split(":")
|
| 869 |
+
result["mapping"] = dict(zip(native_chains, self_chains))
|
| 870 |
+
else:
|
| 871 |
+
raise RuntimeError(
|
| 872 |
+
"Failed to parse DockQ output, maybe your DockQ version is wrong?"
|
| 873 |
+
)
|
| 874 |
+
elif line.startswith("Native chains:"):
|
| 875 |
+
if current_interface:
|
| 876 |
+
interfaces.append(current_interface)
|
| 877 |
+
current_interface = {
|
| 878 |
+
"Native chains": line.split(":")[1].strip().split(", ")
|
| 879 |
+
}
|
| 880 |
+
elif line.startswith("Model chains:"):
|
| 881 |
+
current_interface["Model chains"] = (
|
| 882 |
+
line.split(":")[1].strip().split(", ")
|
| 883 |
+
)
|
| 884 |
+
elif ":" in line:
|
| 885 |
+
key, value = line.split(":", 1)
|
| 886 |
+
current_interface[key.strip()] = float(value.strip())
|
| 887 |
+
|
| 888 |
+
if current_interface:
|
| 889 |
+
interfaces.append(current_interface)
|
| 890 |
+
|
| 891 |
+
def parse_dict(d: dict[str, Any]) -> DockQSingleScore:
|
| 892 |
+
return DockQSingleScore(
|
| 893 |
+
native_chains=tuple(d["Native chains"]), # type: ignore
|
| 894 |
+
DockQ=float(d["DockQ"]),
|
| 895 |
+
interface_rms=float(d["irms"]),
|
| 896 |
+
ligand_rms=float(d["Lrms"]), # Note the capitalization difference
|
| 897 |
+
fnat=float(d["fnat"]),
|
| 898 |
+
fnonnat=float(d["fnonnat"]),
|
| 899 |
+
clashes=float(d["clashes"]),
|
| 900 |
+
F1=float(d["F1"]),
|
| 901 |
+
DockQ_F1=float(d["DockQ_F1"]),
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
inv_mapping = {v: k for k, v in result["mapping"].items()}
|
| 905 |
+
|
| 906 |
+
self_chain_map = {c.chain_id: c for c in self.chain_iter()}
|
| 907 |
+
realigned = []
|
| 908 |
+
for chain in native.chain_iter():
|
| 909 |
+
realigned.append(self_chain_map[inv_mapping[chain.chain_id]])
|
| 910 |
+
|
| 911 |
+
realigned = ProteinComplex.from_chains(realigned)
|
| 912 |
+
aligner = Aligner(realigned, native)
|
| 913 |
+
realigned = aligner.apply(realigned)
|
| 914 |
+
|
| 915 |
+
result = DockQResult(
|
| 916 |
+
total_dockq=result["value"],
|
| 917 |
+
native_interfaces=result["native interfaces"],
|
| 918 |
+
chain_mapping=result["mapping"],
|
| 919 |
+
interfaces={
|
| 920 |
+
(i["Model chains"][0], i["Model chains"][1]): parse_dict(i)
|
| 921 |
+
for i in interfaces
|
| 922 |
+
},
|
| 923 |
+
aligned=realigned,
|
| 924 |
+
aligned_rmsd=aligner.rmsd,
|
| 925 |
+
)
|
| 926 |
+
|
| 927 |
+
return result
|
| 928 |
+
|
| 929 |
+
@cached_property
|
| 930 |
+
def per_chain_kd_trees(self):
|
| 931 |
+
# Iterate over chains, build KDTree for each chain
|
| 932 |
+
kdtrees = []
|
| 933 |
+
|
| 934 |
+
CA = self.atoms["CA"]
|
| 935 |
+
|
| 936 |
+
for start, end in self.chain_boundaries:
|
| 937 |
+
chain_CA = CA[start:end]
|
| 938 |
+
chain_CA = chain_CA[np.isfinite(chain_CA).all(axis=-1)]
|
| 939 |
+
kdtrees.append(KDTree(chain_CA))
|
| 940 |
+
|
| 941 |
+
return kdtrees
|
| 942 |
+
|
| 943 |
+
def chain_adjacency(self, cutoff: float = 8.0) -> np.ndarray:
|
| 944 |
+
# Compute adjacency matrix for protein complex
|
| 945 |
+
num_chains = self.num_chains
|
| 946 |
+
adjacency = np.zeros((num_chains, num_chains), dtype=bool)
|
| 947 |
+
for (i, kdtree), (j, kdtree2) in itertools.combinations(
|
| 948 |
+
enumerate(self.per_chain_kd_trees), 2
|
| 949 |
+
):
|
| 950 |
+
adj = kdtree.query_ball_tree(kdtree2, cutoff)
|
| 951 |
+
any_is_adjacent = any(len(a) > 0 for a in adj)
|
| 952 |
+
adjacency[i, j] = any_is_adjacent
|
| 953 |
+
adjacency[j, i] = any_is_adjacent
|
| 954 |
+
return adjacency
|
| 955 |
+
|
| 956 |
+
def chain_adjacency_by_index(self, index: int, cutoff: float = 8.0) -> np.ndarray:
|
| 957 |
+
num_chains = len(self.chain_boundaries)
|
| 958 |
+
adjacency = np.zeros(num_chains, dtype=bool)
|
| 959 |
+
for i, kdtree in enumerate(self.per_chain_kd_trees):
|
| 960 |
+
if i == index:
|
| 961 |
+
continue
|
| 962 |
+
adj = kdtree.query_ball_tree(self.per_chain_kd_trees[index], cutoff)
|
| 963 |
+
adjacency[i] = any(len(a) > 0 for a in adj)
|
| 964 |
+
return adjacency
|
| 965 |
+
|
| 966 |
+
def add_prefix_to_chain_ids(self, prefix: str) -> ProteinComplex:
|
| 967 |
+
"""Rename all chains in the complex with a given prefix.
|
| 968 |
+
|
| 969 |
+
Args:
|
| 970 |
+
prefix (str): The prefix to use for the new chain IDs. Each chain will be
|
| 971 |
+
named as "{prefix}_{chain_id}".
|
| 972 |
+
|
| 973 |
+
Returns:
|
| 974 |
+
ProteinComplex: A new protein complex with renamed chains.
|
| 975 |
+
"""
|
| 976 |
+
new_chains = []
|
| 977 |
+
for chain in self.chain_iter():
|
| 978 |
+
# Create new chain with updated chain_id
|
| 979 |
+
new_chain = replace(chain, chain_id=f"{prefix}_{chain.chain_id}")
|
| 980 |
+
new_chains.append(new_chain)
|
| 981 |
+
return ProteinComplex.from_chains(new_chains)
|
| 982 |
+
|
| 983 |
+
def sasa(self, by_residue: bool = True):
|
| 984 |
+
chain = self.as_chain(force_conversion=True)
|
| 985 |
+
return chain.sasa(by_residue=by_residue)
|
| 986 |
+
|
| 987 |
+
def to_mmcif_string(self) -> str:
|
| 988 |
+
"""Convert the ProteinComplex to mmCIF format.
|
| 989 |
+
|
| 990 |
+
Returns:
|
| 991 |
+
str: The mmCIF content as a string.
|
| 992 |
+
"""
|
| 993 |
+
# Convert the ProteinComplex to a biotite AtomArray
|
| 994 |
+
# Collect all atoms from all chains
|
| 995 |
+
all_atoms = []
|
| 996 |
+
for chain in self.chain_iter():
|
| 997 |
+
chain_atom_array = chain.atom_array
|
| 998 |
+
# Convert AtomArray to list of atoms and add to collection
|
| 999 |
+
all_atoms.extend(chain_atom_array)
|
| 1000 |
+
|
| 1001 |
+
# Create combined AtomArray from all atoms
|
| 1002 |
+
if not all_atoms:
|
| 1003 |
+
raise ValueError("No atoms found in protein complex")
|
| 1004 |
+
|
| 1005 |
+
atom_array = bs.array(all_atoms)
|
| 1006 |
+
|
| 1007 |
+
# Create CIF file
|
| 1008 |
+
f = CIFFile()
|
| 1009 |
+
set_structure_pdbx(f, atom_array, data_block=self.id)
|
| 1010 |
+
|
| 1011 |
+
# Add entity information for proper mmCIF structure
|
| 1012 |
+
self._add_entity_information(f)
|
| 1013 |
+
|
| 1014 |
+
# Write to string
|
| 1015 |
+
output = io.StringIO()
|
| 1016 |
+
f.write(output)
|
| 1017 |
+
return output.getvalue()
|
| 1018 |
+
|
| 1019 |
+
def _add_entity_information(self, cif_file: CIFFile) -> None:
|
| 1020 |
+
"""Add entity, entity_poly, and struct_asym sections to CIF file."""
|
| 1021 |
+
|
| 1022 |
+
# Group chains by sequence to create unique entities
|
| 1023 |
+
entity_map = {} # sequence -> entity_id
|
| 1024 |
+
chain_to_entity = {} # chain_id -> entity_id
|
| 1025 |
+
entity_sequences = {} # entity_id -> sequence
|
| 1026 |
+
entity_id_counter = 1
|
| 1027 |
+
|
| 1028 |
+
for chain in self.chain_iter():
|
| 1029 |
+
sequence = chain.sequence
|
| 1030 |
+
if sequence not in entity_map:
|
| 1031 |
+
entity_map[sequence] = entity_id_counter
|
| 1032 |
+
entity_sequences[entity_id_counter] = sequence
|
| 1033 |
+
entity_id_counter += 1
|
| 1034 |
+
chain_to_entity[chain.chain_id] = entity_map[sequence]
|
| 1035 |
+
|
| 1036 |
+
# Create _entity section
|
| 1037 |
+
entity_ids = []
|
| 1038 |
+
entity_types = []
|
| 1039 |
+
entity_descriptions = []
|
| 1040 |
+
|
| 1041 |
+
for entity_id in sorted(entity_sequences.keys()):
|
| 1042 |
+
entity_ids.append(str(entity_id))
|
| 1043 |
+
entity_types.append("polymer")
|
| 1044 |
+
entity_descriptions.append(f"Protein chain (entity {entity_id})")
|
| 1045 |
+
|
| 1046 |
+
cif_file.block["entity"] = CIFCategory(
|
| 1047 |
+
name="entity",
|
| 1048 |
+
columns={
|
| 1049 |
+
"id": CIFColumn(
|
| 1050 |
+
data=CIFData(array=np.array(entity_ids), dtype=np.str_)
|
| 1051 |
+
),
|
| 1052 |
+
"type": CIFColumn(
|
| 1053 |
+
data=CIFData(array=np.array(entity_types), dtype=np.str_)
|
| 1054 |
+
),
|
| 1055 |
+
"pdbx_description": CIFColumn(
|
| 1056 |
+
data=CIFData(array=np.array(entity_descriptions), dtype=np.str_)
|
| 1057 |
+
),
|
| 1058 |
+
},
|
| 1059 |
+
)
|
| 1060 |
+
|
| 1061 |
+
# Create _entity_poly section
|
| 1062 |
+
poly_entity_ids = []
|
| 1063 |
+
poly_types = []
|
| 1064 |
+
poly_nstd_linkages = []
|
| 1065 |
+
poly_sequences = []
|
| 1066 |
+
|
| 1067 |
+
for entity_id in sorted(entity_sequences.keys()):
|
| 1068 |
+
poly_entity_ids.append(str(entity_id))
|
| 1069 |
+
poly_types.append("polypeptide(L)")
|
| 1070 |
+
poly_nstd_linkages.append("no")
|
| 1071 |
+
poly_sequences.append(entity_sequences[entity_id])
|
| 1072 |
+
|
| 1073 |
+
cif_file.block["entity_poly"] = CIFCategory(
|
| 1074 |
+
name="entity_poly",
|
| 1075 |
+
columns={
|
| 1076 |
+
"entity_id": CIFColumn(
|
| 1077 |
+
data=CIFData(array=np.array(poly_entity_ids), dtype=np.str_)
|
| 1078 |
+
),
|
| 1079 |
+
"type": CIFColumn(
|
| 1080 |
+
data=CIFData(array=np.array(poly_types), dtype=np.str_)
|
| 1081 |
+
),
|
| 1082 |
+
"nstd_linkage": CIFColumn(
|
| 1083 |
+
data=CIFData(array=np.array(poly_nstd_linkages), dtype=np.str_)
|
| 1084 |
+
),
|
| 1085 |
+
"pdbx_seq_one_letter_code": CIFColumn(
|
| 1086 |
+
data=CIFData(array=np.array(poly_sequences), dtype=np.str_)
|
| 1087 |
+
),
|
| 1088 |
+
},
|
| 1089 |
+
)
|
| 1090 |
+
|
| 1091 |
+
# Create _struct_asym section
|
| 1092 |
+
asym_ids = []
|
| 1093 |
+
asym_entity_ids = []
|
| 1094 |
+
asym_details = []
|
| 1095 |
+
|
| 1096 |
+
for chain in self.chain_iter():
|
| 1097 |
+
asym_ids.append(chain.chain_id)
|
| 1098 |
+
asym_entity_ids.append(str(chain_to_entity[chain.chain_id]))
|
| 1099 |
+
asym_details.append("")
|
| 1100 |
+
|
| 1101 |
+
cif_file.block["struct_asym"] = CIFCategory(
|
| 1102 |
+
name="struct_asym",
|
| 1103 |
+
columns={
|
| 1104 |
+
"id": CIFColumn(data=CIFData(array=np.array(asym_ids), dtype=np.str_)),
|
| 1105 |
+
"entity_id": CIFColumn(
|
| 1106 |
+
data=CIFData(array=np.array(asym_entity_ids), dtype=np.str_)
|
| 1107 |
+
),
|
| 1108 |
+
"details": CIFColumn(
|
| 1109 |
+
data=CIFData(array=np.array(asym_details), dtype=np.str_)
|
| 1110 |
+
),
|
| 1111 |
+
},
|
| 1112 |
+
)
|
| 1113 |
+
|
| 1114 |
+
|
| 1115 |
+
def get_assembly_fast(
|
| 1116 |
+
mmcif: MmcifWrapper,
|
| 1117 |
+
assembly_id=None,
|
| 1118 |
+
model=None,
|
| 1119 |
+
data_block=None,
|
| 1120 |
+
altloc="first",
|
| 1121 |
+
use_author_fields=True,
|
| 1122 |
+
):
|
| 1123 |
+
pdbx_file = mmcif.raw
|
| 1124 |
+
if pdbx_file is None:
|
| 1125 |
+
raise InvalidFileError("No mmCIF data loaded")
|
| 1126 |
+
assembly_gen_category = pdbx_file.block["pdbx_struct_assembly_gen"]
|
| 1127 |
+
if assembly_gen_category is None:
|
| 1128 |
+
raise InvalidFileError("File has no 'pdbx_struct_assembly_gen' category")
|
| 1129 |
+
|
| 1130 |
+
struct_oper_category = pdbx_file.block["pdbx_struct_oper_list"]
|
| 1131 |
+
if struct_oper_category is None:
|
| 1132 |
+
raise InvalidFileError("File has no 'pdbx_struct_oper_list' category")
|
| 1133 |
+
|
| 1134 |
+
if assembly_id is None:
|
| 1135 |
+
assembly_id = assembly_gen_category["assembly_id"].data.array[0]
|
| 1136 |
+
elif assembly_id not in assembly_gen_category["assembly_id"].data.array:
|
| 1137 |
+
raise KeyError(f"File has no Assembly ID '{assembly_id}'")
|
| 1138 |
+
|
| 1139 |
+
### Calculate all possible transformations
|
| 1140 |
+
transformations = _get_transformations(struct_oper_category)
|
| 1141 |
+
|
| 1142 |
+
### Get structure according to additional parameters
|
| 1143 |
+
structure = get_structure(
|
| 1144 |
+
pdbx_file, model, data_block, altloc, ["label_asym_id"], use_author_fields
|
| 1145 |
+
)[0] # type: ignore
|
| 1146 |
+
# TODO(@zeming) This line will remove all non-protein structural elements,
|
| 1147 |
+
# we should remove this when we want to parse these too.
|
| 1148 |
+
structure: bs.AtomArray = structure[
|
| 1149 |
+
bs.filter_amino_acids(structure) & ~structure.hetero # type: ignore
|
| 1150 |
+
]
|
| 1151 |
+
if len(structure) == 0:
|
| 1152 |
+
raise NoProteinError
|
| 1153 |
+
unique_asym_ids = np.unique(structure.label_asym_id) # type: ignore
|
| 1154 |
+
asym2chain = {}
|
| 1155 |
+
asym2auth = {}
|
| 1156 |
+
for asym_id in unique_asym_ids:
|
| 1157 |
+
sub_structure: bs.AtomArray = structure[structure.label_asym_id == asym_id] # type: ignore
|
| 1158 |
+
chain_id: str = sub_structure[0].chain_id # type: ignore
|
| 1159 |
+
(
|
| 1160 |
+
sequence,
|
| 1161 |
+
atom_positions,
|
| 1162 |
+
atom_mask,
|
| 1163 |
+
residue_index,
|
| 1164 |
+
insertion_code,
|
| 1165 |
+
confidence,
|
| 1166 |
+
entity_id,
|
| 1167 |
+
) = chain_to_ndarray(sub_structure, mmcif, chain_id, False)
|
| 1168 |
+
|
| 1169 |
+
asym2chain[asym_id] = ProteinChain(
|
| 1170 |
+
id=mmcif.id or "unknown",
|
| 1171 |
+
sequence=sequence,
|
| 1172 |
+
chain_id=chain_id,
|
| 1173 |
+
entity_id=entity_id,
|
| 1174 |
+
atom37_positions=atom_positions,
|
| 1175 |
+
atom37_mask=atom_mask,
|
| 1176 |
+
residue_index=residue_index,
|
| 1177 |
+
insertion_code=insertion_code,
|
| 1178 |
+
confidence=confidence,
|
| 1179 |
+
mmcif=None,
|
| 1180 |
+
)
|
| 1181 |
+
asym2auth[asym_id] = chain_id
|
| 1182 |
+
|
| 1183 |
+
### Get transformations and apply them to the affected asym IDs
|
| 1184 |
+
assembly = []
|
| 1185 |
+
assembly_id_dict: dict[str, list[str]] = {}
|
| 1186 |
+
|
| 1187 |
+
# Process the target assembly ID
|
| 1188 |
+
for aid, op_expr, asym_id_expr in zip(
|
| 1189 |
+
assembly_gen_category["assembly_id"].data.array,
|
| 1190 |
+
assembly_gen_category["oper_expression"].data.array,
|
| 1191 |
+
assembly_gen_category["asym_id_list"].data.array,
|
| 1192 |
+
):
|
| 1193 |
+
if aid == assembly_id:
|
| 1194 |
+
# Parse operations and asym IDs for this specific entry
|
| 1195 |
+
operations = _parse_operation_expression(op_expr)
|
| 1196 |
+
asym_ids = asym_id_expr.split(",")
|
| 1197 |
+
|
| 1198 |
+
# Filter affected asym IDs to only protein chains, preserving order
|
| 1199 |
+
sub_structures = [
|
| 1200 |
+
asym2chain[asym_id] for asym_id in asym_ids if asym_id in asym2chain
|
| 1201 |
+
]
|
| 1202 |
+
|
| 1203 |
+
# Apply transformations
|
| 1204 |
+
sub_assembly = _apply_transformations_fast(
|
| 1205 |
+
sub_structures, transformations, operations
|
| 1206 |
+
)
|
| 1207 |
+
assembly.extend(sub_assembly)
|
| 1208 |
+
|
| 1209 |
+
# Build assembly_id_dict for this entry
|
| 1210 |
+
assembly_id_dict[aid] = assembly_id_dict.get(aid, []) + [
|
| 1211 |
+
asym2auth[id_] for id_ in asym_ids if id_ in asym2auth
|
| 1212 |
+
]
|
| 1213 |
+
|
| 1214 |
+
if len(assembly) == 0:
|
| 1215 |
+
raise NoProteinError
|
| 1216 |
+
return ProteinComplex.from_chains(assembly, mmcif, assembly_id_dict)
|
| 1217 |
+
|
| 1218 |
+
|
| 1219 |
+
def protein_chain_to_protein_complex(chain: ProteinChain) -> ProteinComplex:
|
| 1220 |
+
if "|" not in chain.sequence:
|
| 1221 |
+
return ProteinComplex.from_chains([chain])
|
| 1222 |
+
chain_breaks = np.array(list(chain.sequence)) == "|"
|
| 1223 |
+
chain_break_inds = np.where(chain_breaks)[0]
|
| 1224 |
+
chain_break_inds = np.concatenate([[0], chain_break_inds, [len(chain)]])
|
| 1225 |
+
chain_break_inds = np.array(list(zip(chain_break_inds[:-1], chain_break_inds[1:])))
|
| 1226 |
+
complex_chains = []
|
| 1227 |
+
for start, end in chain_break_inds:
|
| 1228 |
+
if start != 0:
|
| 1229 |
+
start += 1
|
| 1230 |
+
complex_chains.append(chain[start:end])
|
| 1231 |
+
complex_chains = [
|
| 1232 |
+
ProteinChain.from_atom37(
|
| 1233 |
+
chain.atom37_positions,
|
| 1234 |
+
sequence=chain.sequence,
|
| 1235 |
+
chain_id=SINGLE_LETTER_CHAIN_IDS[i],
|
| 1236 |
+
entity_id=i,
|
| 1237 |
+
)
|
| 1238 |
+
for i, chain in enumerate(complex_chains)
|
| 1239 |
+
]
|
| 1240 |
+
return ProteinComplex.from_chains(complex_chains)
|
| 1241 |
+
|
esmfold2_protein_structure.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Tuple, TypeVar
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.amp import autocast # type: ignore
|
| 10 |
+
|
| 11 |
+
from . import esmfold2_residue_constants
|
| 12 |
+
from .esmfold2_misc import unbinpack
|
| 13 |
+
from .esmfold2_affine3d import Affine3D
|
| 14 |
+
|
| 15 |
+
ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def index_by_atom_name(
|
| 19 |
+
atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2
|
| 20 |
+
) -> ArrayOrTensor:
|
| 21 |
+
squeeze = False
|
| 22 |
+
if isinstance(atom_names, str):
|
| 23 |
+
atom_names = [atom_names]
|
| 24 |
+
squeeze = True
|
| 25 |
+
indices = [residue_constants.atom_order[atom_name] for atom_name in atom_names]
|
| 26 |
+
dim = dim % atom37.ndim
|
| 27 |
+
index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim))
|
| 28 |
+
result = atom37[index] # type: ignore
|
| 29 |
+
if squeeze:
|
| 30 |
+
result = result.squeeze(dim)
|
| 31 |
+
return result
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def infer_cbeta_from_atom37(
|
| 35 |
+
atom37: ArrayOrTensor, L: float = 1.522, A: float = 1.927, D: float = -2.143
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Inspired by a util in trDesign:
|
| 39 |
+
https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L92
|
| 40 |
+
|
| 41 |
+
input: atom37, (L)ength, (A)ngle, and (D)ihedral
|
| 42 |
+
output: 4th coord
|
| 43 |
+
"""
|
| 44 |
+
N = index_by_atom_name(atom37, "N", dim=-2)
|
| 45 |
+
CA = index_by_atom_name(atom37, "CA", dim=-2)
|
| 46 |
+
C = index_by_atom_name(atom37, "C", dim=-2)
|
| 47 |
+
|
| 48 |
+
if isinstance(atom37, np.ndarray):
|
| 49 |
+
|
| 50 |
+
def normalize(x: ArrayOrTensor):
|
| 51 |
+
return x / np.linalg.norm(x, axis=-1, keepdims=True)
|
| 52 |
+
|
| 53 |
+
cross = np.cross
|
| 54 |
+
else:
|
| 55 |
+
normalize = F.normalize # type: ignore
|
| 56 |
+
cross = torch.cross
|
| 57 |
+
|
| 58 |
+
with np.errstate(invalid="ignore"): # inf - inf = nan is ok here
|
| 59 |
+
vec_nca = N - CA
|
| 60 |
+
vec_nc = N - C
|
| 61 |
+
nca = normalize(vec_nca)
|
| 62 |
+
n = normalize(cross(vec_nc, nca)) # type: ignore
|
| 63 |
+
m = [nca, cross(n, nca), n]
|
| 64 |
+
d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
|
| 65 |
+
return CA + sum([m * d for m, d in zip(m, d)])
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@torch.no_grad()
|
| 69 |
+
@autocast("cuda", enabled=False)
|
| 70 |
+
def compute_alignment_tensors(
|
| 71 |
+
mobile: torch.Tensor,
|
| 72 |
+
target: torch.Tensor,
|
| 73 |
+
atom_exists_mask: torch.Tensor | None = None,
|
| 74 |
+
sequence_id: torch.Tensor | None = None,
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Align two batches of structures with support for masking invalid atoms using PyTorch.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
|
| 81 |
+
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
|
| 82 |
+
- atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
|
| 83 |
+
- sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
- centered_mobile (torch.Tensor): Batch of coordinates of structure centered mobile (B, N, 3)
|
| 87 |
+
- centroid_mobile (torch.Tensor): Batch of coordinates of mobile centeroid (B, 3)
|
| 88 |
+
- centered_target (torch.Tensor): Batch of coordinates of structure centered target (B, N, 3)
|
| 89 |
+
- centroid_target (torch.Tensor): Batch of coordinates of target centeroid (B, 3)
|
| 90 |
+
- rotation_matrix (torch.Tensor): Batch of coordinates of rotation matrix (B, 3, 3)
|
| 91 |
+
- num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,)
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
# Ensure both batches have the same number of structures, atoms, and dimensions
|
| 95 |
+
if sequence_id is not None:
|
| 96 |
+
mobile = unbinpack(mobile, sequence_id, pad_value=torch.nan)
|
| 97 |
+
target = unbinpack(target, sequence_id, pad_value=torch.nan)
|
| 98 |
+
if atom_exists_mask is not None:
|
| 99 |
+
atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=0)
|
| 100 |
+
else:
|
| 101 |
+
atom_exists_mask = torch.isfinite(target).all(-1)
|
| 102 |
+
|
| 103 |
+
assert mobile.shape == target.shape, "Batch structure shapes do not match!"
|
| 104 |
+
|
| 105 |
+
# Number of structures in the batch
|
| 106 |
+
batch_size = mobile.shape[0]
|
| 107 |
+
|
| 108 |
+
# if [B, Nres, Natom, 3], resize
|
| 109 |
+
if mobile.dim() == 4:
|
| 110 |
+
mobile = mobile.view(batch_size, -1, 3)
|
| 111 |
+
if target.dim() == 4:
|
| 112 |
+
target = target.view(batch_size, -1, 3)
|
| 113 |
+
if atom_exists_mask is not None and atom_exists_mask.dim() == 3:
|
| 114 |
+
atom_exists_mask = atom_exists_mask.view(batch_size, -1)
|
| 115 |
+
|
| 116 |
+
# Number of atoms
|
| 117 |
+
num_atoms = mobile.shape[1]
|
| 118 |
+
|
| 119 |
+
# Apply masks if provided
|
| 120 |
+
if atom_exists_mask is not None:
|
| 121 |
+
mobile = mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
|
| 122 |
+
target = target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
|
| 123 |
+
else:
|
| 124 |
+
atom_exists_mask = torch.ones(
|
| 125 |
+
batch_size, num_atoms, dtype=torch.bool, device=mobile.device
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
num_valid_atoms = atom_exists_mask.sum(dim=-1, keepdim=True)
|
| 129 |
+
# Compute centroids for each batch
|
| 130 |
+
centroid_mobile = mobile.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1)
|
| 131 |
+
centroid_target = target.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1)
|
| 132 |
+
|
| 133 |
+
# Handle potential division by zero if all atoms are invalid in a structure
|
| 134 |
+
centroid_mobile[num_valid_atoms == 0] = 0
|
| 135 |
+
centroid_target[num_valid_atoms == 0] = 0
|
| 136 |
+
|
| 137 |
+
# Center structures by subtracting centroids
|
| 138 |
+
centered_mobile = mobile - centroid_mobile
|
| 139 |
+
centered_target = target - centroid_target
|
| 140 |
+
|
| 141 |
+
centered_mobile = centered_mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
|
| 142 |
+
centered_target = centered_target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
|
| 143 |
+
|
| 144 |
+
# Compute covariance matrix for each batch
|
| 145 |
+
covariance_matrix = torch.matmul(centered_mobile.transpose(1, 2), centered_target)
|
| 146 |
+
|
| 147 |
+
# Singular Value Decomposition for each batch
|
| 148 |
+
u, _, v = torch.svd(covariance_matrix)
|
| 149 |
+
|
| 150 |
+
# Calculate rotation matrices for each batch
|
| 151 |
+
rotation_matrix = torch.matmul(u, v.transpose(1, 2))
|
| 152 |
+
|
| 153 |
+
return (
|
| 154 |
+
centered_mobile,
|
| 155 |
+
centroid_mobile,
|
| 156 |
+
centered_target,
|
| 157 |
+
centroid_target,
|
| 158 |
+
rotation_matrix,
|
| 159 |
+
num_valid_atoms,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@torch.no_grad()
|
| 164 |
+
@autocast("cuda", enabled=False)
|
| 165 |
+
def compute_rmsd_no_alignment(
|
| 166 |
+
aligned: torch.Tensor,
|
| 167 |
+
target: torch.Tensor,
|
| 168 |
+
num_valid_atoms: torch.Tensor,
|
| 169 |
+
reduction: str = "batch",
|
| 170 |
+
) -> torch.Tensor:
|
| 171 |
+
"""
|
| 172 |
+
Compute RMSD between two batches of structures without alignment.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
|
| 176 |
+
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
|
| 177 |
+
- num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,)
|
| 178 |
+
- reduction (str): One of "batch", "per_sample", "per_residue".
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
|
| 182 |
+
If reduction == "batch":
|
| 183 |
+
(torch.Tensor): 0-dim, Average Root Mean Square Deviation between the structures for each batch
|
| 184 |
+
If reduction == "per_sample":
|
| 185 |
+
(torch.Tensor): (B,)-dim, Root Mean Square Deviation between the structures for each batch
|
| 186 |
+
If reduction == "per_residue":
|
| 187 |
+
(torch.Tensor): (B, N)-dim, Root Mean Square Deviation between the structures for residue in the batch
|
| 188 |
+
"""
|
| 189 |
+
if reduction not in ("per_residue", "per_sample", "batch"):
|
| 190 |
+
raise ValueError("Unrecognized reduction: '{reduction}'")
|
| 191 |
+
# Compute RMSD for each batch
|
| 192 |
+
diff = aligned - target
|
| 193 |
+
if reduction == "per_residue":
|
| 194 |
+
mean_squared_error = diff.square().view(diff.size(0), -1, 9).mean(dim=-1)
|
| 195 |
+
else:
|
| 196 |
+
mean_squared_error = diff.square().sum(dim=(1, 2)) / (
|
| 197 |
+
num_valid_atoms.squeeze(-1)
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
rmsd = torch.sqrt(mean_squared_error)
|
| 201 |
+
if reduction in ("per_sample", "per_residue"):
|
| 202 |
+
return rmsd
|
| 203 |
+
elif reduction == "batch":
|
| 204 |
+
avg_rmsd = rmsd.masked_fill(num_valid_atoms.squeeze(-1) == 0, 0).sum() / (
|
| 205 |
+
(num_valid_atoms > 0).sum() + 1e-8
|
| 206 |
+
)
|
| 207 |
+
return avg_rmsd
|
| 208 |
+
else:
|
| 209 |
+
raise ValueError(reduction)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@torch.no_grad()
|
| 213 |
+
@autocast("cuda", enabled=False)
|
| 214 |
+
def compute_affine_and_rmsd(
|
| 215 |
+
mobile: torch.Tensor,
|
| 216 |
+
target: torch.Tensor,
|
| 217 |
+
atom_exists_mask: torch.Tensor | None = None,
|
| 218 |
+
sequence_id: torch.Tensor | None = None,
|
| 219 |
+
) -> Tuple[Affine3D, torch.Tensor]:
|
| 220 |
+
"""
|
| 221 |
+
Compute RMSD between two batches of structures with support for masking invalid atoms using PyTorch.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
|
| 225 |
+
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
|
| 226 |
+
- atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
|
| 227 |
+
- sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
- affine (Affine3D): Transformation between mobile and target structure
|
| 231 |
+
- avg_rmsd (torch.Tensor): Average Root Mean Square Deviation between the structures for each batch
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
(
|
| 235 |
+
centered_mobile,
|
| 236 |
+
centroid_mobile,
|
| 237 |
+
centered_target,
|
| 238 |
+
centroid_target,
|
| 239 |
+
rotation_matrix,
|
| 240 |
+
num_valid_atoms,
|
| 241 |
+
) = compute_alignment_tensors(
|
| 242 |
+
mobile=mobile,
|
| 243 |
+
target=target,
|
| 244 |
+
atom_exists_mask=atom_exists_mask,
|
| 245 |
+
sequence_id=sequence_id,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Apply rotation to mobile centroid
|
| 249 |
+
translation = torch.matmul(-centroid_mobile, rotation_matrix) + centroid_target
|
| 250 |
+
affine = Affine3D.from_tensor_pair(
|
| 251 |
+
translation, rotation_matrix.unsqueeze(dim=-3).transpose(-2, -1)
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Apply transformation to centered structure to compute rmsd
|
| 255 |
+
rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
|
| 256 |
+
avg_rmsd = compute_rmsd_no_alignment(
|
| 257 |
+
rotated_mobile, centered_target, num_valid_atoms, reduction="batch"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
return affine, avg_rmsd
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def compute_gdt_ts_no_alignment(
|
| 264 |
+
aligned: torch.Tensor,
|
| 265 |
+
target: torch.Tensor,
|
| 266 |
+
atom_exists_mask: torch.Tensor,
|
| 267 |
+
reduction: str = "batch",
|
| 268 |
+
) -> torch.Tensor:
|
| 269 |
+
"""
|
| 270 |
+
Compute GDT_TS between two batches of structures without alignment.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
- mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
|
| 274 |
+
- target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
|
| 275 |
+
- atom_exists_mask (torch.Tensor): Mask for Whether an atom exists of shape (B, N). noo
|
| 276 |
+
- reduction (str): One of "batch", "per_sample".
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
If reduction == "batch":
|
| 280 |
+
(torch.Tensor): 0-dim, GDT_TS between the structures for each batch
|
| 281 |
+
If reduction == "per_sample":
|
| 282 |
+
(torch.Tensor): (B,)-dim, GDT_TS between the structures for each sample in the batch
|
| 283 |
+
"""
|
| 284 |
+
if reduction not in ("per_sample", "batch"):
|
| 285 |
+
raise ValueError("Unrecognized reduction: '{reduction}'")
|
| 286 |
+
|
| 287 |
+
if atom_exists_mask is None:
|
| 288 |
+
atom_exists_mask = torch.isfinite(target).all(dim=-1)
|
| 289 |
+
|
| 290 |
+
deviation = torch.linalg.vector_norm(aligned - target, dim=-1)
|
| 291 |
+
num_valid_atoms = atom_exists_mask.sum(dim=-1)
|
| 292 |
+
|
| 293 |
+
# Compute GDT_TS
|
| 294 |
+
score = (
|
| 295 |
+
((deviation < 1) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
|
| 296 |
+
+ ((deviation < 2) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
|
| 297 |
+
+ ((deviation < 4) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
|
| 298 |
+
+ ((deviation < 8) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
|
| 299 |
+
) * 0.25
|
| 300 |
+
|
| 301 |
+
if reduction == "batch":
|
| 302 |
+
return score.mean()
|
| 303 |
+
elif reduction == "per_sample":
|
| 304 |
+
return score
|
| 305 |
+
else:
|
| 306 |
+
raise ValueError("Unrecognized reduction: '{reduction}'")
|
| 307 |
+
|
esmfold2_residue_constants.py
ADDED
|
@@ -0,0 +1,1224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 EvolutionaryScale
|
| 2 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 3 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""Constants used in AlphaFold."""
|
| 18 |
+
|
| 19 |
+
import collections
|
| 20 |
+
import functools
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import List, Mapping, Tuple
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
# import tree
|
| 27 |
+
|
| 28 |
+
# Internal import (35fd).
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Distance from one CA to next CA [trans configuration: omega = 180].
|
| 32 |
+
ca_ca = 3.80209737096
|
| 33 |
+
|
| 34 |
+
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
|
| 35 |
+
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
|
| 36 |
+
# chi angles so their chi angle lists are empty.
|
| 37 |
+
chi_angles_atoms = {
|
| 38 |
+
"ALA": [],
|
| 39 |
+
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
|
| 40 |
+
"ARG": [
|
| 41 |
+
["N", "CA", "CB", "CG"],
|
| 42 |
+
["CA", "CB", "CG", "CD"],
|
| 43 |
+
["CB", "CG", "CD", "NE"],
|
| 44 |
+
["CG", "CD", "NE", "CZ"],
|
| 45 |
+
],
|
| 46 |
+
"ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
|
| 47 |
+
"ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
|
| 48 |
+
"CYS": [["N", "CA", "CB", "SG"]],
|
| 49 |
+
"GLN": [
|
| 50 |
+
["N", "CA", "CB", "CG"],
|
| 51 |
+
["CA", "CB", "CG", "CD"],
|
| 52 |
+
["CB", "CG", "CD", "OE1"],
|
| 53 |
+
],
|
| 54 |
+
"GLU": [
|
| 55 |
+
["N", "CA", "CB", "CG"],
|
| 56 |
+
["CA", "CB", "CG", "CD"],
|
| 57 |
+
["CB", "CG", "CD", "OE1"],
|
| 58 |
+
],
|
| 59 |
+
"GLY": [],
|
| 60 |
+
"HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
|
| 61 |
+
"ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
|
| 62 |
+
"LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 63 |
+
"LYS": [
|
| 64 |
+
["N", "CA", "CB", "CG"],
|
| 65 |
+
["CA", "CB", "CG", "CD"],
|
| 66 |
+
["CB", "CG", "CD", "CE"],
|
| 67 |
+
["CG", "CD", "CE", "NZ"],
|
| 68 |
+
],
|
| 69 |
+
"MET": [
|
| 70 |
+
["N", "CA", "CB", "CG"],
|
| 71 |
+
["CA", "CB", "CG", "SD"],
|
| 72 |
+
["CB", "CG", "SD", "CE"],
|
| 73 |
+
],
|
| 74 |
+
"PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 75 |
+
"PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
|
| 76 |
+
"SER": [["N", "CA", "CB", "OG"]],
|
| 77 |
+
"THR": [["N", "CA", "CB", "OG1"]],
|
| 78 |
+
"TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 79 |
+
"TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 80 |
+
"VAL": [["N", "CA", "CB", "CG1"]],
|
| 81 |
+
"UNK": [],
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
# If chi angles given in fixed-length array, this matrix determines how to mask
|
| 85 |
+
# them for each AA type. The order is as per restype_order (see below).
|
| 86 |
+
chi_angles_mask = [
|
| 87 |
+
[0.0, 0.0, 0.0, 0.0], # ALA
|
| 88 |
+
[1.0, 1.0, 1.0, 1.0], # ARG
|
| 89 |
+
[1.0, 1.0, 0.0, 0.0], # ASN
|
| 90 |
+
[1.0, 1.0, 0.0, 0.0], # ASP
|
| 91 |
+
[1.0, 0.0, 0.0, 0.0], # CYS
|
| 92 |
+
[1.0, 1.0, 1.0, 0.0], # GLN
|
| 93 |
+
[1.0, 1.0, 1.0, 0.0], # GLU
|
| 94 |
+
[0.0, 0.0, 0.0, 0.0], # GLY
|
| 95 |
+
[1.0, 1.0, 0.0, 0.0], # HIS
|
| 96 |
+
[1.0, 1.0, 0.0, 0.0], # ILE
|
| 97 |
+
[1.0, 1.0, 0.0, 0.0], # LEU
|
| 98 |
+
[1.0, 1.0, 1.0, 1.0], # LYS
|
| 99 |
+
[1.0, 1.0, 1.0, 0.0], # MET
|
| 100 |
+
[1.0, 1.0, 0.0, 0.0], # PHE
|
| 101 |
+
[1.0, 1.0, 0.0, 0.0], # PRO
|
| 102 |
+
[1.0, 0.0, 0.0, 0.0], # SER
|
| 103 |
+
[1.0, 0.0, 0.0, 0.0], # THR
|
| 104 |
+
[1.0, 1.0, 0.0, 0.0], # TRP
|
| 105 |
+
[1.0, 1.0, 0.0, 0.0], # TYR
|
| 106 |
+
[1.0, 0.0, 0.0, 0.0], # VAL
|
| 107 |
+
[0.0, 0.0, 0.0, 0.0], # UNK
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
# The following chi angles are pi periodic: they can be rotated by a multiple
|
| 111 |
+
# of pi without affecting the structure.
|
| 112 |
+
chi_pi_periodic = [
|
| 113 |
+
[0.0, 0.0, 0.0, 0.0], # ALA
|
| 114 |
+
[0.0, 0.0, 0.0, 0.0], # ARG
|
| 115 |
+
[0.0, 0.0, 0.0, 0.0], # ASN
|
| 116 |
+
[0.0, 1.0, 0.0, 0.0], # ASP
|
| 117 |
+
[0.0, 0.0, 0.0, 0.0], # CYS
|
| 118 |
+
[0.0, 0.0, 0.0, 0.0], # GLN
|
| 119 |
+
[0.0, 0.0, 1.0, 0.0], # GLU
|
| 120 |
+
[0.0, 0.0, 0.0, 0.0], # GLY
|
| 121 |
+
[0.0, 0.0, 0.0, 0.0], # HIS
|
| 122 |
+
[0.0, 0.0, 0.0, 0.0], # ILE
|
| 123 |
+
[0.0, 0.0, 0.0, 0.0], # LEU
|
| 124 |
+
[0.0, 0.0, 0.0, 0.0], # LYS
|
| 125 |
+
[0.0, 0.0, 0.0, 0.0], # MET
|
| 126 |
+
[0.0, 1.0, 0.0, 0.0], # PHE
|
| 127 |
+
[0.0, 0.0, 0.0, 0.0], # PRO
|
| 128 |
+
[0.0, 0.0, 0.0, 0.0], # SER
|
| 129 |
+
[0.0, 0.0, 0.0, 0.0], # THR
|
| 130 |
+
[0.0, 0.0, 0.0, 0.0], # TRP
|
| 131 |
+
[0.0, 1.0, 0.0, 0.0], # TYR
|
| 132 |
+
[0.0, 0.0, 0.0, 0.0], # VAL
|
| 133 |
+
[0.0, 0.0, 0.0, 0.0], # UNK
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
|
| 137 |
+
# psi and chi angles:
|
| 138 |
+
# 0: 'backbone group',
|
| 139 |
+
# 1: 'pre-omega-group', (empty)
|
| 140 |
+
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
|
| 141 |
+
# 3: 'psi-group',
|
| 142 |
+
# 4,5,6,7: 'chi1,2,3,4-group'
|
| 143 |
+
# The atom positions are relative to the axis-end-atom of the corresponding
|
| 144 |
+
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
|
| 145 |
+
# is defined such that the dihedral-angle-definiting atom (the last entry in
|
| 146 |
+
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
|
| 147 |
+
# format: [atomname, group_idx, rel_position]
|
| 148 |
+
rigid_group_atom_positions = {
|
| 149 |
+
"ALA": [
|
| 150 |
+
["N", 0, (-0.525, 1.363, 0.000)],
|
| 151 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 152 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
| 153 |
+
["CB", 0, (-0.529, -0.774, -1.205)],
|
| 154 |
+
["O", 3, (0.627, 1.062, 0.000)],
|
| 155 |
+
],
|
| 156 |
+
"ARG": [
|
| 157 |
+
["N", 0, (-0.524, 1.362, -0.000)],
|
| 158 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 159 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
| 160 |
+
["CB", 0, (-0.524, -0.778, -1.209)],
|
| 161 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
| 162 |
+
["CG", 4, (0.616, 1.390, -0.000)],
|
| 163 |
+
["CD", 5, (0.564, 1.414, 0.000)],
|
| 164 |
+
["NE", 6, (0.539, 1.357, -0.000)],
|
| 165 |
+
["NH1", 7, (0.206, 2.301, 0.000)],
|
| 166 |
+
["NH2", 7, (2.078, 0.978, -0.000)],
|
| 167 |
+
["CZ", 7, (0.758, 1.093, -0.000)],
|
| 168 |
+
],
|
| 169 |
+
"ASN": [
|
| 170 |
+
["N", 0, (-0.536, 1.357, 0.000)],
|
| 171 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 172 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
| 173 |
+
["CB", 0, (-0.531, -0.787, -1.200)],
|
| 174 |
+
["O", 3, (0.625, 1.062, 0.000)],
|
| 175 |
+
["CG", 4, (0.584, 1.399, 0.000)],
|
| 176 |
+
["ND2", 5, (0.593, -1.188, 0.001)],
|
| 177 |
+
["OD1", 5, (0.633, 1.059, 0.000)],
|
| 178 |
+
],
|
| 179 |
+
"ASP": [
|
| 180 |
+
["N", 0, (-0.525, 1.362, -0.000)],
|
| 181 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 182 |
+
["C", 0, (1.527, 0.000, -0.000)],
|
| 183 |
+
["CB", 0, (-0.526, -0.778, -1.208)],
|
| 184 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 185 |
+
["CG", 4, (0.593, 1.398, -0.000)],
|
| 186 |
+
["OD1", 5, (0.610, 1.091, 0.000)],
|
| 187 |
+
["OD2", 5, (0.592, -1.101, -0.003)],
|
| 188 |
+
],
|
| 189 |
+
"CYS": [
|
| 190 |
+
["N", 0, (-0.522, 1.362, -0.000)],
|
| 191 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 192 |
+
["C", 0, (1.524, 0.000, 0.000)],
|
| 193 |
+
["CB", 0, (-0.519, -0.773, -1.212)],
|
| 194 |
+
["O", 3, (0.625, 1.062, -0.000)],
|
| 195 |
+
["SG", 4, (0.728, 1.653, 0.000)],
|
| 196 |
+
],
|
| 197 |
+
"GLN": [
|
| 198 |
+
["N", 0, (-0.526, 1.361, -0.000)],
|
| 199 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 200 |
+
["C", 0, (1.526, 0.000, 0.000)],
|
| 201 |
+
["CB", 0, (-0.525, -0.779, -1.207)],
|
| 202 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 203 |
+
["CG", 4, (0.615, 1.393, 0.000)],
|
| 204 |
+
["CD", 5, (0.587, 1.399, -0.000)],
|
| 205 |
+
["NE2", 6, (0.593, -1.189, -0.001)],
|
| 206 |
+
["OE1", 6, (0.634, 1.060, 0.000)],
|
| 207 |
+
],
|
| 208 |
+
"GLU": [
|
| 209 |
+
["N", 0, (-0.528, 1.361, 0.000)],
|
| 210 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 211 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
| 212 |
+
["CB", 0, (-0.526, -0.781, -1.207)],
|
| 213 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
| 214 |
+
["CG", 4, (0.615, 1.392, 0.000)],
|
| 215 |
+
["CD", 5, (0.600, 1.397, 0.000)],
|
| 216 |
+
["OE1", 6, (0.607, 1.095, -0.000)],
|
| 217 |
+
["OE2", 6, (0.589, -1.104, -0.001)],
|
| 218 |
+
],
|
| 219 |
+
"GLY": [
|
| 220 |
+
["N", 0, (-0.572, 1.337, 0.000)],
|
| 221 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 222 |
+
["C", 0, (1.517, -0.000, -0.000)],
|
| 223 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 224 |
+
],
|
| 225 |
+
"HIS": [
|
| 226 |
+
["N", 0, (-0.527, 1.360, 0.000)],
|
| 227 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 228 |
+
["C", 0, (1.525, 0.000, 0.000)],
|
| 229 |
+
["CB", 0, (-0.525, -0.778, -1.208)],
|
| 230 |
+
["O", 3, (0.625, 1.063, 0.000)],
|
| 231 |
+
["CG", 4, (0.600, 1.370, -0.000)],
|
| 232 |
+
["CD2", 5, (0.889, -1.021, 0.003)],
|
| 233 |
+
["ND1", 5, (0.744, 1.160, -0.000)],
|
| 234 |
+
["CE1", 5, (2.030, 0.851, 0.002)],
|
| 235 |
+
["NE2", 5, (2.145, -0.466, 0.004)],
|
| 236 |
+
],
|
| 237 |
+
"ILE": [
|
| 238 |
+
["N", 0, (-0.493, 1.373, -0.000)],
|
| 239 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 240 |
+
["C", 0, (1.527, -0.000, -0.000)],
|
| 241 |
+
["CB", 0, (-0.536, -0.793, -1.213)],
|
| 242 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
| 243 |
+
["CG1", 4, (0.534, 1.437, -0.000)],
|
| 244 |
+
["CG2", 4, (0.540, -0.785, -1.199)],
|
| 245 |
+
["CD1", 5, (0.619, 1.391, 0.000)],
|
| 246 |
+
],
|
| 247 |
+
"LEU": [
|
| 248 |
+
["N", 0, (-0.520, 1.363, 0.000)],
|
| 249 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 250 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
| 251 |
+
["CB", 0, (-0.522, -0.773, -1.214)],
|
| 252 |
+
["O", 3, (0.625, 1.063, -0.000)],
|
| 253 |
+
["CG", 4, (0.678, 1.371, 0.000)],
|
| 254 |
+
["CD1", 5, (0.530, 1.430, -0.000)],
|
| 255 |
+
["CD2", 5, (0.535, -0.774, 1.200)],
|
| 256 |
+
],
|
| 257 |
+
"LYS": [
|
| 258 |
+
["N", 0, (-0.526, 1.362, -0.000)],
|
| 259 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 260 |
+
["C", 0, (1.526, 0.000, 0.000)],
|
| 261 |
+
["CB", 0, (-0.524, -0.778, -1.208)],
|
| 262 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 263 |
+
["CG", 4, (0.619, 1.390, 0.000)],
|
| 264 |
+
["CD", 5, (0.559, 1.417, 0.000)],
|
| 265 |
+
["CE", 6, (0.560, 1.416, 0.000)],
|
| 266 |
+
["NZ", 7, (0.554, 1.387, 0.000)],
|
| 267 |
+
],
|
| 268 |
+
"MET": [
|
| 269 |
+
["N", 0, (-0.521, 1.364, -0.000)],
|
| 270 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 271 |
+
["C", 0, (1.525, 0.000, 0.000)],
|
| 272 |
+
["CB", 0, (-0.523, -0.776, -1.210)],
|
| 273 |
+
["O", 3, (0.625, 1.062, -0.000)],
|
| 274 |
+
["CG", 4, (0.613, 1.391, -0.000)],
|
| 275 |
+
["SD", 5, (0.703, 1.695, 0.000)],
|
| 276 |
+
["CE", 6, (0.320, 1.786, -0.000)],
|
| 277 |
+
],
|
| 278 |
+
"PHE": [
|
| 279 |
+
["N", 0, (-0.518, 1.363, 0.000)],
|
| 280 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 281 |
+
["C", 0, (1.524, 0.000, -0.000)],
|
| 282 |
+
["CB", 0, (-0.525, -0.776, -1.212)],
|
| 283 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 284 |
+
["CG", 4, (0.607, 1.377, 0.000)],
|
| 285 |
+
["CD1", 5, (0.709, 1.195, -0.000)],
|
| 286 |
+
["CD2", 5, (0.706, -1.196, 0.000)],
|
| 287 |
+
["CE1", 5, (2.102, 1.198, -0.000)],
|
| 288 |
+
["CE2", 5, (2.098, -1.201, -0.000)],
|
| 289 |
+
["CZ", 5, (2.794, -0.003, -0.001)],
|
| 290 |
+
],
|
| 291 |
+
"PRO": [
|
| 292 |
+
["N", 0, (-0.566, 1.351, -0.000)],
|
| 293 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 294 |
+
["C", 0, (1.527, -0.000, 0.000)],
|
| 295 |
+
["CB", 0, (-0.546, -0.611, -1.293)],
|
| 296 |
+
["O", 3, (0.621, 1.066, 0.000)],
|
| 297 |
+
["CG", 4, (0.382, 1.445, 0.0)],
|
| 298 |
+
# ['CD', 5, (0.427, 1.440, 0.0)],
|
| 299 |
+
["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
|
| 300 |
+
],
|
| 301 |
+
"SER": [
|
| 302 |
+
["N", 0, (-0.529, 1.360, -0.000)],
|
| 303 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 304 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
| 305 |
+
["CB", 0, (-0.518, -0.777, -1.211)],
|
| 306 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 307 |
+
["OG", 4, (0.503, 1.325, 0.000)],
|
| 308 |
+
],
|
| 309 |
+
"THR": [
|
| 310 |
+
["N", 0, (-0.517, 1.364, 0.000)],
|
| 311 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 312 |
+
["C", 0, (1.526, 0.000, -0.000)],
|
| 313 |
+
["CB", 0, (-0.516, -0.793, -1.215)],
|
| 314 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
| 315 |
+
["CG2", 4, (0.550, -0.718, -1.228)],
|
| 316 |
+
["OG1", 4, (0.472, 1.353, 0.000)],
|
| 317 |
+
],
|
| 318 |
+
"TRP": [
|
| 319 |
+
["N", 0, (-0.521, 1.363, 0.000)],
|
| 320 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 321 |
+
["C", 0, (1.525, -0.000, 0.000)],
|
| 322 |
+
["CB", 0, (-0.523, -0.776, -1.212)],
|
| 323 |
+
["O", 3, (0.627, 1.062, 0.000)],
|
| 324 |
+
["CG", 4, (0.609, 1.370, -0.000)],
|
| 325 |
+
["CD1", 5, (0.824, 1.091, 0.000)],
|
| 326 |
+
["CD2", 5, (0.854, -1.148, -0.005)],
|
| 327 |
+
["CE2", 5, (2.186, -0.678, -0.007)],
|
| 328 |
+
["CE3", 5, (0.622, -2.530, -0.007)],
|
| 329 |
+
["NE1", 5, (2.140, 0.690, -0.004)],
|
| 330 |
+
["CH2", 5, (3.028, -2.890, -0.013)],
|
| 331 |
+
["CZ2", 5, (3.283, -1.543, -0.011)],
|
| 332 |
+
["CZ3", 5, (1.715, -3.389, -0.011)],
|
| 333 |
+
],
|
| 334 |
+
"TYR": [
|
| 335 |
+
["N", 0, (-0.522, 1.362, 0.000)],
|
| 336 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 337 |
+
["C", 0, (1.524, -0.000, -0.000)],
|
| 338 |
+
["CB", 0, (-0.522, -0.776, -1.213)],
|
| 339 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
| 340 |
+
["CG", 4, (0.607, 1.382, -0.000)],
|
| 341 |
+
["CD1", 5, (0.716, 1.195, -0.000)],
|
| 342 |
+
["CD2", 5, (0.713, -1.194, -0.001)],
|
| 343 |
+
["CE1", 5, (2.107, 1.200, -0.002)],
|
| 344 |
+
["CE2", 5, (2.104, -1.201, -0.003)],
|
| 345 |
+
["OH", 5, (4.168, -0.002, -0.005)],
|
| 346 |
+
["CZ", 5, (2.791, -0.001, -0.003)],
|
| 347 |
+
],
|
| 348 |
+
"VAL": [
|
| 349 |
+
["N", 0, (-0.494, 1.373, -0.000)],
|
| 350 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 351 |
+
["C", 0, (1.527, -0.000, -0.000)],
|
| 352 |
+
["CB", 0, (-0.533, -0.795, -1.213)],
|
| 353 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
| 354 |
+
["CG1", 4, (0.540, 1.429, -0.000)],
|
| 355 |
+
["CG2", 4, (0.533, -0.776, 1.203)],
|
| 356 |
+
],
|
| 357 |
+
# Assume alanine positions for unknown AA
|
| 358 |
+
"UNK": [
|
| 359 |
+
["N", 0, (-0.525, 1.363, 0.000)],
|
| 360 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 361 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
| 362 |
+
],
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
|
| 366 |
+
residue_atoms = {
|
| 367 |
+
"ALA": ["C", "CA", "CB", "N", "O"],
|
| 368 |
+
"ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
|
| 369 |
+
"ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
|
| 370 |
+
"ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
|
| 371 |
+
"CYS": ["C", "CA", "CB", "N", "O", "SG"],
|
| 372 |
+
"GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
|
| 373 |
+
"GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
|
| 374 |
+
"GLY": ["C", "CA", "N", "O"],
|
| 375 |
+
"HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
|
| 376 |
+
"ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
|
| 377 |
+
"LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
|
| 378 |
+
"LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
|
| 379 |
+
"MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
|
| 380 |
+
"PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
|
| 381 |
+
"PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
|
| 382 |
+
"SER": ["C", "CA", "CB", "N", "O", "OG"],
|
| 383 |
+
"THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
|
| 384 |
+
"TRP": [
|
| 385 |
+
"C",
|
| 386 |
+
"CA",
|
| 387 |
+
"CB",
|
| 388 |
+
"CG",
|
| 389 |
+
"CD1",
|
| 390 |
+
"CD2",
|
| 391 |
+
"CE2",
|
| 392 |
+
"CE3",
|
| 393 |
+
"CZ2",
|
| 394 |
+
"CZ3",
|
| 395 |
+
"CH2",
|
| 396 |
+
"N",
|
| 397 |
+
"NE1",
|
| 398 |
+
"O",
|
| 399 |
+
],
|
| 400 |
+
"TYR": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O", "OH"],
|
| 401 |
+
"VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
|
| 402 |
+
"UNK": ["C", "CA", "N"],
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
# Naming swaps for ambiguous atom names.
|
| 406 |
+
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
|
| 407 |
+
# 4 of the 20 amino acids.
|
| 408 |
+
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
|
| 409 |
+
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
|
| 410 |
+
# the 'ambiguous' atoms and their neighbours)
|
| 411 |
+
# TODO: ^ interpret this
|
| 412 |
+
residue_atom_renaming_swaps = {
|
| 413 |
+
"ASP": {"OD1": "OD2"},
|
| 414 |
+
"GLU": {"OE1": "OE2"},
|
| 415 |
+
"PHE": {"CD1": "CD2", "CE1": "CE2"},
|
| 416 |
+
"TYR": {"CD1": "CD2", "CE1": "CE2"},
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
|
| 420 |
+
van_der_waals_radius = {"C": 1.7, "N": 1.55, "O": 1.52, "S": 1.8}
|
| 421 |
+
|
| 422 |
+
Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"])
|
| 423 |
+
BondAngle = collections.namedtuple(
|
| 424 |
+
"BondAngle", ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"]
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
@functools.lru_cache(maxsize=None)
|
| 429 |
+
def load_stereo_chemical_props() -> (
|
| 430 |
+
Tuple[
|
| 431 |
+
Mapping[str, List[Bond]],
|
| 432 |
+
Mapping[str, List[Bond]],
|
| 433 |
+
Mapping[str, List[BondAngle]],
|
| 434 |
+
]
|
| 435 |
+
):
|
| 436 |
+
"""Load stereo_chemical_props.txt into a nice structure.
|
| 437 |
+
|
| 438 |
+
Load literature values for bond lengths and bond angles and translate
|
| 439 |
+
bond angles into the length of the opposite edge of the triangle
|
| 440 |
+
("residue_virtual_bonds").
|
| 441 |
+
|
| 442 |
+
Returns:
|
| 443 |
+
residue_bonds: dict that maps resname --> list of Bond tuples
|
| 444 |
+
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
|
| 445 |
+
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
|
| 446 |
+
"""
|
| 447 |
+
stereo_chemical_props = Path(
|
| 448 |
+
"evolutionaryscale/structure/stereo_chemical_props.txt"
|
| 449 |
+
).read_text()
|
| 450 |
+
|
| 451 |
+
lines_iter = iter(stereo_chemical_props.splitlines())
|
| 452 |
+
# Load bond lengths.
|
| 453 |
+
residue_bonds = {}
|
| 454 |
+
next(lines_iter) # Skip header line.
|
| 455 |
+
for line in lines_iter:
|
| 456 |
+
if line.strip() == "-":
|
| 457 |
+
break
|
| 458 |
+
bond, resname, length, stddev = line.split()
|
| 459 |
+
atom1, atom2 = bond.split("-")
|
| 460 |
+
if resname not in residue_bonds:
|
| 461 |
+
residue_bonds[resname] = []
|
| 462 |
+
residue_bonds[resname].append(Bond(atom1, atom2, float(length), float(stddev)))
|
| 463 |
+
residue_bonds["UNK"] = []
|
| 464 |
+
|
| 465 |
+
# Load bond angles.
|
| 466 |
+
residue_bond_angles = {}
|
| 467 |
+
next(lines_iter) # Skip empty line.
|
| 468 |
+
next(lines_iter) # Skip header line.
|
| 469 |
+
for line in lines_iter:
|
| 470 |
+
if line.strip() == "-":
|
| 471 |
+
break
|
| 472 |
+
bond, resname, angle_degree, stddev_degree = line.split()
|
| 473 |
+
atom1, atom2, atom3 = bond.split("-")
|
| 474 |
+
if resname not in residue_bond_angles:
|
| 475 |
+
residue_bond_angles[resname] = []
|
| 476 |
+
residue_bond_angles[resname].append(
|
| 477 |
+
BondAngle(
|
| 478 |
+
atom1,
|
| 479 |
+
atom2,
|
| 480 |
+
atom3,
|
| 481 |
+
float(angle_degree) / 180.0 * np.pi,
|
| 482 |
+
float(stddev_degree) / 180.0 * np.pi,
|
| 483 |
+
)
|
| 484 |
+
)
|
| 485 |
+
residue_bond_angles["UNK"] = []
|
| 486 |
+
|
| 487 |
+
def make_bond_key(atom1_name, atom2_name):
|
| 488 |
+
"""Unique key to lookup bonds."""
|
| 489 |
+
return "-".join(sorted([atom1_name, atom2_name]))
|
| 490 |
+
|
| 491 |
+
# Translate bond angles into distances ("virtual bonds").
|
| 492 |
+
residue_virtual_bonds = {}
|
| 493 |
+
for resname, bond_angles in residue_bond_angles.items():
|
| 494 |
+
# Create a fast lookup dict for bond lengths.
|
| 495 |
+
bond_cache = {}
|
| 496 |
+
for b in residue_bonds[resname]:
|
| 497 |
+
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
|
| 498 |
+
residue_virtual_bonds[resname] = []
|
| 499 |
+
for ba in bond_angles:
|
| 500 |
+
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
|
| 501 |
+
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
|
| 502 |
+
|
| 503 |
+
# Compute distance between atom1 and atom3 using the law of cosines
|
| 504 |
+
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
|
| 505 |
+
gamma = ba.angle_rad
|
| 506 |
+
length = np.sqrt(
|
| 507 |
+
bond1.length**2
|
| 508 |
+
+ bond2.length**2
|
| 509 |
+
- 2 * bond1.length * bond2.length * np.cos(gamma)
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
# Propagation of uncertainty assuming uncorrelated errors.
|
| 513 |
+
dl_outer = 0.5 / length
|
| 514 |
+
dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
|
| 515 |
+
dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
|
| 516 |
+
dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
|
| 517 |
+
stddev = np.sqrt(
|
| 518 |
+
(dl_dgamma * ba.stddev) ** 2
|
| 519 |
+
+ (dl_db1 * bond1.stddev) ** 2
|
| 520 |
+
+ (dl_db2 * bond2.stddev) ** 2
|
| 521 |
+
)
|
| 522 |
+
residue_virtual_bonds[resname].append(
|
| 523 |
+
Bond(ba.atom1_name, ba.atom3name, length, stddev)
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
# Between-residue bond lengths for general bonds (first element) and for Proline
|
| 530 |
+
# (second element).
|
| 531 |
+
between_res_bond_length_c_n = [1.329, 1.341]
|
| 532 |
+
between_res_bond_length_stddev_c_n = [0.014, 0.016]
|
| 533 |
+
|
| 534 |
+
# Between-residue cos_angles.
|
| 535 |
+
between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
|
| 536 |
+
between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
|
| 537 |
+
|
| 538 |
+
# This mapping is used when we need to store atom data in a format that requires
|
| 539 |
+
# fixed atom data size for every residue (e.g. a numpy array).
|
| 540 |
+
atom_types = [
|
| 541 |
+
"N",
|
| 542 |
+
"CA",
|
| 543 |
+
"C",
|
| 544 |
+
"CB",
|
| 545 |
+
"O",
|
| 546 |
+
"CG",
|
| 547 |
+
"CG1",
|
| 548 |
+
"CG2",
|
| 549 |
+
"OG",
|
| 550 |
+
"OG1",
|
| 551 |
+
"SG",
|
| 552 |
+
"CD",
|
| 553 |
+
"CD1",
|
| 554 |
+
"CD2",
|
| 555 |
+
"ND1",
|
| 556 |
+
"ND2",
|
| 557 |
+
"OD1",
|
| 558 |
+
"OD2",
|
| 559 |
+
"SD",
|
| 560 |
+
"CE",
|
| 561 |
+
"CE1",
|
| 562 |
+
"CE2",
|
| 563 |
+
"CE3",
|
| 564 |
+
"NE",
|
| 565 |
+
"NE1",
|
| 566 |
+
"NE2",
|
| 567 |
+
"OE1",
|
| 568 |
+
"OE2",
|
| 569 |
+
"CH2",
|
| 570 |
+
"NH1",
|
| 571 |
+
"NH2",
|
| 572 |
+
"OH",
|
| 573 |
+
"CZ",
|
| 574 |
+
"CZ2",
|
| 575 |
+
"CZ3",
|
| 576 |
+
"NZ",
|
| 577 |
+
"OXT",
|
| 578 |
+
]
|
| 579 |
+
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
|
| 580 |
+
atom_type_num = len(atom_types) # := 37.
|
| 581 |
+
|
| 582 |
+
# A compact atom encoding with 14 columns
|
| 583 |
+
# pylint: disable=line-too-long
|
| 584 |
+
# pylint: disable=bad-whitespace
|
| 585 |
+
restype_name_to_atom14_names = {
|
| 586 |
+
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
|
| 587 |
+
"ARG": [
|
| 588 |
+
"N",
|
| 589 |
+
"CA",
|
| 590 |
+
"C",
|
| 591 |
+
"O",
|
| 592 |
+
"CB",
|
| 593 |
+
"CG",
|
| 594 |
+
"CD",
|
| 595 |
+
"NE",
|
| 596 |
+
"CZ",
|
| 597 |
+
"NH1",
|
| 598 |
+
"NH2",
|
| 599 |
+
"",
|
| 600 |
+
"",
|
| 601 |
+
"",
|
| 602 |
+
],
|
| 603 |
+
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
|
| 604 |
+
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
|
| 605 |
+
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
|
| 606 |
+
"GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", ""],
|
| 607 |
+
"GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", ""],
|
| 608 |
+
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
|
| 609 |
+
"HIS": [
|
| 610 |
+
"N",
|
| 611 |
+
"CA",
|
| 612 |
+
"C",
|
| 613 |
+
"O",
|
| 614 |
+
"CB",
|
| 615 |
+
"CG",
|
| 616 |
+
"ND1",
|
| 617 |
+
"CD2",
|
| 618 |
+
"CE1",
|
| 619 |
+
"NE2",
|
| 620 |
+
"",
|
| 621 |
+
"",
|
| 622 |
+
"",
|
| 623 |
+
"",
|
| 624 |
+
],
|
| 625 |
+
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
|
| 626 |
+
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
|
| 627 |
+
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
|
| 628 |
+
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
|
| 629 |
+
"PHE": [
|
| 630 |
+
"N",
|
| 631 |
+
"CA",
|
| 632 |
+
"C",
|
| 633 |
+
"O",
|
| 634 |
+
"CB",
|
| 635 |
+
"CG",
|
| 636 |
+
"CD1",
|
| 637 |
+
"CD2",
|
| 638 |
+
"CE1",
|
| 639 |
+
"CE2",
|
| 640 |
+
"CZ",
|
| 641 |
+
"",
|
| 642 |
+
"",
|
| 643 |
+
"",
|
| 644 |
+
],
|
| 645 |
+
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
|
| 646 |
+
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
|
| 647 |
+
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
|
| 648 |
+
"TRP": [
|
| 649 |
+
"N",
|
| 650 |
+
"CA",
|
| 651 |
+
"C",
|
| 652 |
+
"O",
|
| 653 |
+
"CB",
|
| 654 |
+
"CG",
|
| 655 |
+
"CD1",
|
| 656 |
+
"CD2",
|
| 657 |
+
"NE1",
|
| 658 |
+
"CE2",
|
| 659 |
+
"CE3",
|
| 660 |
+
"CZ2",
|
| 661 |
+
"CZ3",
|
| 662 |
+
"CH2",
|
| 663 |
+
],
|
| 664 |
+
"TYR": [
|
| 665 |
+
"N",
|
| 666 |
+
"CA",
|
| 667 |
+
"C",
|
| 668 |
+
"O",
|
| 669 |
+
"CB",
|
| 670 |
+
"CG",
|
| 671 |
+
"CD1",
|
| 672 |
+
"CD2",
|
| 673 |
+
"CE1",
|
| 674 |
+
"CE2",
|
| 675 |
+
"CZ",
|
| 676 |
+
"OH",
|
| 677 |
+
"",
|
| 678 |
+
"",
|
| 679 |
+
],
|
| 680 |
+
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
|
| 681 |
+
"UNK": ["N", "CA", "C", "", "", "", "", "", "", "", "", "", "", ""],
|
| 682 |
+
}
|
| 683 |
+
# pylint: enable=line-too-long
|
| 684 |
+
# pylint: enable=bad-whitespace
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
# This is the standard residue order when coding AA type as a number.
|
| 688 |
+
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
|
| 689 |
+
restypes = [
|
| 690 |
+
"A",
|
| 691 |
+
"R",
|
| 692 |
+
"N",
|
| 693 |
+
"D",
|
| 694 |
+
"C",
|
| 695 |
+
"Q",
|
| 696 |
+
"E",
|
| 697 |
+
"G",
|
| 698 |
+
"H",
|
| 699 |
+
"I",
|
| 700 |
+
"L",
|
| 701 |
+
"K",
|
| 702 |
+
"M",
|
| 703 |
+
"F",
|
| 704 |
+
"P",
|
| 705 |
+
"S",
|
| 706 |
+
"T",
|
| 707 |
+
"W",
|
| 708 |
+
"Y",
|
| 709 |
+
"V",
|
| 710 |
+
]
|
| 711 |
+
restype_order = {restype: i for i, restype in enumerate(restypes)}
|
| 712 |
+
restype_num = len(restypes) # := 20.
|
| 713 |
+
unk_restype_index = restype_num # Catch-all index for unknown restypes.
|
| 714 |
+
|
| 715 |
+
restypes_with_x = restypes + ["X"]
|
| 716 |
+
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
|
| 717 |
+
|
| 718 |
+
bb_atoms = ["N", "CA", "C", "O"]
|
| 719 |
+
|
| 720 |
+
# Hydrophobicity by residue (positive values are hydrophobic). Derived from Black & Mould (1991), normalized by subtracting 0.5.
|
| 721 |
+
hydrophobicity = {
|
| 722 |
+
"ALA": 0.116,
|
| 723 |
+
"ARG": -0.5,
|
| 724 |
+
"ASN": -0.264,
|
| 725 |
+
"ASP": -0.472,
|
| 726 |
+
"CYS": 0.18,
|
| 727 |
+
"GLN": -0.249,
|
| 728 |
+
"GLU": -0.457,
|
| 729 |
+
"GLY": 0.001,
|
| 730 |
+
"HIS": -0.335,
|
| 731 |
+
"ILE": 0.443,
|
| 732 |
+
"LEU": 0.443,
|
| 733 |
+
"LYS": -0.217,
|
| 734 |
+
"MET": 0.238,
|
| 735 |
+
"PHE": 0.5,
|
| 736 |
+
"PRO": 0.211,
|
| 737 |
+
"SER": -0.141,
|
| 738 |
+
"THR": -0.05,
|
| 739 |
+
"TRP": 0.378,
|
| 740 |
+
"TYR": 0.38,
|
| 741 |
+
"VAL": 0.325,
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
# Side chain max accessible surface area in Ala-X-Ala tripeptide (from Chennamsetty et al. 2010).
|
| 745 |
+
side_chain_asa = {
|
| 746 |
+
"ALA": 64.7809,
|
| 747 |
+
"ARG": 210.02,
|
| 748 |
+
"ASN": 113.187,
|
| 749 |
+
"ASP": 110.209,
|
| 750 |
+
"CYS": 95.2439,
|
| 751 |
+
"GLN": 147.855,
|
| 752 |
+
"GLU": 143.924,
|
| 753 |
+
"GLY": 23.1338,
|
| 754 |
+
"HIS": 146.449,
|
| 755 |
+
"ILE": 151.242,
|
| 756 |
+
"LEU": 139.524,
|
| 757 |
+
"LYS": 177.366,
|
| 758 |
+
"MET": 164.674,
|
| 759 |
+
"PHE": 186.7,
|
| 760 |
+
"PRO": 111.533,
|
| 761 |
+
"SER": 81.2159,
|
| 762 |
+
"THR": 111.597,
|
| 763 |
+
"TRP": 229.619,
|
| 764 |
+
"TYR": 200.306,
|
| 765 |
+
"VAL": 124.237,
|
| 766 |
+
}
|
| 767 |
+
|
| 768 |
+
# Approximate Volumes of amino acids in cubic angstroms.
|
| 769 |
+
# https://www.imgt.org/IMGTeducation/Aide-memoire/_UK/aminoacids/abbreviation.html
|
| 770 |
+
amino_acid_volumes = {
|
| 771 |
+
"A": 88.6, # Alanine
|
| 772 |
+
"R": 173.4, # Arginine
|
| 773 |
+
"N": 114.1, # Asparagine
|
| 774 |
+
"D": 111.1, # Aspartic acid
|
| 775 |
+
"C": 108.5, # Cysteine
|
| 776 |
+
"Q": 143.8, # Glutamine
|
| 777 |
+
"E": 138.4, # Glutamic acid
|
| 778 |
+
"G": 60.1, # Glycine
|
| 779 |
+
"H": 153.2, # Histidine
|
| 780 |
+
"I": 166.7, # Isoleucine
|
| 781 |
+
"L": 166.7, # Leucine
|
| 782 |
+
"K": 168.6, # Lysine
|
| 783 |
+
"M": 162.9, # Methionine
|
| 784 |
+
"F": 189.9, # Phenylalanine
|
| 785 |
+
"P": 112.7, # Proline
|
| 786 |
+
"S": 89.0, # Serine
|
| 787 |
+
"T": 116.1, # Threonine
|
| 788 |
+
"W": 227.8, # Tryptophan
|
| 789 |
+
"Y": 193.6, # Tyrosine
|
| 790 |
+
"V": 140.0, # Valine
|
| 791 |
+
"X": 88.6, # Unknown, use Alanine as approximation
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
def sequence_to_onehot(
|
| 796 |
+
sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
|
| 797 |
+
) -> np.ndarray:
|
| 798 |
+
"""Maps the given sequence into a one-hot encoded matrix.
|
| 799 |
+
|
| 800 |
+
Args:
|
| 801 |
+
sequence: An amino acid sequence.
|
| 802 |
+
mapping: A dictionary mapping amino acids to integers.
|
| 803 |
+
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
|
| 804 |
+
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
|
| 805 |
+
amino acid 'X', an error will be thrown. If False, any amino acid not in
|
| 806 |
+
the mapping will throw an error.
|
| 807 |
+
|
| 808 |
+
Returns:
|
| 809 |
+
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
|
| 810 |
+
the sequence.
|
| 811 |
+
|
| 812 |
+
Raises:
|
| 813 |
+
ValueError: If the mapping doesn't contain values from 0 to
|
| 814 |
+
num_unique_aas - 1 without any gaps.
|
| 815 |
+
"""
|
| 816 |
+
num_entries = max(mapping.values()) + 1
|
| 817 |
+
|
| 818 |
+
if sorted(set(mapping.values())) != list(range(num_entries)):
|
| 819 |
+
raise ValueError(
|
| 820 |
+
"The mapping must have values from 0 to num_unique_aas-1 "
|
| 821 |
+
"without any gaps. Got: %s" % sorted(mapping.values())
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
|
| 825 |
+
|
| 826 |
+
for aa_index, aa_type in enumerate(sequence):
|
| 827 |
+
if map_unknown_to_x:
|
| 828 |
+
if aa_type.isalpha() and aa_type.isupper():
|
| 829 |
+
aa_id = mapping.get(aa_type, mapping["X"])
|
| 830 |
+
else:
|
| 831 |
+
raise ValueError(f"Invalid character in the sequence: {aa_type}")
|
| 832 |
+
else:
|
| 833 |
+
aa_id = mapping[aa_type]
|
| 834 |
+
one_hot_arr[aa_index, aa_id] = 1
|
| 835 |
+
|
| 836 |
+
return one_hot_arr
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
restype_1to3 = {
|
| 840 |
+
"A": "ALA",
|
| 841 |
+
"R": "ARG",
|
| 842 |
+
"N": "ASN",
|
| 843 |
+
"D": "ASP",
|
| 844 |
+
"C": "CYS",
|
| 845 |
+
"Q": "GLN",
|
| 846 |
+
"E": "GLU",
|
| 847 |
+
"G": "GLY",
|
| 848 |
+
"H": "HIS",
|
| 849 |
+
"I": "ILE",
|
| 850 |
+
"L": "LEU",
|
| 851 |
+
"K": "LYS",
|
| 852 |
+
"M": "MET",
|
| 853 |
+
"F": "PHE",
|
| 854 |
+
"P": "PRO",
|
| 855 |
+
"S": "SER",
|
| 856 |
+
"T": "THR",
|
| 857 |
+
"W": "TRP",
|
| 858 |
+
"Y": "TYR",
|
| 859 |
+
"V": "VAL",
|
| 860 |
+
"X": "UNK",
|
| 861 |
+
}
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
|
| 865 |
+
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
|
| 866 |
+
# many more, and less common, three letter names as keys and maps many of these
|
| 867 |
+
# to the same one letter name (including 'X' and 'U' which we don't use here).
|
| 868 |
+
restype_3to1 = {v: k for k, v in restype_1to3.items()}
|
| 869 |
+
|
| 870 |
+
# Define a restype name for all unknown residues.
|
| 871 |
+
unk_restype = "UNK"
|
| 872 |
+
|
| 873 |
+
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
|
| 874 |
+
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
|
| 875 |
+
|
| 876 |
+
hydrophobic_resnames = {"VAL", "ILE", "LEU", "PHE", "MET", "TRP"}
|
| 877 |
+
|
| 878 |
+
# The mapping here uses hhblits convention, so that B is mapped to D, J and O
|
| 879 |
+
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
|
| 880 |
+
# remaining 20 amino acids are kept in alphabetical order.
|
| 881 |
+
# There are 2 non-amino acid codes, X (representing any amino acid) and
|
| 882 |
+
# "-" representing a missing amino acid in an alignment. The id for these
|
| 883 |
+
# codes is put at the end (20 and 21) so that they can easily be ignored if
|
| 884 |
+
# desired.
|
| 885 |
+
HHBLITS_AA_TO_ID = {
|
| 886 |
+
"A": 0,
|
| 887 |
+
"B": 2,
|
| 888 |
+
"C": 1,
|
| 889 |
+
"D": 2,
|
| 890 |
+
"E": 3,
|
| 891 |
+
"F": 4,
|
| 892 |
+
"G": 5,
|
| 893 |
+
"H": 6,
|
| 894 |
+
"I": 7,
|
| 895 |
+
"J": 20,
|
| 896 |
+
"K": 8,
|
| 897 |
+
"L": 9,
|
| 898 |
+
"M": 10,
|
| 899 |
+
"N": 11,
|
| 900 |
+
"O": 20,
|
| 901 |
+
"P": 12,
|
| 902 |
+
"Q": 13,
|
| 903 |
+
"R": 14,
|
| 904 |
+
"S": 15,
|
| 905 |
+
"T": 16,
|
| 906 |
+
"U": 1,
|
| 907 |
+
"V": 17,
|
| 908 |
+
"W": 18,
|
| 909 |
+
"X": 20,
|
| 910 |
+
"Y": 19,
|
| 911 |
+
"Z": 3,
|
| 912 |
+
"-": 21,
|
| 913 |
+
}
|
| 914 |
+
|
| 915 |
+
# Partial inversion of HHBLITS_AA_TO_ID.
|
| 916 |
+
ID_TO_HHBLITS_AA = {
|
| 917 |
+
0: "A",
|
| 918 |
+
1: "C", # Also U.
|
| 919 |
+
2: "D", # Also B.
|
| 920 |
+
3: "E", # Also Z.
|
| 921 |
+
4: "F",
|
| 922 |
+
5: "G",
|
| 923 |
+
6: "H",
|
| 924 |
+
7: "I",
|
| 925 |
+
8: "K",
|
| 926 |
+
9: "L",
|
| 927 |
+
10: "M",
|
| 928 |
+
11: "N",
|
| 929 |
+
12: "P",
|
| 930 |
+
13: "Q",
|
| 931 |
+
14: "R",
|
| 932 |
+
15: "S",
|
| 933 |
+
16: "T",
|
| 934 |
+
17: "V",
|
| 935 |
+
18: "W",
|
| 936 |
+
19: "Y",
|
| 937 |
+
20: "X", # Includes J and O.
|
| 938 |
+
21: "-",
|
| 939 |
+
}
|
| 940 |
+
|
| 941 |
+
restypes_with_x_and_gap = restypes + ["X", "-"]
|
| 942 |
+
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
|
| 943 |
+
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
|
| 944 |
+
for i in range(len(restypes_with_x_and_gap))
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
def _make_standard_atom_mask() -> np.ndarray:
|
| 949 |
+
"""Returns [num_res_types, num_atom_types] mask array."""
|
| 950 |
+
# +1 to account for unknown (all 0s).
|
| 951 |
+
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
|
| 952 |
+
for restype, restype_letter in enumerate(restypes):
|
| 953 |
+
restype_name = restype_1to3[restype_letter]
|
| 954 |
+
atom_names = residue_atoms[restype_name]
|
| 955 |
+
for atom_name in atom_names:
|
| 956 |
+
atom_type = atom_order[atom_name]
|
| 957 |
+
mask[restype, atom_type] = 1
|
| 958 |
+
return mask
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
STANDARD_ATOM_MASK = _make_standard_atom_mask()
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
# A one hot representation for the first and second atoms defining the axis
|
| 965 |
+
# of rotation for each chi-angle in each residue.
|
| 966 |
+
def chi_angle_atom(atom_index: int) -> np.ndarray:
|
| 967 |
+
"""Define chi-angle rigid groups via one-hot representations."""
|
| 968 |
+
chi_angles_index = {}
|
| 969 |
+
one_hots = []
|
| 970 |
+
|
| 971 |
+
for k, v in chi_angles_atoms.items():
|
| 972 |
+
indices = [atom_types.index(s[atom_index]) for s in v]
|
| 973 |
+
indices.extend([-1] * (4 - len(indices)))
|
| 974 |
+
chi_angles_index[k] = indices
|
| 975 |
+
|
| 976 |
+
for r in restypes:
|
| 977 |
+
res3 = restype_1to3[r]
|
| 978 |
+
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
|
| 979 |
+
one_hots.append(one_hot)
|
| 980 |
+
|
| 981 |
+
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
|
| 982 |
+
one_hot = np.stack(one_hots, axis=0)
|
| 983 |
+
one_hot = np.transpose(one_hot, [0, 2, 1])
|
| 984 |
+
|
| 985 |
+
return one_hot
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
chi_atom_1_one_hot = chi_angle_atom(1)
|
| 989 |
+
chi_atom_2_one_hot = chi_angle_atom(2)
|
| 990 |
+
|
| 991 |
+
# An array like chi_angles_atoms but using indices rather than names.
|
| 992 |
+
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
|
| 993 |
+
# chi_angles_atom_indices = tree.map_structure(
|
| 994 |
+
# lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
|
| 995 |
+
# )
|
| 996 |
+
chi_angles_atom_indices = np.array(
|
| 997 |
+
[
|
| 998 |
+
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
|
| 999 |
+
for chi_atoms in chi_angles_atom_indices
|
| 1000 |
+
]
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
|
| 1004 |
+
# and atom index within that group.
|
| 1005 |
+
chi_groups_for_atom = collections.defaultdict(list)
|
| 1006 |
+
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
|
| 1007 |
+
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
|
| 1008 |
+
for atom_i, atom in enumerate(chi_group):
|
| 1009 |
+
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
|
| 1010 |
+
chi_groups_for_atom = dict(chi_groups_for_atom)
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
def _make_rigid_transformation_4x4(ex, ey, translation):
|
| 1014 |
+
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
|
| 1015 |
+
# Normalize ex.
|
| 1016 |
+
ex_normalized = ex / np.linalg.norm(ex)
|
| 1017 |
+
|
| 1018 |
+
# make ey perpendicular to ex
|
| 1019 |
+
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
|
| 1020 |
+
ey_normalized /= np.linalg.norm(ey_normalized)
|
| 1021 |
+
|
| 1022 |
+
# compute ez as cross product
|
| 1023 |
+
eznorm = np.cross(ex_normalized, ey_normalized)
|
| 1024 |
+
m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
|
| 1025 |
+
m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
|
| 1026 |
+
return m
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
# create an array with (restype, atomtype) --> rigid_group_idx
|
| 1030 |
+
# and an array with (restype, atomtype, coord) for the atom positions
|
| 1031 |
+
# and compute affine transformation matrices (4,4) from one rigid group to the
|
| 1032 |
+
# previous group
|
| 1033 |
+
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
|
| 1034 |
+
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
|
| 1035 |
+
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
|
| 1036 |
+
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
|
| 1037 |
+
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
|
| 1038 |
+
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
|
| 1039 |
+
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
|
| 1040 |
+
|
| 1041 |
+
|
| 1042 |
+
def _make_rigid_group_constants():
|
| 1043 |
+
"""Fill the arrays above."""
|
| 1044 |
+
for restype, restype_letter in enumerate(restypes_with_x):
|
| 1045 |
+
resname = restype_1to3[restype_letter]
|
| 1046 |
+
for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]:
|
| 1047 |
+
atomtype = atom_order[atomname]
|
| 1048 |
+
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
|
| 1049 |
+
restype_atom37_mask[restype, atomtype] = 1
|
| 1050 |
+
restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
|
| 1051 |
+
|
| 1052 |
+
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
|
| 1053 |
+
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
|
| 1054 |
+
restype_atom14_mask[restype, atom14idx] = 1
|
| 1055 |
+
restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position
|
| 1056 |
+
|
| 1057 |
+
for restype, restype_letter in enumerate(restypes_with_x):
|
| 1058 |
+
resname = restype_1to3[restype_letter]
|
| 1059 |
+
atom_positions = {
|
| 1060 |
+
name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]
|
| 1061 |
+
}
|
| 1062 |
+
|
| 1063 |
+
# backbone to backbone is the identity transform
|
| 1064 |
+
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
|
| 1065 |
+
|
| 1066 |
+
# pre-omega-frame to backbone (currently dummy identity matrix)
|
| 1067 |
+
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
|
| 1068 |
+
|
| 1069 |
+
# phi-frame to backbone
|
| 1070 |
+
mat = _make_rigid_transformation_4x4(
|
| 1071 |
+
ex=atom_positions["N"] - atom_positions["CA"],
|
| 1072 |
+
ey=np.array([1.0, 0.0, 0.0]),
|
| 1073 |
+
translation=atom_positions["N"],
|
| 1074 |
+
)
|
| 1075 |
+
restype_rigid_group_default_frame[restype, 2, :, :] = mat
|
| 1076 |
+
|
| 1077 |
+
# psi-frame to backbone
|
| 1078 |
+
mat = _make_rigid_transformation_4x4(
|
| 1079 |
+
ex=atom_positions["C"] - atom_positions["CA"],
|
| 1080 |
+
ey=atom_positions["CA"] - atom_positions["N"],
|
| 1081 |
+
translation=atom_positions["C"],
|
| 1082 |
+
)
|
| 1083 |
+
restype_rigid_group_default_frame[restype, 3, :, :] = mat
|
| 1084 |
+
|
| 1085 |
+
# chi1-frame to backbone
|
| 1086 |
+
if chi_angles_mask[restype][0]:
|
| 1087 |
+
base_atom_names = chi_angles_atoms[resname][0]
|
| 1088 |
+
base_atom_positions = [atom_positions[name] for name in base_atom_names]
|
| 1089 |
+
mat = _make_rigid_transformation_4x4(
|
| 1090 |
+
ex=base_atom_positions[2] - base_atom_positions[1],
|
| 1091 |
+
ey=base_atom_positions[0] - base_atom_positions[1],
|
| 1092 |
+
translation=base_atom_positions[2],
|
| 1093 |
+
)
|
| 1094 |
+
restype_rigid_group_default_frame[restype, 4, :, :] = mat
|
| 1095 |
+
|
| 1096 |
+
# chi2-frame to chi1-frame
|
| 1097 |
+
# chi3-frame to chi2-frame
|
| 1098 |
+
# chi4-frame to chi3-frame
|
| 1099 |
+
# luckily all rotation axes for the next frame start at (0,0,0) of the
|
| 1100 |
+
# previous frame
|
| 1101 |
+
for chi_idx in range(1, 4):
|
| 1102 |
+
if chi_angles_mask[restype][chi_idx]:
|
| 1103 |
+
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
|
| 1104 |
+
axis_end_atom_position = atom_positions[axis_end_atom_name]
|
| 1105 |
+
mat = _make_rigid_transformation_4x4(
|
| 1106 |
+
ex=axis_end_atom_position,
|
| 1107 |
+
ey=np.array([-1.0, 0.0, 0.0]),
|
| 1108 |
+
translation=axis_end_atom_position,
|
| 1109 |
+
)
|
| 1110 |
+
restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
_make_rigid_group_constants()
|
| 1114 |
+
|
| 1115 |
+
|
| 1116 |
+
def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=15.0):
|
| 1117 |
+
"""compute upper and lower bounds for bonds to assess violations."""
|
| 1118 |
+
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
|
| 1119 |
+
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
|
| 1120 |
+
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
|
| 1121 |
+
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
|
| 1122 |
+
for restype, restype_letter in enumerate(restypes):
|
| 1123 |
+
resname = restype_1to3[restype_letter]
|
| 1124 |
+
atom_list = restype_name_to_atom14_names[resname]
|
| 1125 |
+
|
| 1126 |
+
# create lower and upper bounds for clashes
|
| 1127 |
+
for atom1_idx, atom1_name in enumerate(atom_list):
|
| 1128 |
+
if not atom1_name:
|
| 1129 |
+
continue
|
| 1130 |
+
atom1_radius = van_der_waals_radius[atom1_name[0]]
|
| 1131 |
+
for atom2_idx, atom2_name in enumerate(atom_list):
|
| 1132 |
+
if (not atom2_name) or atom1_idx == atom2_idx:
|
| 1133 |
+
continue
|
| 1134 |
+
atom2_radius = van_der_waals_radius[atom2_name[0]]
|
| 1135 |
+
lower = atom1_radius + atom2_radius - overlap_tolerance
|
| 1136 |
+
upper = 1e10
|
| 1137 |
+
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
|
| 1138 |
+
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
|
| 1139 |
+
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
|
| 1140 |
+
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
|
| 1141 |
+
|
| 1142 |
+
# overwrite lower and upper bounds for bonds and angles
|
| 1143 |
+
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
|
| 1144 |
+
atom1_idx = atom_list.index(b.atom1_name)
|
| 1145 |
+
atom2_idx = atom_list.index(b.atom2_name)
|
| 1146 |
+
lower = b.length - bond_length_tolerance_factor * b.stddev
|
| 1147 |
+
upper = b.length + bond_length_tolerance_factor * b.stddev
|
| 1148 |
+
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
|
| 1149 |
+
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
|
| 1150 |
+
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
|
| 1151 |
+
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
|
| 1152 |
+
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
|
| 1153 |
+
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
|
| 1154 |
+
return {
|
| 1155 |
+
"lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
|
| 1156 |
+
"upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
|
| 1157 |
+
"stddev": restype_atom14_bond_stddev, # shape (21,14,14)
|
| 1158 |
+
}
|
| 1159 |
+
|
| 1160 |
+
|
| 1161 |
+
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
|
| 1162 |
+
restype_atom14_ambiguous_atoms_swap_idx = np.tile(np.arange(14, dtype=int), (21, 1))
|
| 1163 |
+
|
| 1164 |
+
|
| 1165 |
+
def _make_atom14_ambiguity_feats():
|
| 1166 |
+
for res, pairs in residue_atom_renaming_swaps.items():
|
| 1167 |
+
res_idx = restype_order[restype_3to1[res]]
|
| 1168 |
+
for atom1, atom2 in pairs.items():
|
| 1169 |
+
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
|
| 1170 |
+
atom2_idx = restype_name_to_atom14_names[res].index(atom2)
|
| 1171 |
+
restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
|
| 1172 |
+
restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
|
| 1173 |
+
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = atom2_idx
|
| 1174 |
+
restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = atom1_idx
|
| 1175 |
+
|
| 1176 |
+
|
| 1177 |
+
_make_atom14_ambiguity_feats()
|
| 1178 |
+
|
| 1179 |
+
|
| 1180 |
+
def aatype_to_str_sequence(aatype):
|
| 1181 |
+
return "".join([restypes_with_x[aatype[i]] for i in range(len(aatype))])
|
| 1182 |
+
|
| 1183 |
+
|
| 1184 |
+
# NOTE(thayes): These are computed based on the average CA->C and CA->N norm from rigid_group_atom_positions
|
| 1185 |
+
CA_TO_N_NORM = 1.4591
|
| 1186 |
+
CA_TO_C_NORM = 1.5252
|
| 1187 |
+
|
| 1188 |
+
|
| 1189 |
+
def _make_restype_atom37_to_atom14():
|
| 1190 |
+
"""Map from atom37 to atom14 per residue type."""
|
| 1191 |
+
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
|
| 1192 |
+
for rt in restypes:
|
| 1193 |
+
atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
|
| 1194 |
+
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
|
| 1195 |
+
restype_atom37_to_atom14.append(
|
| 1196 |
+
[
|
| 1197 |
+
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
|
| 1198 |
+
for name in atom_types
|
| 1199 |
+
]
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
restype_atom37_to_atom14.append([0] * 37)
|
| 1203 |
+
restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
|
| 1204 |
+
return restype_atom37_to_atom14
|
| 1205 |
+
|
| 1206 |
+
|
| 1207 |
+
def _make_restype_atom14_to_atom37():
|
| 1208 |
+
"""Map from atom14 to atom37 per residue type."""
|
| 1209 |
+
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
|
| 1210 |
+
for rt in restypes:
|
| 1211 |
+
atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
|
| 1212 |
+
restype_atom14_to_atom37.append(
|
| 1213 |
+
[(atom_order[name] if name else 0) for name in atom_names]
|
| 1214 |
+
)
|
| 1215 |
+
# Add dummy mapping for restype 'UNK'
|
| 1216 |
+
restype_atom14_to_atom37.append([0] * 14)
|
| 1217 |
+
restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
|
| 1218 |
+
return restype_atom14_to_atom37
|
| 1219 |
+
|
| 1220 |
+
|
| 1221 |
+
RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37()
|
| 1222 |
+
RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14()
|
| 1223 |
+
CHAIN_BREAK_TOKEN = "|"
|
| 1224 |
+
|
esmfold2_sequential_dataclass.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from dataclasses import dataclass, fields, replace
|
| 3 |
+
from typing import TypeVar
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from .esmfold2_misc import concat_objects, slice_any_object
|
| 8 |
+
|
| 9 |
+
T = TypeVar("T")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True)
|
| 13 |
+
class SequentialDataclass(ABC):
|
| 14 |
+
"""
|
| 15 |
+
This is a builder on a dataclass that allows for automatic slicing and concatenation.
|
| 16 |
+
|
| 17 |
+
When representing multimodal data, we often have multiple datatypes which have sequence dimensions that are the same (e.g. the length of the protein).
|
| 18 |
+
|
| 19 |
+
When applying a transformation like a crop, we want to apply this to all tensors at the same time (e.g. crop the sequence, structure, and function).
|
| 20 |
+
|
| 21 |
+
We also have some fields that are not sequential (like an id, or data source), which we don't want to crop.
|
| 22 |
+
|
| 23 |
+
The SequentialDataclass abstracts this cropping away, allowing you to define dataclasses that implement `__len__`, `__getitem__` and `concat` automatically.
|
| 24 |
+
|
| 25 |
+
This is done through the `metadata` field, which can take 3 values:
|
| 26 |
+
`sequence` (bool): True or False, tells the dataclass whether this field is a sequential type. Default: False.
|
| 27 |
+
`sequence_dim` (int): Which dimension is the sequential dimension (e.g. for a list of inverse folded sequences, we want to index each sequence in the list, not the list itself). Default: 0.
|
| 28 |
+
`join_token` (Any): What token to use to join when concatenating elements. Default: None.
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
Example:
|
| 32 |
+
|
| 33 |
+
@dataclass(frozen=True)
|
| 34 |
+
class Foo(SequentialDataclass):
|
| 35 |
+
id: str
|
| 36 |
+
sequence: str = field(metadata={"sequence": True, "join_token": "|"})
|
| 37 |
+
tensor: torch.Tensor = field(metadata={"sequence": True, "join_token": torch.nan})
|
| 38 |
+
|
| 39 |
+
def __len__(self):
|
| 40 |
+
# Must implement the __len__ method
|
| 41 |
+
return len(self.sequence)
|
| 42 |
+
|
| 43 |
+
>>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(5))
|
| 44 |
+
Foo(id='foo', sequence='ABCDE', tensor=tensor([ 0.0252, -0.3335, -0.5143, 0.0251, -1.0717]))
|
| 45 |
+
|
| 46 |
+
>>> foo[1:4]
|
| 47 |
+
Foo(id='foo', sequence='BCD', tensor=tensor([-0.3335, -0.5143, 0.0251]))
|
| 48 |
+
|
| 49 |
+
>>> foo[np.arange(5) < 3]
|
| 50 |
+
Foo(id='foo', sequence='ABC', tensor=tensor([ 0.0252, -0.3335, -0.5143]))
|
| 51 |
+
|
| 52 |
+
>>> Foo.concat([foo[:2], foo[3:]])
|
| 53 |
+
Foo(id='foo', sequence='AB|DE', tensor=tensor([ 0.0252, -0.3335, nan, 0.0251, -1.0717]))
|
| 54 |
+
|
| 55 |
+
# Trying to create a type where the sequence lengths do not match raises an error
|
| 56 |
+
>>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(6))
|
| 57 |
+
ValueError: Mismatch in sequence length for field: tensor. Expected 5, received 6
|
| 58 |
+
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __post_init__(self):
|
| 62 |
+
self._check_sequence_lengths_match()
|
| 63 |
+
|
| 64 |
+
@abstractmethod
|
| 65 |
+
def __len__(self):
|
| 66 |
+
raise NotImplementedError
|
| 67 |
+
|
| 68 |
+
def __getitem__(self, idx: int | list[int] | slice | np.ndarray):
|
| 69 |
+
updated_fields = {}
|
| 70 |
+
if isinstance(idx, int):
|
| 71 |
+
# make it so that things remain sequential
|
| 72 |
+
idx = [idx]
|
| 73 |
+
|
| 74 |
+
for fld in fields(self):
|
| 75 |
+
if fld.metadata.get("sequence", False):
|
| 76 |
+
# this is a sequence, should be the same length as all other sequences
|
| 77 |
+
sequence_dim = fld.metadata.get("sequence_dim", 0)
|
| 78 |
+
value = getattr(self, fld.name)
|
| 79 |
+
if value is None:
|
| 80 |
+
continue
|
| 81 |
+
match sequence_dim:
|
| 82 |
+
case 0:
|
| 83 |
+
# sequence is first dimension
|
| 84 |
+
value = getattr(self, fld.name)
|
| 85 |
+
value = slice_any_object(value, idx)
|
| 86 |
+
updated_fields[fld.name] = value
|
| 87 |
+
case 1:
|
| 88 |
+
new_value = [slice_any_object(item, idx) for item in value]
|
| 89 |
+
updated_fields[fld.name] = value.__class__(new_value)
|
| 90 |
+
case _:
|
| 91 |
+
raise NotImplementedError(
|
| 92 |
+
"Arbitrary slicing for different sequence length fields is not implemented"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
return replace(self, **updated_fields)
|
| 96 |
+
|
| 97 |
+
def _check_sequence_lengths_match(self):
|
| 98 |
+
"""Checks if sequence lengths of all "sequence" fields match."""
|
| 99 |
+
for fld in fields(self):
|
| 100 |
+
if fld.metadata.get("sequence", False) and fld.name != "complex":
|
| 101 |
+
# this is a sequence, should be the same length as all other sequences
|
| 102 |
+
sequence_dim = fld.metadata.get("sequence_dim", 0)
|
| 103 |
+
value = getattr(self, fld.name)
|
| 104 |
+
if value is None:
|
| 105 |
+
continue
|
| 106 |
+
match sequence_dim:
|
| 107 |
+
case 0:
|
| 108 |
+
# sequence is first dimension
|
| 109 |
+
value = getattr(self, fld.name)
|
| 110 |
+
if len(value) != len(self):
|
| 111 |
+
raise ValueError(
|
| 112 |
+
f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(value)}"
|
| 113 |
+
)
|
| 114 |
+
case 1:
|
| 115 |
+
for item in value:
|
| 116 |
+
if len(item) != len(self):
|
| 117 |
+
raise ValueError(
|
| 118 |
+
f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(item)}"
|
| 119 |
+
)
|
| 120 |
+
case _:
|
| 121 |
+
raise NotImplementedError(
|
| 122 |
+
"Arbitrary matching for different sequence length fields is not implemented"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
@classmethod
|
| 126 |
+
def concat(cls, items: list[T], **kwargs) -> T:
|
| 127 |
+
updated_fields = {}
|
| 128 |
+
for fld in fields(cls):
|
| 129 |
+
if fld.metadata.get("sequence", False):
|
| 130 |
+
# this is a sequence, should be the same length as all other sequences
|
| 131 |
+
sequence_dim = fld.metadata.get("sequence_dim", 0)
|
| 132 |
+
join_value = fld.metadata.get("join_token", None)
|
| 133 |
+
if getattr(items[0], fld.name) is None:
|
| 134 |
+
continue
|
| 135 |
+
values = [getattr(item, fld.name) for item in items]
|
| 136 |
+
match sequence_dim:
|
| 137 |
+
case 0:
|
| 138 |
+
# sequence is first dimension
|
| 139 |
+
value = concat_objects(values, join_value)
|
| 140 |
+
updated_fields[fld.name] = value
|
| 141 |
+
case 1:
|
| 142 |
+
new_value = [
|
| 143 |
+
concat_objects(item, join_value) for item in zip(*values)
|
| 144 |
+
]
|
| 145 |
+
updated_fields[fld.name] = getattr(
|
| 146 |
+
items[0], fld.name
|
| 147 |
+
).__class__(new_value)
|
| 148 |
+
case _:
|
| 149 |
+
raise NotImplementedError(
|
| 150 |
+
"Arbitrary joining for different sequence length fields is not implemented"
|
| 151 |
+
)
|
| 152 |
+
updated_fields.update(kwargs)
|
| 153 |
+
|
| 154 |
+
return replace(
|
| 155 |
+
items[0], # type: ignore
|
| 156 |
+
**updated_fields,
|
| 157 |
+
)
|
| 158 |
+
|
esmfold2_system.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import subprocess
|
| 3 |
+
import typing as T
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
PathLike = T.Union[str, Path]
|
| 7 |
+
PathOrBuffer = T.Union[PathLike, io.StringIO]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def run_subprocess_with_errorcheck(
|
| 11 |
+
*popenargs,
|
| 12 |
+
capture_output: bool = False,
|
| 13 |
+
quiet: bool = False,
|
| 14 |
+
env: dict[str, str] | None = None,
|
| 15 |
+
shell: bool = False,
|
| 16 |
+
executable: str | None = None,
|
| 17 |
+
**kws,
|
| 18 |
+
) -> subprocess.CompletedProcess:
|
| 19 |
+
"""A command similar to subprocess.run, however the errormessage will
|
| 20 |
+
contain the stderr when using this function. This makes it significantly
|
| 21 |
+
easier to diagnose issues.
|
| 22 |
+
"""
|
| 23 |
+
try:
|
| 24 |
+
if capture_output:
|
| 25 |
+
stdout = subprocess.PIPE
|
| 26 |
+
elif quiet:
|
| 27 |
+
stdout = subprocess.DEVNULL
|
| 28 |
+
else:
|
| 29 |
+
stdout = None
|
| 30 |
+
|
| 31 |
+
p = subprocess.run(
|
| 32 |
+
*popenargs,
|
| 33 |
+
stderr=subprocess.PIPE,
|
| 34 |
+
stdout=stdout,
|
| 35 |
+
check=True,
|
| 36 |
+
env=env,
|
| 37 |
+
shell=shell,
|
| 38 |
+
executable=executable,
|
| 39 |
+
**kws,
|
| 40 |
+
)
|
| 41 |
+
except subprocess.CalledProcessError as e:
|
| 42 |
+
raise RuntimeError(
|
| 43 |
+
f"Command failed with errorcode {e.returncode}." f"\n\n{e.stderr.decode()}"
|
| 44 |
+
)
|
| 45 |
+
return p
|
| 46 |
+
|
esmfold2_types.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Re-exports of the canonical SPI dataclasses from input_builder.
|
| 2 |
+
|
| 3 |
+
This module exists so the HF processor and downstream code can import the
|
| 4 |
+
ESMFold2 input types from a single namespace without picking up internal-only
|
| 5 |
+
sibling utilities. The actual definitions live in
|
| 6 |
+
``esm.utils.structure.input_builder``.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from .esmfold2_msa import MSA
|
| 10 |
+
from .esmfold2_parsing import FastaEntry
|
| 11 |
+
from .esmfold2_input_builder import (
|
| 12 |
+
CovalentBond,
|
| 13 |
+
DistogramConditioning,
|
| 14 |
+
DNAInput,
|
| 15 |
+
LigandInput,
|
| 16 |
+
Modification,
|
| 17 |
+
ProteinInput,
|
| 18 |
+
RNAInput,
|
| 19 |
+
StructurePredictionInput,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"FastaEntry",
|
| 24 |
+
"MSA",
|
| 25 |
+
"Modification",
|
| 26 |
+
"ProteinInput",
|
| 27 |
+
"RNAInput",
|
| 28 |
+
"DNAInput",
|
| 29 |
+
"LigandInput",
|
| 30 |
+
"DistogramConditioning",
|
| 31 |
+
"CovalentBond",
|
| 32 |
+
"StructurePredictionInput",
|
| 33 |
+
]
|
| 34 |
+
|
esmfold2_utils_types.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import io
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Union
|
| 7 |
+
|
| 8 |
+
from cloudpathlib import CloudPath
|
| 9 |
+
|
| 10 |
+
PathLike = Union[str, Path, CloudPath]
|
| 11 |
+
PathOrBuffer = Union[PathLike, io.StringIO]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class FunctionAnnotation:
|
| 16 |
+
"""Represents an annotation of a protein's function over a range of residues.
|
| 17 |
+
|
| 18 |
+
Fields:
|
| 19 |
+
label (str): An entry in either the function_tokens or residue_annotations tokenizer vocabs
|
| 20 |
+
start (int): Start index of this annotation. 1-indexed, inclusive.
|
| 21 |
+
end (int): End index of this annotation. 1-indexed, inclusive.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
label: str
|
| 25 |
+
start: int
|
| 26 |
+
end: int
|
| 27 |
+
|
| 28 |
+
def to_tuple(self) -> tuple[str, int, int]:
|
| 29 |
+
return self.label, self.start, self.end
|
| 30 |
+
|
| 31 |
+
def __len__(self) -> int:
|
| 32 |
+
"""Length of the annotation."""
|
| 33 |
+
return self.end - self.start + 1
|
| 34 |
+
|
modeling_esmc.py
ADDED
|
@@ -0,0 +1,1667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Biohub. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""PyTorch ESMC model."""
|
| 15 |
+
|
| 16 |
+
import importlib
|
| 17 |
+
import math
|
| 18 |
+
import re
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Optional, cast
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 25 |
+
from torch.nn import functional as F
|
| 26 |
+
|
| 27 |
+
from transformers.modeling_outputs import (
|
| 28 |
+
MaskedLMOutput,
|
| 29 |
+
ModelOutput,
|
| 30 |
+
SequenceClassifierOutput,
|
| 31 |
+
TokenClassifierOutput,
|
| 32 |
+
)
|
| 33 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 34 |
+
from transformers.utils import (
|
| 35 |
+
auto_docstring,
|
| 36 |
+
can_return_tuple,
|
| 37 |
+
is_flash_attn_2_available,
|
| 38 |
+
logging,
|
| 39 |
+
)
|
| 40 |
+
from .configuration_esmc import ESMCConfig
|
| 41 |
+
from .modeling_esmc_sae import _ESMCSAELayer
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
_CONFIG_FOR_DOC = "ESMCConfig"
|
| 46 |
+
|
| 47 |
+
# Optional accelerated kernels. Pure-PyTorch fallbacks below if absent.
|
| 48 |
+
if is_flash_attn_2_available():
|
| 49 |
+
flash_attn_module = importlib.import_module("flash_attn")
|
| 50 |
+
flash_bert_padding = importlib.import_module("flash_attn.bert_padding")
|
| 51 |
+
flash_attn_varlen_qkvpacked_func = (
|
| 52 |
+
flash_attn_module.flash_attn_varlen_qkvpacked_func
|
| 53 |
+
)
|
| 54 |
+
pad_input = flash_bert_padding.pad_input
|
| 55 |
+
unpad_input = flash_bert_padding.unpad_input
|
| 56 |
+
|
| 57 |
+
_flash_attn_available = True
|
| 58 |
+
else:
|
| 59 |
+
pad_input = unpad_input = flash_attn_varlen_qkvpacked_func = None
|
| 60 |
+
_flash_attn_available = False
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
flash_rotary = importlib.import_module("flash_attn.ops.triton.rotary")
|
| 64 |
+
apply_triton_rotary = flash_rotary.apply_rotary
|
| 65 |
+
|
| 66 |
+
_flash_attn_rotary_available = torch.cuda.is_available()
|
| 67 |
+
except ImportError:
|
| 68 |
+
apply_triton_rotary = None # type: ignore[assignment]
|
| 69 |
+
_flash_attn_rotary_available = False
|
| 70 |
+
|
| 71 |
+
# Transformer Engine: fused LayerNorm+Linear / LayerNorm+MLP kernels with
|
| 72 |
+
# fp32 reduction inside the LayerNorm. Recommended on GPU for accurate bf16
|
| 73 |
+
# inference; without it the pure-PyTorch fallback drifts ~O(10) in fp32 and
|
| 74 |
+
# ~O(100) in bf16 on the unnormalized residual stream (perplexity stays
|
| 75 |
+
# within rounding noise).
|
| 76 |
+
try:
|
| 77 |
+
te = importlib.import_module("transformer_engine.pytorch")
|
| 78 |
+
|
| 79 |
+
_te_available = True
|
| 80 |
+
except ImportError:
|
| 81 |
+
te = None # type: ignore[assignment]
|
| 82 |
+
_te_available = False
|
| 83 |
+
|
| 84 |
+
# xformers: preferred SDPA implementation on GPU. Provides a fused
|
| 85 |
+
# bf16 attention kernel with deterministic reduction order. Flash
|
| 86 |
+
# Attention 2 and PyTorch's ``F.scaled_dot_product_attention`` are
|
| 87 |
+
# progressively-less-preferred fallbacks.
|
| 88 |
+
try:
|
| 89 |
+
xops = importlib.import_module("xformers.ops")
|
| 90 |
+
|
| 91 |
+
_xformers_available = True
|
| 92 |
+
except ImportError:
|
| 93 |
+
xops = None # type: ignore[assignment]
|
| 94 |
+
_xformers_available = False
|
| 95 |
+
|
| 96 |
+
# Flash Attention 2: secondary SDPA fallback. Used when xformers is not
|
| 97 |
+
# installed; fp16 / bf16 only.
|
| 98 |
+
if _flash_attn_available:
|
| 99 |
+
flash_attn_func = flash_attn_module.flash_attn_func
|
| 100 |
+
else:
|
| 101 |
+
flash_attn_func = None # type: ignore[assignment]
|
| 102 |
+
|
| 103 |
+
if not _te_available:
|
| 104 |
+
logger.warning(
|
| 105 |
+
"ESMC: transformer_engine is not installed; falling back to "
|
| 106 |
+
"pure-PyTorch LayerNorm+Linear / LayerNorm+MLP. Outputs will differ "
|
| 107 |
+
"numerically — measured on the unnormalized residual stream (before "
|
| 108 |
+
"the final LayerNorm), ~O(10) max-diff in fp32 and ~O(100) in bf16; "
|
| 109 |
+
"after the final LayerNorm these shrink to a few ULP and perplexity "
|
| 110 |
+
"stays within rounding noise. Install with "
|
| 111 |
+
"`pip install transformer-engine[pytorch]` to enable fused fp32-"
|
| 112 |
+
"reduction LayerNorm."
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if not _xformers_available and not _flash_attn_available:
|
| 116 |
+
logger.warning(
|
| 117 |
+
"ESMC: neither xformers nor flash-attn is installed; falling back "
|
| 118 |
+
"to PyTorch ``F.scaled_dot_product_attention``. The attention "
|
| 119 |
+
"reduction order in bf16 differs from a fused kernel by ~1 bf16 "
|
| 120 |
+
"ULP per attention block; compounded across the 80-block stack "
|
| 121 |
+
"this reaches ~O(100) max-diff on the unnormalized residual stream. "
|
| 122 |
+
"Install xformers (preferred) with `pip install xformers` for a "
|
| 123 |
+
"fused attention kernel."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if torch.cuda.is_available() and not _flash_attn_rotary_available:
|
| 127 |
+
logger.warning(
|
| 128 |
+
"ESMC: flash-attn rotary kernel not installed; falling back to "
|
| 129 |
+
"pure-PyTorch RoPE. For faster GPU inference run `pip install flash-attn`."
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ---------------------------------------------------------------------------
|
| 134 |
+
# Output dataclasses
|
| 135 |
+
# ---------------------------------------------------------------------------
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class ESMCOutput(ModelOutput):
|
| 140 |
+
"""
|
| 141 |
+
Args:
|
| 142 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, d_model)`):
|
| 143 |
+
Sequence of hidden states at the output of the last layer, after layer normalisation.
|
| 144 |
+
hidden_states (`torch.FloatTensor`, *optional*):
|
| 145 |
+
Stacked hidden states for all encoder layers.
|
| 146 |
+
Shape ``(n_layers, batch_size, sequence_length, d_model)``.
|
| 147 |
+
Returned when ``output_hidden_states=True``.
|
| 148 |
+
sae_outputs (`dict[str, torch.Tensor]`, *optional*):
|
| 149 |
+
SAE feature magnitudes keyed by SAE model name (sparse tensors).
|
| 150 |
+
Only populated when SAE models have been registered via
|
| 151 |
+
``add_sae_models`` and ``compute_sae=True``.
|
| 152 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*):
|
| 153 |
+
Per-layer attention weights of shape
|
| 154 |
+
``(batch_size, num_heads, sequence_length, sequence_length)``.
|
| 155 |
+
Returned when ``output_attentions=True``. Not available on the
|
| 156 |
+
``flash_attention_2`` path.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
last_hidden_state: torch.FloatTensor | None = None
|
| 160 |
+
hidden_states: torch.FloatTensor | None = None
|
| 161 |
+
sae_outputs: dict[str, torch.Tensor] | None = None
|
| 162 |
+
attentions: tuple[torch.FloatTensor, ...] | None = None
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@dataclass
|
| 166 |
+
class ESMCMaskedLMOutput(MaskedLMOutput):
|
| 167 |
+
"""
|
| 168 |
+
Args:
|
| 169 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
| 170 |
+
Masked language modelling loss. Returned when ``labels`` are provided.
|
| 171 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`):
|
| 172 |
+
Prediction scores of the language modelling head.
|
| 173 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, d_model)`):
|
| 174 |
+
Final hidden states after layer normalisation.
|
| 175 |
+
hidden_states (`torch.FloatTensor`, *optional*):
|
| 176 |
+
Stacked hidden states. Shape ``(n_layers, batch_size, sequence_length, d_model)``.
|
| 177 |
+
sae_outputs (`dict[str, torch.Tensor]`, *optional*):
|
| 178 |
+
SAE feature magnitudes keyed by SAE model name (sparse tensors).
|
| 179 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*):
|
| 180 |
+
Per-layer attention weights of shape
|
| 181 |
+
``(batch_size, num_heads, sequence_length, sequence_length)``.
|
| 182 |
+
Returned when ``output_attentions=True``.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
loss: torch.FloatTensor | None = None
|
| 186 |
+
logits: torch.FloatTensor | None = None
|
| 187 |
+
last_hidden_state: torch.FloatTensor | None = None
|
| 188 |
+
hidden_states: torch.FloatTensor | None = None
|
| 189 |
+
sae_outputs: dict[str, torch.Tensor] | None = None
|
| 190 |
+
attentions: tuple[torch.FloatTensor, ...] | None = None
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@dataclass
|
| 194 |
+
class ESMCTokenClassifierOutput(TokenClassifierOutput):
|
| 195 |
+
"""
|
| 196 |
+
Args:
|
| 197 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
| 198 |
+
Token classification loss. Returned when ``labels`` are provided.
|
| 199 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`):
|
| 200 |
+
Classification scores (before SoftMax).
|
| 201 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, d_model)`):
|
| 202 |
+
Final hidden states after layer normalisation.
|
| 203 |
+
hidden_states (`torch.FloatTensor`, *optional*):
|
| 204 |
+
Stacked hidden states. Shape ``(n_layers, batch_size, sequence_length, d_model)``.
|
| 205 |
+
sae_outputs (`dict[str, torch.Tensor]`, *optional*):
|
| 206 |
+
SAE feature magnitudes keyed by SAE model name (sparse tensors).
|
| 207 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*):
|
| 208 |
+
Per-layer attention weights of shape
|
| 209 |
+
``(batch_size, num_heads, sequence_length, sequence_length)``.
|
| 210 |
+
Returned when ``output_attentions=True``.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
loss: torch.FloatTensor | None = None
|
| 214 |
+
logits: torch.FloatTensor | None = None
|
| 215 |
+
last_hidden_state: torch.FloatTensor | None = None
|
| 216 |
+
hidden_states: torch.FloatTensor | None = None
|
| 217 |
+
sae_outputs: dict[str, torch.Tensor] | None = None
|
| 218 |
+
attentions: tuple[torch.FloatTensor, ...] | None = None
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@dataclass
|
| 222 |
+
class ESMCSequenceClassifierOutput(SequenceClassifierOutput):
|
| 223 |
+
"""
|
| 224 |
+
Args:
|
| 225 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
| 226 |
+
Sequence classification loss. Returned when ``labels`` are provided.
|
| 227 |
+
logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
|
| 228 |
+
Classification scores (before SoftMax).
|
| 229 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, d_model)`):
|
| 230 |
+
Final hidden states after layer normalisation.
|
| 231 |
+
hidden_states (`torch.FloatTensor`, *optional*):
|
| 232 |
+
Stacked hidden states. Shape ``(n_layers, batch_size, sequence_length, d_model)``.
|
| 233 |
+
sae_outputs (`dict[str, torch.Tensor]`, *optional*):
|
| 234 |
+
SAE feature magnitudes keyed by SAE model name (sparse tensors).
|
| 235 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*):
|
| 236 |
+
Per-layer attention weights of shape
|
| 237 |
+
``(batch_size, num_heads, sequence_length, sequence_length)``.
|
| 238 |
+
Returned when ``output_attentions=True``.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
loss: torch.FloatTensor | None = None
|
| 242 |
+
logits: torch.FloatTensor | None = None
|
| 243 |
+
last_hidden_state: torch.FloatTensor | None = None
|
| 244 |
+
hidden_states: torch.FloatTensor | None = None
|
| 245 |
+
sae_outputs: dict[str, torch.Tensor] | None = None
|
| 246 |
+
attentions: tuple[torch.FloatTensor, ...] | None = None
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# ---------------------------------------------------------------------------
|
| 250 |
+
# Rotary position embedding helpers
|
| 251 |
+
# ---------------------------------------------------------------------------
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def _rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
| 255 |
+
if not interleaved:
|
| 256 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 257 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 258 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 259 |
+
return torch.stack((-x2, x1), dim=-1).flatten(-2, -1)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _apply_rotary_emb_torch(
|
| 263 |
+
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
| 264 |
+
) -> torch.Tensor:
|
| 265 |
+
"""Apply rotary position embeddings (pure PyTorch, no Triton dependency).
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
x: ``(batch, seqlen, n_heads, head_dim)``
|
| 269 |
+
cos: ``(seqlen, rotary_dim / 2)``
|
| 270 |
+
sin: ``(seqlen, rotary_dim / 2)``
|
| 271 |
+
"""
|
| 272 |
+
ro_dim = cos.shape[-1] * 2
|
| 273 |
+
seqlen = x.size(1)
|
| 274 |
+
cos = cos[:seqlen].unsqueeze(1).repeat(1, 1, 2)
|
| 275 |
+
sin = sin[:seqlen].unsqueeze(1).repeat(1, 1, 2)
|
| 276 |
+
return torch.cat(
|
| 277 |
+
[
|
| 278 |
+
x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim], interleaved) * sin,
|
| 279 |
+
x[..., ro_dim:],
|
| 280 |
+
],
|
| 281 |
+
dim=-1,
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class RotaryEmbedding(nn.Module):
|
| 286 |
+
"""Rotary position embeddings (RoPE) as described in `RoFormer`_.
|
| 287 |
+
|
| 288 |
+
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
dim: Size of a single attention head.
|
| 292 |
+
base: Frequency base for the sinusoidal positions.
|
| 293 |
+
interleaved: If ``True`` rotate adjacent pairs (GPT-J style) instead of
|
| 294 |
+
splitting the head dimension in half (GPT-NeoX style).
|
| 295 |
+
scaling_factor: Linear scaling factor applied to position indices.
|
| 296 |
+
pos_idx_in_fp32: Compute position indices in float32 to avoid bf16
|
| 297 |
+
rounding errors at large sequence lengths.
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
dim: int,
|
| 303 |
+
base: float = 10000.0,
|
| 304 |
+
interleaved: bool = False,
|
| 305 |
+
scale_base: float | None = None,
|
| 306 |
+
scaling_factor: float = 1.0,
|
| 307 |
+
pos_idx_in_fp32: bool = True,
|
| 308 |
+
device=None,
|
| 309 |
+
):
|
| 310 |
+
super().__init__()
|
| 311 |
+
self.dim = dim
|
| 312 |
+
self.base = base
|
| 313 |
+
self.interleaved = interleaved
|
| 314 |
+
self.scale_base = scale_base
|
| 315 |
+
self.scaling_factor = scaling_factor
|
| 316 |
+
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
| 317 |
+
|
| 318 |
+
self._seq_len_cached = 0
|
| 319 |
+
self._cos_cached: torch.Tensor | None = None
|
| 320 |
+
self._sin_cached: torch.Tensor | None = None
|
| 321 |
+
self._cos_k_cached: torch.Tensor | None = None
|
| 322 |
+
self._sin_k_cached: torch.Tensor | None = None
|
| 323 |
+
|
| 324 |
+
self.reset_parameters(device=device)
|
| 325 |
+
|
| 326 |
+
def reset_parameters(self, device=None):
|
| 327 |
+
inv_freq = self._compute_inv_freq(device)
|
| 328 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 329 |
+
arange = torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
|
| 330 |
+
scale = (
|
| 331 |
+
(arange + 0.4 * self.dim) / (1.4 * self.dim)
|
| 332 |
+
if self.scale_base is not None
|
| 333 |
+
else None
|
| 334 |
+
)
|
| 335 |
+
self.register_buffer("scale", scale, persistent=False)
|
| 336 |
+
|
| 337 |
+
def _compute_inv_freq(self, device=None) -> torch.Tensor:
|
| 338 |
+
return 1.0 / (
|
| 339 |
+
self.base
|
| 340 |
+
** (
|
| 341 |
+
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
|
| 342 |
+
/ self.dim
|
| 343 |
+
)
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
def _update_cos_sin_cache(self, seqlen: int, device=None, dtype=None):
|
| 347 |
+
if self.inv_freq.is_meta:
|
| 348 |
+
self.reset_parameters(device=device)
|
| 349 |
+
if (
|
| 350 |
+
seqlen > self._seq_len_cached
|
| 351 |
+
or self._cos_cached is None
|
| 352 |
+
or self._cos_cached.device != device
|
| 353 |
+
or self._cos_cached.dtype != dtype
|
| 354 |
+
or (self.training and self._cos_cached.is_inference())
|
| 355 |
+
):
|
| 356 |
+
self._seq_len_cached = seqlen
|
| 357 |
+
if self.pos_idx_in_fp32:
|
| 358 |
+
t = (
|
| 359 |
+
torch.arange(seqlen, device=device, dtype=torch.float32)
|
| 360 |
+
/ self.scaling_factor
|
| 361 |
+
)
|
| 362 |
+
inv_freq = (
|
| 363 |
+
self.inv_freq.to(torch.float32)
|
| 364 |
+
if self.inv_freq.dtype != torch.float32
|
| 365 |
+
else self.inv_freq
|
| 366 |
+
)
|
| 367 |
+
else:
|
| 368 |
+
t = (
|
| 369 |
+
torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) # type: ignore[call-overload]
|
| 370 |
+
/ self.scaling_factor
|
| 371 |
+
)
|
| 372 |
+
inv_freq = self.inv_freq
|
| 373 |
+
freqs = torch.outer(t, inv_freq) # type: ignore[arg-type]
|
| 374 |
+
|
| 375 |
+
if self.scale is None:
|
| 376 |
+
self._cos_cached = torch.cos(freqs).to(dtype)
|
| 377 |
+
self._sin_cached = torch.sin(freqs).to(dtype)
|
| 378 |
+
else:
|
| 379 |
+
_scale: torch.Tensor = self.scale # type: ignore[assignment]
|
| 380 |
+
power = (
|
| 381 |
+
torch.arange(seqlen, dtype=_scale.dtype, device=_scale.device)
|
| 382 |
+
- seqlen // 2
|
| 383 |
+
) / self.scale_base # type: ignore[operator]
|
| 384 |
+
scale = _scale.to(device=power.device) ** power.unsqueeze(-1)
|
| 385 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
| 386 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
| 387 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
| 388 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
| 389 |
+
|
| 390 |
+
def _apply(self, fn, recurse=True):
|
| 391 |
+
if self.inv_freq.is_meta:
|
| 392 |
+
self.reset_parameters(device="cpu")
|
| 393 |
+
result = super()._apply(fn, recurse=recurse)
|
| 394 |
+
# Recompute inv_freq on the new device: CPU vs CUDA ``pow`` differ by
|
| 395 |
+
# ~1 fp32 ULP, which compounds across attention layers. Keep this
|
| 396 |
+
# buffer fp32 even when the module is cast to bf16/fp16; otherwise the
|
| 397 |
+
# rounded RoPE frequencies drift from the internal ESMC path.
|
| 398 |
+
new_inv_freq = self._compute_inv_freq(device=self.inv_freq.device)
|
| 399 |
+
self.register_buffer("inv_freq", new_inv_freq, persistent=False)
|
| 400 |
+
self._seq_len_cached = 0
|
| 401 |
+
self._cos_cached = None
|
| 402 |
+
self._sin_cached = None
|
| 403 |
+
self._cos_k_cached = None
|
| 404 |
+
self._sin_k_cached = None
|
| 405 |
+
return result
|
| 406 |
+
|
| 407 |
+
def forward(
|
| 408 |
+
self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0
|
| 409 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 410 |
+
"""Apply RoPE to query and key tensors.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
q: ``(batch, seqlen, n_heads, head_dim)``
|
| 414 |
+
k: ``(batch, seqlen, n_heads, head_dim)``
|
| 415 |
+
seqlen_offset: Offset used in incremental decoding.
|
| 416 |
+
|
| 417 |
+
Returns:
|
| 418 |
+
Tuple of rotated ``(q, k)`` tensors with the same shape as the inputs.
|
| 419 |
+
"""
|
| 420 |
+
self._update_cos_sin_cache(
|
| 421 |
+
q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype
|
| 422 |
+
)
|
| 423 |
+
assert self._cos_cached is not None and self._sin_cached is not None
|
| 424 |
+
|
| 425 |
+
if self.scale is not None:
|
| 426 |
+
raise NotImplementedError("XPos scaling is not supported in this path.")
|
| 427 |
+
|
| 428 |
+
cos = self._cos_cached[seqlen_offset:]
|
| 429 |
+
sin = self._sin_cached[seqlen_offset:]
|
| 430 |
+
|
| 431 |
+
if _flash_attn_rotary_available and q.device.type == "cuda":
|
| 432 |
+
q_rot = apply_triton_rotary(q, cos, sin, interleaved=self.interleaved) # type: ignore[misc]
|
| 433 |
+
k_rot = apply_triton_rotary(k, cos, sin, interleaved=self.interleaved) # type: ignore[misc]
|
| 434 |
+
else:
|
| 435 |
+
q_rot = _apply_rotary_emb_torch(q, cos, sin, self.interleaved)
|
| 436 |
+
k_rot = _apply_rotary_emb_torch(k, cos, sin, self.interleaved)
|
| 437 |
+
return q_rot, k_rot
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class _TritonRotaryEmbedding(RotaryEmbedding):
|
| 441 |
+
"""RoPE variant that delegates to the Flash-Attention Triton kernel.
|
| 442 |
+
|
| 443 |
+
Only used inside :class:`_FlashMultiHeadAttention` when Flash Attention 2
|
| 444 |
+
is available. The ``forward`` signature differs from :class:`RotaryEmbedding`
|
| 445 |
+
because Flash Attention packs Q, K, V together.
|
| 446 |
+
"""
|
| 447 |
+
|
| 448 |
+
def forward(
|
| 449 |
+
self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int
|
| 450 |
+
) -> torch.Tensor: # type: ignore[override]
|
| 451 |
+
"""Apply RoPE in-place to a packed ``(N, 3, n_heads, head_dim)`` tensor."""
|
| 452 |
+
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
| 453 |
+
assert self._cos_cached is not None and self._sin_cached is not None
|
| 454 |
+
assert apply_triton_rotary is not None
|
| 455 |
+
|
| 456 |
+
apply_triton_rotary(
|
| 457 |
+
qkv[:, 0],
|
| 458 |
+
self._cos_cached,
|
| 459 |
+
self._sin_cached,
|
| 460 |
+
cu_seqlens=cu_seqlens,
|
| 461 |
+
max_seqlen=max_seqlen,
|
| 462 |
+
inplace=True,
|
| 463 |
+
)
|
| 464 |
+
apply_triton_rotary(
|
| 465 |
+
qkv[:, 1],
|
| 466 |
+
self._cos_cached,
|
| 467 |
+
self._sin_cached,
|
| 468 |
+
cu_seqlens=cu_seqlens,
|
| 469 |
+
max_seqlen=max_seqlen,
|
| 470 |
+
inplace=True,
|
| 471 |
+
)
|
| 472 |
+
return qkv
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
# ---------------------------------------------------------------------------
|
| 476 |
+
# Feed-forward network helpers
|
| 477 |
+
# ---------------------------------------------------------------------------
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def _swiglu_hidden_dim(expansion_ratio: float, d_model: int) -> int:
|
| 481 |
+
"""Round hidden dim to the nearest multiple of 256 after applying expansion_ratio."""
|
| 482 |
+
return int(((expansion_ratio * d_model) + 255) // 256 * 256)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class _SwiGLU(nn.Module):
|
| 486 |
+
"""SwiGLU activation: ``silu(x1) * x2`` where ``x`` is split along the last dim."""
|
| 487 |
+
|
| 488 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 489 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 490 |
+
return F.silu(x1) * x2
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class _PyTorchLayerNormLinear(nn.Module):
|
| 494 |
+
"""LayerNorm followed by a Linear projection, sharing the parameter
|
| 495 |
+
names ``layer_norm_weight``, ``layer_norm_bias`` and ``weight`` so the
|
| 496 |
+
state-dict layout matches the accelerated TE module loaded on GPU.
|
| 497 |
+
"""
|
| 498 |
+
|
| 499 |
+
def __init__(self, d_in: int, d_out: int, eps: float = 1e-5) -> None:
|
| 500 |
+
super().__init__()
|
| 501 |
+
self.d_in = d_in
|
| 502 |
+
self.eps = eps
|
| 503 |
+
self.layer_norm_weight = nn.Parameter(torch.ones(d_in))
|
| 504 |
+
self.layer_norm_bias = nn.Parameter(torch.zeros(d_in))
|
| 505 |
+
self.weight = nn.Parameter(torch.empty(d_out, d_in))
|
| 506 |
+
nn.init.normal_(self.weight, std=0.02)
|
| 507 |
+
|
| 508 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 509 |
+
x = F.layer_norm(
|
| 510 |
+
x, (self.d_in,), self.layer_norm_weight, self.layer_norm_bias, self.eps
|
| 511 |
+
)
|
| 512 |
+
return F.linear(x, self.weight)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class _PyTorchLayerNormMLP(nn.Module):
|
| 516 |
+
"""LayerNorm + SwiGLU MLP, sharing the parameter names
|
| 517 |
+
``layer_norm_weight``, ``layer_norm_bias``, ``fc1_weight``,
|
| 518 |
+
``fc2_weight`` so the state-dict layout matches the accelerated TE
|
| 519 |
+
module loaded on GPU.
|
| 520 |
+
"""
|
| 521 |
+
|
| 522 |
+
def __init__(
|
| 523 |
+
self, hidden_size: int, ffn_hidden_size: int, eps: float = 1e-5
|
| 524 |
+
) -> None:
|
| 525 |
+
super().__init__()
|
| 526 |
+
self.hidden_size = hidden_size
|
| 527 |
+
self.ffn_hidden_size = ffn_hidden_size
|
| 528 |
+
self.eps = eps
|
| 529 |
+
self.layer_norm_weight = nn.Parameter(torch.ones(hidden_size))
|
| 530 |
+
self.layer_norm_bias = nn.Parameter(torch.zeros(hidden_size))
|
| 531 |
+
self.fc1_weight = nn.Parameter(torch.empty(2 * ffn_hidden_size, hidden_size))
|
| 532 |
+
self.fc2_weight = nn.Parameter(torch.empty(hidden_size, ffn_hidden_size))
|
| 533 |
+
nn.init.normal_(self.fc1_weight, std=0.02)
|
| 534 |
+
nn.init.normal_(self.fc2_weight, std=0.02)
|
| 535 |
+
|
| 536 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 537 |
+
x = F.layer_norm(
|
| 538 |
+
x,
|
| 539 |
+
(self.hidden_size,),
|
| 540 |
+
self.layer_norm_weight,
|
| 541 |
+
self.layer_norm_bias,
|
| 542 |
+
self.eps,
|
| 543 |
+
)
|
| 544 |
+
x = F.linear(x, self.fc1_weight)
|
| 545 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 546 |
+
x = F.silu(x1) * x2
|
| 547 |
+
return F.linear(x, self.fc2_weight)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def _swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool) -> nn.Module:
|
| 551 |
+
"""LayerNorm + SwiGLU MLP. Uses Transformer Engine's fused LN+MLP when
|
| 552 |
+
available; otherwise returns the pure-PyTorch fallback with matching
|
| 553 |
+
state-dict layout."""
|
| 554 |
+
assert not bias, "ESMC was trained with bias=False; bias=True not supported"
|
| 555 |
+
hidden = _swiglu_hidden_dim(expansion_ratio, d_model)
|
| 556 |
+
if _te_available:
|
| 557 |
+
return te.LayerNormMLP( # type: ignore[union-attr]
|
| 558 |
+
hidden_size=d_model,
|
| 559 |
+
ffn_hidden_size=hidden,
|
| 560 |
+
bias=bias,
|
| 561 |
+
activation="swiglu",
|
| 562 |
+
init_method=None,
|
| 563 |
+
output_layer_init_method=None,
|
| 564 |
+
)
|
| 565 |
+
return _PyTorchLayerNormMLP(hidden_size=d_model, ffn_hidden_size=hidden)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def _make_attn_layernorm_qkv(d_model: int, bias: bool) -> nn.Module:
|
| 569 |
+
"""LayerNorm + fused QKV projection. Uses Transformer Engine when
|
| 570 |
+
available; pure-PyTorch fallback otherwise."""
|
| 571 |
+
assert not bias, "ESMC was trained with bias=False; bias=True not supported"
|
| 572 |
+
if _te_available:
|
| 573 |
+
return te.LayerNormLinear( # type: ignore[union-attr]
|
| 574 |
+
d_model, d_model * 3, bias=bias, init_method=None
|
| 575 |
+
)
|
| 576 |
+
return _PyTorchLayerNormLinear(d_model, d_model * 3)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def _make_attn_out_proj(d_model: int, bias: bool) -> nn.Module:
|
| 580 |
+
"""Attention output projection. Uses Transformer Engine when available;
|
| 581 |
+
pure-PyTorch ``nn.Linear`` otherwise."""
|
| 582 |
+
if _te_available:
|
| 583 |
+
return te.Linear( # type: ignore[union-attr]
|
| 584 |
+
d_model, d_model, bias=bias, init_method=None
|
| 585 |
+
)
|
| 586 |
+
return nn.Linear(d_model, d_model, bias=bias)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def _gelu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool) -> nn.Sequential:
|
| 590 |
+
hidden = int(expansion_ratio * d_model)
|
| 591 |
+
return nn.Sequential(
|
| 592 |
+
nn.LayerNorm(d_model),
|
| 593 |
+
nn.Linear(d_model, hidden, bias=bias),
|
| 594 |
+
nn.GELU(),
|
| 595 |
+
nn.Linear(hidden, d_model, bias=bias),
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
# ---------------------------------------------------------------------------
|
| 600 |
+
# Attention
|
| 601 |
+
# ---------------------------------------------------------------------------
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def _scaled_dot_product_attention(
|
| 605 |
+
q: torch.Tensor,
|
| 606 |
+
k: torch.Tensor,
|
| 607 |
+
v: torch.Tensor,
|
| 608 |
+
*,
|
| 609 |
+
n_heads: int,
|
| 610 |
+
d_head: int,
|
| 611 |
+
seq_id: torch.Tensor | None,
|
| 612 |
+
) -> torch.Tensor:
|
| 613 |
+
"""Scaled dot-product attention with optional chain-aware mask.
|
| 614 |
+
|
| 615 |
+
Dispatches in order of preference:
|
| 616 |
+
1. xformers ``memory_efficient_attention`` — preferred fused kernel,
|
| 617 |
+
requires ``xformers``, no chain mask.
|
| 618 |
+
2. Flash Attention 2 (``flash_attn.flash_attn_func``) — secondary
|
| 619 |
+
fused kernel, requires ``flash-attn``, no chain mask, fp16 /
|
| 620 |
+
bf16 only.
|
| 621 |
+
3. PyTorch's ``F.scaled_dot_product_attention`` — last-resort path;
|
| 622 |
+
also handles the chain-aware mask when ``seq_id`` is present
|
| 623 |
+
and the fp32 path that Flash Attention 2 does not support.
|
| 624 |
+
"""
|
| 625 |
+
if seq_id is None and _xformers_available:
|
| 626 |
+
b, s, _ = q.shape
|
| 627 |
+
q4 = q.view(b, s, n_heads, d_head)
|
| 628 |
+
k4 = k.view(b, s, n_heads, d_head)
|
| 629 |
+
v4 = v.view(b, s, n_heads, d_head)
|
| 630 |
+
context = xops.memory_efficient_attention( # type: ignore[union-attr]
|
| 631 |
+
q4, k4, v4, attn_bias=None, scale=d_head**-0.5
|
| 632 |
+
)
|
| 633 |
+
return context.reshape(b, s, n_heads * d_head)
|
| 634 |
+
if (
|
| 635 |
+
seq_id is None
|
| 636 |
+
and _flash_attn_available
|
| 637 |
+
and q.dtype in (torch.float16, torch.bfloat16)
|
| 638 |
+
):
|
| 639 |
+
b, s, _ = q.shape
|
| 640 |
+
q4 = q.view(b, s, n_heads, d_head)
|
| 641 |
+
k4 = k.view(b, s, n_heads, d_head)
|
| 642 |
+
v4 = v.view(b, s, n_heads, d_head)
|
| 643 |
+
context = flash_attn_func( # type: ignore[misc]
|
| 644 |
+
q4, k4, v4, dropout_p=0.0, softmax_scale=d_head**-0.5
|
| 645 |
+
)
|
| 646 |
+
return context.reshape(b, s, n_heads * d_head) # type: ignore[union-attr]
|
| 647 |
+
b, s, _ = q.shape
|
| 648 |
+
q = q.view(b, s, n_heads, -1).transpose(1, 2)
|
| 649 |
+
k = k.view(b, s, n_heads, -1).transpose(1, 2)
|
| 650 |
+
v = v.view(b, s, n_heads, -1).transpose(1, 2)
|
| 651 |
+
if seq_id is not None:
|
| 652 |
+
mask = (seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)).unsqueeze(1)
|
| 653 |
+
context = F.scaled_dot_product_attention(q, k, v, mask)
|
| 654 |
+
else:
|
| 655 |
+
context = F.scaled_dot_product_attention(q, k, v)
|
| 656 |
+
_, h, _, d_out = context.shape
|
| 657 |
+
return context.transpose(1, 2).reshape(b, s, h * d_out)
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
class MultiHeadAttention(nn.Module):
|
| 661 |
+
"""Multi-head self-attention with QK LayerNorm and RoPE.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
d_model: Model hidden dimension.
|
| 665 |
+
n_heads: Number of attention heads.
|
| 666 |
+
bias: Whether to use bias in linear layers.
|
| 667 |
+
qk_layernorm: Whether to apply LayerNorm to queries and keys before
|
| 668 |
+
computing attention scores.
|
| 669 |
+
"""
|
| 670 |
+
|
| 671 |
+
def __init__(
|
| 672 |
+
self, d_model: int, n_heads: int, bias: bool = False, qk_layernorm: bool = True
|
| 673 |
+
):
|
| 674 |
+
super().__init__()
|
| 675 |
+
self.d_model = d_model
|
| 676 |
+
self.n_heads = n_heads
|
| 677 |
+
self.d_head = d_model // n_heads
|
| 678 |
+
|
| 679 |
+
assert not bias, "ESMC was trained with bias=False; bias=True not supported"
|
| 680 |
+
self.layernorm_qkv = _make_attn_layernorm_qkv(d_model, bias)
|
| 681 |
+
self.out_proj = _make_attn_out_proj(d_model, bias)
|
| 682 |
+
|
| 683 |
+
if qk_layernorm:
|
| 684 |
+
self.q_ln = nn.LayerNorm(d_model, bias=bias)
|
| 685 |
+
self.k_ln = nn.LayerNorm(d_model, bias=bias)
|
| 686 |
+
else:
|
| 687 |
+
self.q_ln = nn.Identity()
|
| 688 |
+
self.k_ln = nn.Identity()
|
| 689 |
+
|
| 690 |
+
self.rotary = RotaryEmbedding(d_model // n_heads)
|
| 691 |
+
|
| 692 |
+
def _apply_rotary(
|
| 693 |
+
self, q: torch.Tensor, k: torch.Tensor
|
| 694 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 695 |
+
q = q.unflatten(-1, (self.n_heads, self.d_head))
|
| 696 |
+
k = k.unflatten(-1, (self.n_heads, self.d_head))
|
| 697 |
+
q, k = self.rotary(q, k)
|
| 698 |
+
q = q.flatten(-2, -1)
|
| 699 |
+
k = k.flatten(-2, -1)
|
| 700 |
+
return q, k
|
| 701 |
+
|
| 702 |
+
def forward(
|
| 703 |
+
self,
|
| 704 |
+
x: torch.Tensor,
|
| 705 |
+
seq_id: torch.Tensor | None,
|
| 706 |
+
output_attentions: bool = False,
|
| 707 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 708 |
+
"""Return ``(context, attn_weights)``.
|
| 709 |
+
|
| 710 |
+
``attn_weights`` is ``None`` unless ``output_attentions=True`` — the
|
| 711 |
+
fused SDPA backends (xformers, flash-attn 2, ``F.scaled_dot_product_attention``)
|
| 712 |
+
don't expose attention probabilities, so capturing them forces a
|
| 713 |
+
materialized ``softmax(Q @ K.T / sqrt(d)) @ V`` path with shape
|
| 714 |
+
``(B, H, L, L)``.
|
| 715 |
+
"""
|
| 716 |
+
qkv = self.layernorm_qkv(x)
|
| 717 |
+
q, k, v = torch.chunk(qkv, 3, dim=-1)
|
| 718 |
+
q = self.q_ln(q).to(q.dtype)
|
| 719 |
+
k = self.k_ln(k).to(q.dtype)
|
| 720 |
+
q, k = self._apply_rotary(q, k)
|
| 721 |
+
|
| 722 |
+
b, s, _ = q.shape
|
| 723 |
+
|
| 724 |
+
if output_attentions:
|
| 725 |
+
# Manual SDPA so attention probabilities are observable.
|
| 726 |
+
q4 = q.view(b, s, self.n_heads, self.d_head).transpose(1, 2)
|
| 727 |
+
k4 = k.view(b, s, self.n_heads, self.d_head).transpose(1, 2)
|
| 728 |
+
v4 = v.view(b, s, self.n_heads, self.d_head).transpose(1, 2)
|
| 729 |
+
scale = self.d_head**-0.5
|
| 730 |
+
attn_scores = (q4 @ k4.transpose(-2, -1)) * scale
|
| 731 |
+
if seq_id is not None:
|
| 732 |
+
mask = (seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)).unsqueeze(1)
|
| 733 |
+
attn_scores = attn_scores.masked_fill(~mask, float("-inf"))
|
| 734 |
+
attn_weights = torch.softmax(attn_scores, dim=-1)
|
| 735 |
+
context = (attn_weights @ v4).transpose(1, 2).reshape(b, s, -1)
|
| 736 |
+
return self.out_proj(context), attn_weights
|
| 737 |
+
|
| 738 |
+
context = _scaled_dot_product_attention(
|
| 739 |
+
q, k, v, n_heads=self.n_heads, d_head=self.d_head, seq_id=seq_id
|
| 740 |
+
)
|
| 741 |
+
return self.out_proj(context), None
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
class _FlashMultiHeadAttention(MultiHeadAttention):
|
| 745 |
+
"""Flash-Attention 2 variant of :class:`MultiHeadAttention`."""
|
| 746 |
+
|
| 747 |
+
def __init__(
|
| 748 |
+
self, d_model: int, n_heads: int, bias: bool = False, qk_layernorm: bool = True
|
| 749 |
+
):
|
| 750 |
+
super().__init__(
|
| 751 |
+
d_model=d_model, n_heads=n_heads, bias=bias, qk_layernorm=qk_layernorm
|
| 752 |
+
)
|
| 753 |
+
self.rotary = _TritonRotaryEmbedding(d_model // n_heads)
|
| 754 |
+
|
| 755 |
+
def forward(
|
| 756 |
+
self,
|
| 757 |
+
x: torch.Tensor,
|
| 758 |
+
seq_id: torch.Tensor | None,
|
| 759 |
+
output_attentions: bool = False,
|
| 760 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 761 |
+
if output_attentions:
|
| 762 |
+
raise ValueError(
|
| 763 |
+
"output_attentions=True is not supported with "
|
| 764 |
+
"attn_implementation='flash_attention_2'. "
|
| 765 |
+
"Re-load the model with attn_implementation='sdpa' (or 'eager')."
|
| 766 |
+
)
|
| 767 |
+
assert seq_id is not None and seq_id.dtype == torch.bool
|
| 768 |
+
|
| 769 |
+
seqlens = seq_id.sum(dim=-1, dtype=torch.int32)
|
| 770 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
|
| 771 |
+
max_seqlen = int(seqlens.max().item())
|
| 772 |
+
|
| 773 |
+
qkv = self.layernorm_qkv(x)
|
| 774 |
+
q, k, v = torch.chunk(qkv, 3, dim=-1)
|
| 775 |
+
q = self.q_ln(q).to(q.dtype)
|
| 776 |
+
k = self.k_ln(k).to(q.dtype)
|
| 777 |
+
|
| 778 |
+
# ``q``/``k``/``v`` are 2D ``(T, D)`` here: the parent ``ESMCModel.forward``
|
| 779 |
+
# calls ``unpad_input`` before the transformer stack to produce the
|
| 780 |
+
# varlen-flat layout that ``flash_attn_varlen_qkvpacked_func`` requires.
|
| 781 |
+
T = q.shape[0]
|
| 782 |
+
qkv_packed = torch.stack([q, k, v], dim=1).view(T, 3, self.n_heads, self.d_head)
|
| 783 |
+
qkv_packed = self.rotary(qkv_packed, cu_seqlens, max_seqlen)
|
| 784 |
+
|
| 785 |
+
context = flash_attn_varlen_qkvpacked_func( # type: ignore[misc]
|
| 786 |
+
qkv_packed, cu_seqlens, max_seqlen, softmax_scale=self.d_head**-0.5
|
| 787 |
+
)
|
| 788 |
+
n_out, h_out, d_out = context.shape # type: ignore[union-attr]
|
| 789 |
+
return (
|
| 790 |
+
self.out_proj(context.reshape(n_out, h_out * d_out)), # type: ignore[union-attr]
|
| 791 |
+
None,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
# ---------------------------------------------------------------------------
|
| 796 |
+
# Transformer blocks
|
| 797 |
+
# ---------------------------------------------------------------------------
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
class UnifiedTransformerBlock(nn.Module):
|
| 801 |
+
"""Single transformer block: pre-norm attention + pre-norm FFN with residual scaling.
|
| 802 |
+
|
| 803 |
+
Args:
|
| 804 |
+
d_model: Hidden dimension.
|
| 805 |
+
n_heads: Number of attention heads.
|
| 806 |
+
use_flash_attn: Use Flash Attention 2 kernel if available.
|
| 807 |
+
bias: Whether linear layers include bias terms.
|
| 808 |
+
expansion_ratio: Hidden-dim expansion ratio for the FFN.
|
| 809 |
+
residue_scaling_factor: Scales residual connections to stabilise deep
|
| 810 |
+
networks (``1 / sqrt(n_layers / 36)`` is the ESM3 scheme).
|
| 811 |
+
qk_layernorm: Whether to apply QK LayerNorm in attention.
|
| 812 |
+
ffn_type: Feed-forward activation: ``"swiglu"`` or ``"gelu"``.
|
| 813 |
+
"""
|
| 814 |
+
|
| 815 |
+
def __init__(
|
| 816 |
+
self,
|
| 817 |
+
d_model: int,
|
| 818 |
+
n_heads: int,
|
| 819 |
+
use_flash_attn: bool = False,
|
| 820 |
+
bias: bool = False,
|
| 821 |
+
expansion_ratio: float = 4.0,
|
| 822 |
+
residue_scaling_factor: float = 1.0,
|
| 823 |
+
qk_layernorm: bool = True,
|
| 824 |
+
ffn_type: str = "swiglu",
|
| 825 |
+
):
|
| 826 |
+
super().__init__()
|
| 827 |
+
|
| 828 |
+
attn_cls = _FlashMultiHeadAttention if use_flash_attn else MultiHeadAttention
|
| 829 |
+
self.attn = attn_cls(d_model, n_heads, bias=bias, qk_layernorm=qk_layernorm)
|
| 830 |
+
|
| 831 |
+
if ffn_type == "swiglu":
|
| 832 |
+
self.ffn = _swiglu_ln_ffn(d_model, expansion_ratio, bias)
|
| 833 |
+
elif ffn_type == "gelu":
|
| 834 |
+
self.ffn = _gelu_ln_ffn(d_model, expansion_ratio, bias)
|
| 835 |
+
else:
|
| 836 |
+
raise ValueError(
|
| 837 |
+
f"Unknown ffn_type: {ffn_type!r}. Choose 'swiglu' or 'gelu'."
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
self.scaling_factor = residue_scaling_factor
|
| 841 |
+
|
| 842 |
+
def forward(
|
| 843 |
+
self,
|
| 844 |
+
x: torch.Tensor,
|
| 845 |
+
sequence_id: torch.Tensor | None,
|
| 846 |
+
output_attentions: bool = False,
|
| 847 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
| 848 |
+
"""
|
| 849 |
+
Args:
|
| 850 |
+
x: ``(batch, seq_len, d_model)``
|
| 851 |
+
sequence_id: ``(batch, seq_len)`` chain-ID tensor used to restrict
|
| 852 |
+
attention to tokens within the same chain. SDPA blocks accept
|
| 853 |
+
an integer tensor (``-1`` marks padding); the flash-attn block
|
| 854 |
+
takes a ``bool`` padding mask — the caller selects which.
|
| 855 |
+
``None`` skips chain-aware masking entirely (fast path).
|
| 856 |
+
output_attentions: When ``True``, returns the per-head attention
|
| 857 |
+
weights for this block alongside the residual output.
|
| 858 |
+
|
| 859 |
+
Returns:
|
| 860 |
+
``(output, attn_weights_or_None)``. Shape of ``output`` is
|
| 861 |
+
``(batch, seq_len, d_model)``; ``attn_weights`` shape is
|
| 862 |
+
``(batch, num_heads, seq_len, seq_len)`` or ``None``.
|
| 863 |
+
"""
|
| 864 |
+
attn_out, attn_weights = self.attn(
|
| 865 |
+
x, sequence_id, output_attentions=output_attentions
|
| 866 |
+
)
|
| 867 |
+
x = x + attn_out / self.scaling_factor
|
| 868 |
+
x = x + self.ffn(x) / self.scaling_factor
|
| 869 |
+
return x, attn_weights
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
class TransformerStack(nn.Module):
|
| 873 |
+
"""Stack of :class:`UnifiedTransformerBlock` layers with a final LayerNorm.
|
| 874 |
+
|
| 875 |
+
Args:
|
| 876 |
+
d_model: Hidden dimension.
|
| 877 |
+
n_heads: Number of attention heads.
|
| 878 |
+
n_layers: Number of transformer blocks.
|
| 879 |
+
scale_residue: When ``True`` apply ESM3 residue scaling
|
| 880 |
+
``sqrt(n_layers / 36)`` to each block.
|
| 881 |
+
bias: Bias flag forwarded to every sub-module.
|
| 882 |
+
qk_layernorm: QK LayerNorm flag forwarded to every block.
|
| 883 |
+
ffn_type: FFN activation type (``"swiglu"`` or ``"gelu"``).
|
| 884 |
+
expansion_ratio: FFN expansion ratio.
|
| 885 |
+
use_flash_attn: Use Flash Attention 2 kernel when available.
|
| 886 |
+
"""
|
| 887 |
+
|
| 888 |
+
def __init__(
|
| 889 |
+
self,
|
| 890 |
+
d_model: int,
|
| 891 |
+
n_heads: int,
|
| 892 |
+
n_layers: int,
|
| 893 |
+
scale_residue: bool = True,
|
| 894 |
+
bias: bool = False,
|
| 895 |
+
qk_layernorm: bool = True,
|
| 896 |
+
ffn_type: str = "swiglu",
|
| 897 |
+
expansion_ratio: float = 8 / 3,
|
| 898 |
+
use_flash_attn: bool = False,
|
| 899 |
+
):
|
| 900 |
+
super().__init__()
|
| 901 |
+
self.blocks = nn.ModuleList(
|
| 902 |
+
[
|
| 903 |
+
UnifiedTransformerBlock(
|
| 904 |
+
d_model,
|
| 905 |
+
n_heads,
|
| 906 |
+
use_flash_attn=use_flash_attn,
|
| 907 |
+
residue_scaling_factor=math.sqrt(n_layers / 36)
|
| 908 |
+
if scale_residue
|
| 909 |
+
else 1.0,
|
| 910 |
+
expansion_ratio=expansion_ratio,
|
| 911 |
+
bias=bias,
|
| 912 |
+
qk_layernorm=qk_layernorm,
|
| 913 |
+
ffn_type=ffn_type,
|
| 914 |
+
)
|
| 915 |
+
for _ in range(n_layers)
|
| 916 |
+
]
|
| 917 |
+
)
|
| 918 |
+
self.norm = nn.LayerNorm(d_model, bias=False)
|
| 919 |
+
|
| 920 |
+
def forward(
|
| 921 |
+
self,
|
| 922 |
+
x: torch.Tensor,
|
| 923 |
+
sequence_id: torch.Tensor | None = None,
|
| 924 |
+
layers_to_collect: list[int] | None = None,
|
| 925 |
+
output_attentions: bool = False,
|
| 926 |
+
) -> tuple[
|
| 927 |
+
torch.Tensor,
|
| 928 |
+
torch.Tensor,
|
| 929 |
+
tuple[torch.Tensor, ...],
|
| 930 |
+
tuple[torch.Tensor, ...] | None,
|
| 931 |
+
]:
|
| 932 |
+
"""Run the full transformer stack.
|
| 933 |
+
|
| 934 |
+
Args:
|
| 935 |
+
x: ``(batch, seq_len, d_model)``
|
| 936 |
+
sequence_id: Optional chain-id tensor forwarded to each block.
|
| 937 |
+
layers_to_collect: Layer indices (0-based pre-block inputs plus
|
| 938 |
+
``n_layers`` for the post-norm output) whose hidden states
|
| 939 |
+
should be returned.
|
| 940 |
+
output_attentions: When ``True``, collects the per-block attention
|
| 941 |
+
weights and returns them as the fourth tuple element.
|
| 942 |
+
|
| 943 |
+
Returns:
|
| 944 |
+
``(post_norm, pre_norm, hidden_states, attentions)`` where
|
| 945 |
+
``hidden_states`` is a (possibly empty) tuple of tensors and
|
| 946 |
+
``attentions`` is a tuple of per-block ``(B, H, L, L)`` tensors
|
| 947 |
+
or ``None`` when ``output_attentions`` is ``False``.
|
| 948 |
+
"""
|
| 949 |
+
if layers_to_collect is None:
|
| 950 |
+
layers_to_collect = []
|
| 951 |
+
|
| 952 |
+
collected: list[torch.Tensor] = []
|
| 953 |
+
all_attentions: list[torch.Tensor] = []
|
| 954 |
+
for layer_idx, block in enumerate(self.blocks):
|
| 955 |
+
if layer_idx in layers_to_collect:
|
| 956 |
+
collected.append(x)
|
| 957 |
+
x, attn_weights = block(x, sequence_id, output_attentions=output_attentions)
|
| 958 |
+
if output_attentions and attn_weights is not None:
|
| 959 |
+
all_attentions.append(attn_weights)
|
| 960 |
+
|
| 961 |
+
norm_x = self.norm(x)
|
| 962 |
+
if len(self.blocks) in layers_to_collect:
|
| 963 |
+
collected.append(norm_x)
|
| 964 |
+
|
| 965 |
+
attentions = tuple(all_attentions) if output_attentions else None
|
| 966 |
+
return norm_x, x, tuple(collected), attentions
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
# ---------------------------------------------------------------------------
|
| 970 |
+
# Pre-trained model base class
|
| 971 |
+
# ---------------------------------------------------------------------------
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
@auto_docstring
|
| 975 |
+
class ESMCPreTrainedModel(PreTrainedModel):
|
| 976 |
+
"""Base class for ESMC models.
|
| 977 |
+
|
| 978 |
+
Handles weight initialisation and declares module-level capabilities.
|
| 979 |
+
"""
|
| 980 |
+
|
| 981 |
+
config_class = ESMCConfig
|
| 982 |
+
base_model_prefix = "esmc"
|
| 983 |
+
supports_gradient_checkpointing = False
|
| 984 |
+
_supports_sdpa = True
|
| 985 |
+
_supports_flash_attn = True
|
| 986 |
+
_supports_attention_backend = True
|
| 987 |
+
_no_split_modules = ["UnifiedTransformerBlock"]
|
| 988 |
+
_keys_to_ignore_on_load_unexpected = [r"\._extra_state$"]
|
| 989 |
+
|
| 990 |
+
def _init_weights(self, module: nn.Module):
|
| 991 |
+
std = self.config.initializer_range
|
| 992 |
+
if isinstance(module, nn.Linear):
|
| 993 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 994 |
+
if module.bias is not None:
|
| 995 |
+
module.bias.data.zero_()
|
| 996 |
+
elif isinstance(module, RotaryEmbedding):
|
| 997 |
+
module.reset_parameters(device=self.device)
|
| 998 |
+
|
| 999 |
+
|
| 1000 |
+
# ---------------------------------------------------------------------------
|
| 1001 |
+
# Base encoder model
|
| 1002 |
+
# ---------------------------------------------------------------------------
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
@auto_docstring
|
| 1006 |
+
class ESMCModel(ESMCPreTrainedModel):
|
| 1007 |
+
"""The bare ESMC encoder outputting raw hidden states.
|
| 1008 |
+
|
| 1009 |
+
ESMC is a protein language model trained by EvolutionaryScale using a
|
| 1010 |
+
masked-token objective over amino acid sequences. The architecture is a
|
| 1011 |
+
standard Transformer encoder with RoPE positional embeddings, QK LayerNorm,
|
| 1012 |
+
and SwiGLU feed-forward networks.
|
| 1013 |
+
|
| 1014 |
+
Args:
|
| 1015 |
+
config: An :class:`ESMCConfig` instance.
|
| 1016 |
+
"""
|
| 1017 |
+
|
| 1018 |
+
def __init__(self, config: ESMCConfig):
|
| 1019 |
+
super().__init__(config)
|
| 1020 |
+
self._use_flash_attn = (
|
| 1021 |
+
_flash_attn_available and config._attn_implementation == "flash_attention_2"
|
| 1022 |
+
)
|
| 1023 |
+
self.embed = nn.Embedding(config.vocab_size, config.d_model)
|
| 1024 |
+
self.transformer = TransformerStack(
|
| 1025 |
+
config.d_model,
|
| 1026 |
+
config.n_heads,
|
| 1027 |
+
config.n_layers,
|
| 1028 |
+
use_flash_attn=self._use_flash_attn,
|
| 1029 |
+
)
|
| 1030 |
+
self._sae_models: nn.ModuleDict = nn.ModuleDict()
|
| 1031 |
+
self.post_init()
|
| 1032 |
+
|
| 1033 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 1034 |
+
return self.embed
|
| 1035 |
+
|
| 1036 |
+
def set_input_embeddings(self, value: nn.Embedding):
|
| 1037 |
+
self.embed = value
|
| 1038 |
+
|
| 1039 |
+
def add_sae_models(self, sae_models: list[_ESMCSAELayer]) -> None:
|
| 1040 |
+
"""Register one or more SAEs obtained from an :class:`ESMCSAEModel`.
|
| 1041 |
+
|
| 1042 |
+
Each is keyed by ``f"layer{N}"`` (the backbone-layer index ``N`` the
|
| 1043 |
+
SAE is trained against, set by
|
| 1044 |
+
:meth:`ESMCSAEModel.initialize_layers`). Attaching two SAEs for the
|
| 1045 |
+
same backbone layer raises — only one SAE per layer can be active.
|
| 1046 |
+
|
| 1047 |
+
Example::
|
| 1048 |
+
|
| 1049 |
+
sae = ESMCSAEModel.from_pretrained(
|
| 1050 |
+
"biohub/esmc-600m-2024-12-sae-k64-codebook16384"
|
| 1051 |
+
)
|
| 1052 |
+
sae.initialize_layers([27, 33])
|
| 1053 |
+
model.add_sae_models([sae.layers["27"], sae.layers["33"]])
|
| 1054 |
+
"""
|
| 1055 |
+
for layer in sae_models:
|
| 1056 |
+
assert isinstance(layer, _ESMCSAELayer), (
|
| 1057 |
+
f"Expected an SAE layer (model.layers['<idx>']), got "
|
| 1058 |
+
f"{type(layer).__name__}."
|
| 1059 |
+
)
|
| 1060 |
+
key = f"layer{int(layer.layer)}"
|
| 1061 |
+
if key in self._sae_models:
|
| 1062 |
+
raise ValueError(
|
| 1063 |
+
f"An SAE is already registered at {key!r}. Only one SAE "
|
| 1064 |
+
"per backbone layer can be active — pick a different "
|
| 1065 |
+
"layer on one of them, or attach in a fresh model."
|
| 1066 |
+
)
|
| 1067 |
+
self._sae_models[key] = layer
|
| 1068 |
+
|
| 1069 |
+
_SAE_KEY_RE = re.compile(r"layer(\d+)")
|
| 1070 |
+
|
| 1071 |
+
def _get_sae_layer_num_requested(self, model_name: str) -> int:
|
| 1072 |
+
"""Recover the backbone-layer index from a key written by
|
| 1073 |
+
:meth:`add_sae_models` (``"layer{N}"`` → ``N``)."""
|
| 1074 |
+
match = self._SAE_KEY_RE.fullmatch(model_name)
|
| 1075 |
+
assert (
|
| 1076 |
+
match is not None
|
| 1077 |
+
), f"Unexpected SAE key {model_name!r}; expected 'layer{{N}}'."
|
| 1078 |
+
return int(match.group(1))
|
| 1079 |
+
|
| 1080 |
+
def _validate_sae_inputs(self, input_ids: torch.Tensor) -> None:
|
| 1081 |
+
assert torch.all(input_ids != self.config.mask_token_id), (
|
| 1082 |
+
"SAE inputs must not contain mask tokens. "
|
| 1083 |
+
"SAEs were trained on unmasked sequences."
|
| 1084 |
+
)
|
| 1085 |
+
|
| 1086 |
+
def _get_sae_outputs(
|
| 1087 |
+
self,
|
| 1088 |
+
hidden_states: torch.Tensor,
|
| 1089 |
+
layers_to_collect: list[int],
|
| 1090 |
+
token_mask: torch.Tensor,
|
| 1091 |
+
normalize_sae: bool = False,
|
| 1092 |
+
) -> dict[str, torch.Tensor]:
|
| 1093 |
+
"""Run all registered SAEs and return their feature magnitudes.
|
| 1094 |
+
|
| 1095 |
+
Args:
|
| 1096 |
+
hidden_states: Stacked tensor of shape
|
| 1097 |
+
``(len(layers_to_collect), batch, seq_len, d_model)``.
|
| 1098 |
+
layers_to_collect: The ESMC layer indices that were collected,
|
| 1099 |
+
in the same order as the first dim of ``hidden_states``.
|
| 1100 |
+
token_mask: Boolean mask ``(batch, seq_len)`` — ``True`` for
|
| 1101 |
+
real (non-padding) tokens.
|
| 1102 |
+
normalize_sae: When ``True``, scale features by ``idf / max``
|
| 1103 |
+
using the per-feature stats trained alongside each SAE.
|
| 1104 |
+
"""
|
| 1105 |
+
layer_to_idx = {layer: idx for idx, layer in enumerate(layers_to_collect)}
|
| 1106 |
+
sae_outputs: dict[str, torch.Tensor] = {}
|
| 1107 |
+
|
| 1108 |
+
for model_name, sae_module in self._sae_models.items():
|
| 1109 |
+
# `nn.ModuleDict` only stores `nn.Module`s at the type level;
|
| 1110 |
+
# ``add_sae_models`` enforces that each entry is an ``_ESMCSAELayer``.
|
| 1111 |
+
assert isinstance(sae_module, _ESMCSAELayer)
|
| 1112 |
+
layer: _ESMCSAELayer = sae_module
|
| 1113 |
+
requested_layer = self._get_sae_layer_num_requested(model_name)
|
| 1114 |
+
layer_idx = layer_to_idx[requested_layer]
|
| 1115 |
+
layer_states = hidden_states[layer_idx].clone().to(self.device)
|
| 1116 |
+
|
| 1117 |
+
sae_out = layer.get_sae_output(layer_states, token_mask)
|
| 1118 |
+
features = sae_out.feature_magnitudes.detach()
|
| 1119 |
+
|
| 1120 |
+
if normalize_sae:
|
| 1121 |
+
# ``register_buffer`` is typed as ``Tensor | Module`` on
|
| 1122 |
+
# ``nn.Module``; narrow here since these are Tensors.
|
| 1123 |
+
idf = cast(torch.Tensor, layer.idf)
|
| 1124 |
+
max_val = cast(torch.Tensor, layer.max)
|
| 1125 |
+
features = (features / max_val) * idf
|
| 1126 |
+
|
| 1127 |
+
sae_outputs[model_name] = features.to_sparse()
|
| 1128 |
+
|
| 1129 |
+
return sae_outputs
|
| 1130 |
+
|
| 1131 |
+
@can_return_tuple
|
| 1132 |
+
@auto_docstring
|
| 1133 |
+
def forward(
|
| 1134 |
+
self,
|
| 1135 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1136 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1137 |
+
sequence_id: Optional[torch.Tensor] = None,
|
| 1138 |
+
output_hidden_states: Optional[bool] = None,
|
| 1139 |
+
output_attentions: Optional[bool] = None,
|
| 1140 |
+
return_dict: Optional[bool] = None,
|
| 1141 |
+
compute_sae: bool = True,
|
| 1142 |
+
normalize_sae: bool = False,
|
| 1143 |
+
) -> tuple[torch.Tensor, ...] | ESMCOutput:
|
| 1144 |
+
r"""
|
| 1145 |
+
sequence_id (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1146 |
+
Integer chain-ID tensor for chain-aware attention masking. Tokens with the same
|
| 1147 |
+
non-negative integer value can attend to each other; tokens with different values
|
| 1148 |
+
cannot (cross-chain masking). Padding positions should be set to ``-1``.
|
| 1149 |
+
When provided, ``attention_mask`` is ignored. The ``flash_attention_2`` backend
|
| 1150 |
+
only supports single-chain inputs (all non-padding values must be ``0``); pass
|
| 1151 |
+
multi-chain ``sequence_id`` with ``attn_implementation='sdpa'`` (or ``'eager'``).
|
| 1152 |
+
output_attentions (`bool`, *optional*):
|
| 1153 |
+
Whether to return the per-block attention weights of shape
|
| 1154 |
+
``(batch_size, num_heads, sequence_length, sequence_length)``.
|
| 1155 |
+
Forces a manual-SDPA path inside :class:`MultiHeadAttention` so the
|
| 1156 |
+
attention probabilities are observable; raises on the
|
| 1157 |
+
``flash_attention_2`` path.
|
| 1158 |
+
compute_sae (`bool`, *optional*, defaults to ``True``):
|
| 1159 |
+
Whether to run any SAE models registered via :meth:`add_sae_models`.
|
| 1160 |
+
Has no effect when no SAEs are registered.
|
| 1161 |
+
normalize_sae (`bool`, *optional*, defaults to ``False``):
|
| 1162 |
+
When ``True``, scale SAE feature magnitudes by ``idf / max`` (only
|
| 1163 |
+
applied when the SAE's normalization buffers contain non-trivial values).
|
| 1164 |
+
|
| 1165 |
+
Examples:
|
| 1166 |
+
|
| 1167 |
+
```python
|
| 1168 |
+
>>> from transformers import AutoTokenizer, ESMCModel
|
| 1169 |
+
|
| 1170 |
+
>>> model = ESMCModel.from_pretrained("Biohub/ESMC-600M-2024-12")
|
| 1171 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Biohub/ESMC-600M-2024-12")
|
| 1172 |
+
>>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt")
|
| 1173 |
+
>>> outputs = model(**inputs)
|
| 1174 |
+
>>> outputs.last_hidden_state.shape
|
| 1175 |
+
torch.Size([1, 12, 960])
|
| 1176 |
+
```
|
| 1177 |
+
"""
|
| 1178 |
+
output_hidden_states = (
|
| 1179 |
+
output_hidden_states
|
| 1180 |
+
if output_hidden_states is not None
|
| 1181 |
+
else self.config.output_hidden_states
|
| 1182 |
+
)
|
| 1183 |
+
output_attentions = (
|
| 1184 |
+
output_attentions
|
| 1185 |
+
if output_attentions is not None
|
| 1186 |
+
else self.config.output_attentions
|
| 1187 |
+
)
|
| 1188 |
+
return_dict = (
|
| 1189 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
output_sae = compute_sae and len(self._sae_models) > 0
|
| 1193 |
+
|
| 1194 |
+
# Determine which intermediate layers to collect. When SAEs are
|
| 1195 |
+
# registered we must collect at least the layers they target, even if
|
| 1196 |
+
# the caller did not ask for all hidden states.
|
| 1197 |
+
if output_hidden_states:
|
| 1198 |
+
layers_to_collect: list[int] = list(range(self.config.n_layers + 1))
|
| 1199 |
+
elif output_sae:
|
| 1200 |
+
layers_to_collect = sorted(
|
| 1201 |
+
{self._get_sae_layer_num_requested(name) for name in self._sae_models}
|
| 1202 |
+
)
|
| 1203 |
+
else:
|
| 1204 |
+
layers_to_collect = []
|
| 1205 |
+
|
| 1206 |
+
user_supplied_sequence_id = sequence_id is not None
|
| 1207 |
+
if sequence_id is not None:
|
| 1208 |
+
bool_mask = sequence_id >= 0
|
| 1209 |
+
else:
|
| 1210 |
+
if attention_mask is None:
|
| 1211 |
+
attention_mask = input_ids != self.config.pad_token_id
|
| 1212 |
+
assert attention_mask is not None
|
| 1213 |
+
bool_mask = attention_mask.bool()
|
| 1214 |
+
sequence_id = bool_mask.to(torch.long) - 1
|
| 1215 |
+
|
| 1216 |
+
x = self.embed(input_ids)
|
| 1217 |
+
b, l_ = x.shape[:2]
|
| 1218 |
+
|
| 1219 |
+
if self._use_flash_attn:
|
| 1220 |
+
if user_supplied_sequence_id and (sequence_id > 0).any():
|
| 1221 |
+
raise ValueError(
|
| 1222 |
+
"Multi-chain ``sequence_id`` (any value > 0) is not "
|
| 1223 |
+
"supported with attn_implementation='flash_attention_2'. "
|
| 1224 |
+
"Re-load the model with attn_implementation='sdpa' (or "
|
| 1225 |
+
"'eager') for chain-aware attention masking."
|
| 1226 |
+
)
|
| 1227 |
+
assert unpad_input is not None
|
| 1228 |
+
x, indices, *_ = unpad_input(x, bool_mask)
|
| 1229 |
+
else:
|
| 1230 |
+
indices = None
|
| 1231 |
+
|
| 1232 |
+
if self._use_flash_attn:
|
| 1233 |
+
trans_seq_id = bool_mask
|
| 1234 |
+
elif user_supplied_sequence_id:
|
| 1235 |
+
trans_seq_id = sequence_id
|
| 1236 |
+
elif bool_mask.all() and not output_attentions:
|
| 1237 |
+
# Fused SDPA fast path (xformers / flash) is correct only when the
|
| 1238 |
+
# mask is uniform; output_attentions forces the manual branch.
|
| 1239 |
+
trans_seq_id = None
|
| 1240 |
+
else:
|
| 1241 |
+
trans_seq_id = sequence_id
|
| 1242 |
+
last_hidden_state, _, collected, attentions = self.transformer(
|
| 1243 |
+
x,
|
| 1244 |
+
sequence_id=trans_seq_id,
|
| 1245 |
+
layers_to_collect=layers_to_collect,
|
| 1246 |
+
output_attentions=output_attentions,
|
| 1247 |
+
)
|
| 1248 |
+
|
| 1249 |
+
if self._use_flash_attn:
|
| 1250 |
+
assert indices is not None and pad_input is not None
|
| 1251 |
+
last_hidden_state = pad_input(last_hidden_state, indices, b, l_)
|
| 1252 |
+
collected = [pad_input(h, indices, b, l_) for h in collected]
|
| 1253 |
+
|
| 1254 |
+
# Stack once; reused for both SAE and hidden-state output.
|
| 1255 |
+
collected_tensor: torch.Tensor | None = (
|
| 1256 |
+
torch.stack(collected, dim=0) if collected else None # type: ignore[arg-type]
|
| 1257 |
+
)
|
| 1258 |
+
|
| 1259 |
+
sae_outputs: dict[str, torch.Tensor] | None = None
|
| 1260 |
+
if output_sae and collected_tensor is not None:
|
| 1261 |
+
assert input_ids is not None
|
| 1262 |
+
self._validate_sae_inputs(input_ids)
|
| 1263 |
+
sae_outputs = self._get_sae_outputs(
|
| 1264 |
+
collected_tensor, layers_to_collect, bool_mask, normalize_sae
|
| 1265 |
+
)
|
| 1266 |
+
|
| 1267 |
+
hidden_states_tensor = collected_tensor if output_hidden_states else None
|
| 1268 |
+
|
| 1269 |
+
if not return_dict:
|
| 1270 |
+
return tuple(
|
| 1271 |
+
v
|
| 1272 |
+
for v in [
|
| 1273 |
+
last_hidden_state,
|
| 1274 |
+
hidden_states_tensor,
|
| 1275 |
+
sae_outputs,
|
| 1276 |
+
attentions,
|
| 1277 |
+
]
|
| 1278 |
+
if v is not None
|
| 1279 |
+
)
|
| 1280 |
+
|
| 1281 |
+
return ESMCOutput(
|
| 1282 |
+
last_hidden_state=last_hidden_state,
|
| 1283 |
+
hidden_states=hidden_states_tensor,
|
| 1284 |
+
sae_outputs=sae_outputs,
|
| 1285 |
+
attentions=attentions,
|
| 1286 |
+
)
|
| 1287 |
+
|
| 1288 |
+
|
| 1289 |
+
# ---------------------------------------------------------------------------
|
| 1290 |
+
# LM head
|
| 1291 |
+
# ---------------------------------------------------------------------------
|
| 1292 |
+
|
| 1293 |
+
|
| 1294 |
+
def _esmc_lm_head(
|
| 1295 |
+
d_model: int, output_dim: int, hidden_dim: int | None = None
|
| 1296 |
+
) -> nn.Sequential:
|
| 1297 |
+
"""Linear → GELU → LayerNorm → Linear projection head for masked LM."""
|
| 1298 |
+
hidden_dim = hidden_dim if hidden_dim is not None else d_model
|
| 1299 |
+
return nn.Sequential(
|
| 1300 |
+
nn.Linear(d_model, hidden_dim),
|
| 1301 |
+
nn.GELU(),
|
| 1302 |
+
nn.LayerNorm(hidden_dim),
|
| 1303 |
+
nn.Linear(hidden_dim, output_dim),
|
| 1304 |
+
)
|
| 1305 |
+
|
| 1306 |
+
|
| 1307 |
+
# ---------------------------------------------------------------------------
|
| 1308 |
+
# Masked language model
|
| 1309 |
+
# ---------------------------------------------------------------------------
|
| 1310 |
+
|
| 1311 |
+
|
| 1312 |
+
@auto_docstring
|
| 1313 |
+
class ESMCForMaskedLM(ESMCPreTrainedModel):
|
| 1314 |
+
"""ESMC with a masked language modelling head.
|
| 1315 |
+
|
| 1316 |
+
This is the primary pre-training objective of ESMC. The LM head consists
|
| 1317 |
+
of a single hidden layer with GELU activation followed by LayerNorm and a
|
| 1318 |
+
linear projection to ``vocab_size``.
|
| 1319 |
+
"""
|
| 1320 |
+
|
| 1321 |
+
def __init__(self, config: ESMCConfig):
|
| 1322 |
+
super().__init__(config)
|
| 1323 |
+
self.esmc = ESMCModel(config)
|
| 1324 |
+
self.lm_head = _esmc_lm_head(config.d_model, config.vocab_size)
|
| 1325 |
+
self.post_init()
|
| 1326 |
+
|
| 1327 |
+
def get_output_embeddings(self) -> nn.Linear:
|
| 1328 |
+
return self.lm_head[-1] # type: ignore[return-value]
|
| 1329 |
+
|
| 1330 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear):
|
| 1331 |
+
self.lm_head[-1] = new_embeddings
|
| 1332 |
+
|
| 1333 |
+
def add_sae_models(self, sae_models: list[_ESMCSAELayer]) -> None:
|
| 1334 |
+
"""Proxy to :meth:`ESMCModel.add_sae_models`."""
|
| 1335 |
+
self.esmc.add_sae_models(sae_models)
|
| 1336 |
+
|
| 1337 |
+
@can_return_tuple
|
| 1338 |
+
@auto_docstring
|
| 1339 |
+
def forward(
|
| 1340 |
+
self,
|
| 1341 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1342 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1343 |
+
sequence_id: Optional[torch.Tensor] = None,
|
| 1344 |
+
output_hidden_states: Optional[bool] = None,
|
| 1345 |
+
output_attentions: Optional[bool] = None,
|
| 1346 |
+
return_dict: Optional[bool] = None,
|
| 1347 |
+
labels: Optional[torch.Tensor] = None,
|
| 1348 |
+
compute_sae: bool = True,
|
| 1349 |
+
normalize_sae: bool = False,
|
| 1350 |
+
) -> tuple[torch.Tensor, ...] | ESMCMaskedLMOutput:
|
| 1351 |
+
r"""
|
| 1352 |
+
sequence_id (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1353 |
+
Integer chain-ID tensor forwarded to the encoder for chain-aware
|
| 1354 |
+
attention masking. See :meth:`ESMCModel.forward` for the encoding.
|
| 1355 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1356 |
+
Labels for masked language modelling loss. Positions with label ``-100``
|
| 1357 |
+
are ignored. Other positions must be in ``[0, config.vocab_size)``.
|
| 1358 |
+
output_attentions (`bool`, *optional*):
|
| 1359 |
+
Whether to return per-block attention weights. Forwarded to the
|
| 1360 |
+
backbone; raises on the ``flash_attention_2`` path.
|
| 1361 |
+
compute_sae (`bool`, *optional*, defaults to ``True``):
|
| 1362 |
+
Whether to run registered SAE models. Has no effect when none are registered.
|
| 1363 |
+
normalize_sae (`bool`, *optional*, defaults to ``False``):
|
| 1364 |
+
When ``True``, scale SAE features by ``idf / max`` normalization buffers.
|
| 1365 |
+
|
| 1366 |
+
Examples:
|
| 1367 |
+
|
| 1368 |
+
```python
|
| 1369 |
+
>>> from transformers import AutoTokenizer, ESMCForMaskedLM
|
| 1370 |
+
>>> import torch
|
| 1371 |
+
|
| 1372 |
+
>>> model = ESMCForMaskedLM.from_pretrained("Biohub/ESMC-600M-2024-12")
|
| 1373 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Biohub/ESMC-600M-2024-12")
|
| 1374 |
+
>>> inputs = tokenizer(["MLKNVQ<mask>LV"], return_tensors="pt")
|
| 1375 |
+
>>> outputs = model(**inputs)
|
| 1376 |
+
>>> outputs.logits.shape
|
| 1377 |
+
torch.Size([1, 11, 64])
|
| 1378 |
+
```
|
| 1379 |
+
"""
|
| 1380 |
+
return_dict = (
|
| 1381 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1382 |
+
)
|
| 1383 |
+
|
| 1384 |
+
encoder_outputs = self.esmc(
|
| 1385 |
+
input_ids=input_ids,
|
| 1386 |
+
attention_mask=attention_mask,
|
| 1387 |
+
sequence_id=sequence_id,
|
| 1388 |
+
output_hidden_states=output_hidden_states,
|
| 1389 |
+
output_attentions=output_attentions,
|
| 1390 |
+
return_dict=True,
|
| 1391 |
+
compute_sae=compute_sae,
|
| 1392 |
+
normalize_sae=normalize_sae,
|
| 1393 |
+
)
|
| 1394 |
+
|
| 1395 |
+
logits = self.lm_head(encoder_outputs.last_hidden_state)
|
| 1396 |
+
|
| 1397 |
+
loss: torch.Tensor | None = None
|
| 1398 |
+
if labels is not None:
|
| 1399 |
+
loss = CrossEntropyLoss(ignore_index=-100)(
|
| 1400 |
+
logits.view(-1, self.config.vocab_size), labels.view(-1)
|
| 1401 |
+
)
|
| 1402 |
+
|
| 1403 |
+
if not return_dict:
|
| 1404 |
+
return tuple(
|
| 1405 |
+
v
|
| 1406 |
+
for v in [
|
| 1407 |
+
loss,
|
| 1408 |
+
logits,
|
| 1409 |
+
encoder_outputs.last_hidden_state,
|
| 1410 |
+
encoder_outputs.hidden_states,
|
| 1411 |
+
encoder_outputs.sae_outputs,
|
| 1412 |
+
encoder_outputs.attentions,
|
| 1413 |
+
]
|
| 1414 |
+
if v is not None
|
| 1415 |
+
)
|
| 1416 |
+
|
| 1417 |
+
return ESMCMaskedLMOutput(
|
| 1418 |
+
loss=loss,
|
| 1419 |
+
logits=logits,
|
| 1420 |
+
last_hidden_state=encoder_outputs.last_hidden_state,
|
| 1421 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1422 |
+
sae_outputs=encoder_outputs.sae_outputs,
|
| 1423 |
+
attentions=encoder_outputs.attentions,
|
| 1424 |
+
)
|
| 1425 |
+
|
| 1426 |
+
|
| 1427 |
+
# ---------------------------------------------------------------------------
|
| 1428 |
+
# Classification heads
|
| 1429 |
+
# ---------------------------------------------------------------------------
|
| 1430 |
+
|
| 1431 |
+
|
| 1432 |
+
class _ESMCClassificationHead(nn.Module):
|
| 1433 |
+
"""Dense classification head applied to the ``<cls>`` token representation."""
|
| 1434 |
+
|
| 1435 |
+
def __init__(self, config: ESMCConfig):
|
| 1436 |
+
super().__init__()
|
| 1437 |
+
self.dense = nn.Linear(config.d_model, config.d_model)
|
| 1438 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
| 1439 |
+
self.out_proj = nn.Linear(config.d_model, config.num_labels)
|
| 1440 |
+
|
| 1441 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 1442 |
+
x = hidden_states[:, 0, :] # <cls> token
|
| 1443 |
+
x = self.dropout(x)
|
| 1444 |
+
x = torch.tanh(self.dense(x))
|
| 1445 |
+
x = self.dropout(x)
|
| 1446 |
+
return self.out_proj(x)
|
| 1447 |
+
|
| 1448 |
+
|
| 1449 |
+
# ---------------------------------------------------------------------------
|
| 1450 |
+
# Sequence classification
|
| 1451 |
+
# ---------------------------------------------------------------------------
|
| 1452 |
+
|
| 1453 |
+
|
| 1454 |
+
@auto_docstring
|
| 1455 |
+
class ESMCForSequenceClassification(ESMCPreTrainedModel):
|
| 1456 |
+
"""ESMC with a sequence-level classification head.
|
| 1457 |
+
|
| 1458 |
+
A linear layer is applied to the ``<cls>`` token representation.
|
| 1459 |
+
Supports regression (``num_labels == 1``), single-label classification,
|
| 1460 |
+
and multi-label classification.
|
| 1461 |
+
"""
|
| 1462 |
+
|
| 1463 |
+
def __init__(self, config: ESMCConfig):
|
| 1464 |
+
super().__init__(config)
|
| 1465 |
+
self.num_labels = config.num_labels
|
| 1466 |
+
self.esmc = ESMCModel(config)
|
| 1467 |
+
self.classifier = _ESMCClassificationHead(config)
|
| 1468 |
+
self.post_init()
|
| 1469 |
+
|
| 1470 |
+
def add_sae_models(self, sae_models: list[_ESMCSAELayer]) -> None:
|
| 1471 |
+
"""Proxy to :meth:`ESMCModel.add_sae_models`."""
|
| 1472 |
+
self.esmc.add_sae_models(sae_models)
|
| 1473 |
+
|
| 1474 |
+
@can_return_tuple
|
| 1475 |
+
@auto_docstring
|
| 1476 |
+
def forward(
|
| 1477 |
+
self,
|
| 1478 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1479 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1480 |
+
output_hidden_states: Optional[bool] = None,
|
| 1481 |
+
output_attentions: Optional[bool] = None,
|
| 1482 |
+
return_dict: Optional[bool] = None,
|
| 1483 |
+
labels: Optional[torch.Tensor] = None,
|
| 1484 |
+
compute_sae: bool = True,
|
| 1485 |
+
normalize_sae: bool = False,
|
| 1486 |
+
) -> tuple[torch.Tensor, ...] | ESMCSequenceClassifierOutput:
|
| 1487 |
+
r"""
|
| 1488 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1489 |
+
Labels for sequence classification loss. Indices must be in
|
| 1490 |
+
``[0, config.num_labels - 1]``. For regression pass a float
|
| 1491 |
+
tensor of shape ``(batch_size,)``.
|
| 1492 |
+
output_attentions (`bool`, *optional*):
|
| 1493 |
+
Whether to return per-block attention weights. Forwarded to the
|
| 1494 |
+
backbone; raises on the ``flash_attention_2`` path.
|
| 1495 |
+
compute_sae (`bool`, *optional*, defaults to ``True``):
|
| 1496 |
+
Whether to run registered SAE models. Has no effect when none are registered.
|
| 1497 |
+
normalize_sae (`bool`, *optional*, defaults to ``False``):
|
| 1498 |
+
When ``True``, scale SAE features by ``idf / max`` normalization buffers.
|
| 1499 |
+
"""
|
| 1500 |
+
return_dict = (
|
| 1501 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1502 |
+
)
|
| 1503 |
+
|
| 1504 |
+
encoder_outputs = self.esmc(
|
| 1505 |
+
input_ids,
|
| 1506 |
+
attention_mask=attention_mask,
|
| 1507 |
+
output_hidden_states=output_hidden_states,
|
| 1508 |
+
output_attentions=output_attentions,
|
| 1509 |
+
return_dict=True,
|
| 1510 |
+
compute_sae=compute_sae,
|
| 1511 |
+
normalize_sae=normalize_sae,
|
| 1512 |
+
)
|
| 1513 |
+
logits = self.classifier(encoder_outputs.last_hidden_state)
|
| 1514 |
+
|
| 1515 |
+
loss: torch.Tensor | None = None
|
| 1516 |
+
if labels is not None:
|
| 1517 |
+
labels = labels.to(logits.device)
|
| 1518 |
+
|
| 1519 |
+
if self.config.problem_type is None:
|
| 1520 |
+
if self.num_labels == 1:
|
| 1521 |
+
self.config.problem_type = "regression"
|
| 1522 |
+
elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int):
|
| 1523 |
+
self.config.problem_type = "single_label_classification"
|
| 1524 |
+
else:
|
| 1525 |
+
self.config.problem_type = "multi_label_classification"
|
| 1526 |
+
|
| 1527 |
+
if self.config.problem_type == "regression":
|
| 1528 |
+
loss_fct = MSELoss()
|
| 1529 |
+
loss = loss_fct(
|
| 1530 |
+
logits.squeeze() if self.num_labels == 1 else logits,
|
| 1531 |
+
labels.squeeze() if self.num_labels == 1 else labels,
|
| 1532 |
+
)
|
| 1533 |
+
elif self.config.problem_type == "single_label_classification":
|
| 1534 |
+
loss = CrossEntropyLoss()(
|
| 1535 |
+
logits.view(-1, self.num_labels), labels.view(-1)
|
| 1536 |
+
)
|
| 1537 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 1538 |
+
loss = BCEWithLogitsLoss()(logits, labels)
|
| 1539 |
+
|
| 1540 |
+
if not return_dict:
|
| 1541 |
+
return tuple(
|
| 1542 |
+
v
|
| 1543 |
+
for v in [
|
| 1544 |
+
loss,
|
| 1545 |
+
logits,
|
| 1546 |
+
encoder_outputs.last_hidden_state,
|
| 1547 |
+
encoder_outputs.hidden_states,
|
| 1548 |
+
encoder_outputs.sae_outputs,
|
| 1549 |
+
encoder_outputs.attentions,
|
| 1550 |
+
]
|
| 1551 |
+
if v is not None
|
| 1552 |
+
)
|
| 1553 |
+
|
| 1554 |
+
return ESMCSequenceClassifierOutput(
|
| 1555 |
+
loss=loss,
|
| 1556 |
+
logits=logits,
|
| 1557 |
+
last_hidden_state=encoder_outputs.last_hidden_state,
|
| 1558 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1559 |
+
sae_outputs=encoder_outputs.sae_outputs,
|
| 1560 |
+
attentions=encoder_outputs.attentions,
|
| 1561 |
+
)
|
| 1562 |
+
|
| 1563 |
+
|
| 1564 |
+
# ---------------------------------------------------------------------------
|
| 1565 |
+
# Token classification
|
| 1566 |
+
# ---------------------------------------------------------------------------
|
| 1567 |
+
|
| 1568 |
+
|
| 1569 |
+
@auto_docstring
|
| 1570 |
+
class ESMCForTokenClassification(ESMCPreTrainedModel):
|
| 1571 |
+
"""ESMC with a per-token classification head.
|
| 1572 |
+
|
| 1573 |
+
Useful for tasks such as secondary structure prediction, contact-map
|
| 1574 |
+
prediction, or per-residue labelling.
|
| 1575 |
+
"""
|
| 1576 |
+
|
| 1577 |
+
def __init__(self, config: ESMCConfig):
|
| 1578 |
+
super().__init__(config)
|
| 1579 |
+
self.num_labels = config.num_labels
|
| 1580 |
+
self.esmc = ESMCModel(config)
|
| 1581 |
+
self.dropout = nn.Dropout(config.classifier_dropout)
|
| 1582 |
+
self.classifier = nn.Linear(config.d_model, config.num_labels)
|
| 1583 |
+
self.post_init()
|
| 1584 |
+
|
| 1585 |
+
def add_sae_models(self, sae_models: list[_ESMCSAELayer]) -> None:
|
| 1586 |
+
"""Proxy to :meth:`ESMCModel.add_sae_models`."""
|
| 1587 |
+
self.esmc.add_sae_models(sae_models)
|
| 1588 |
+
|
| 1589 |
+
@can_return_tuple
|
| 1590 |
+
@auto_docstring
|
| 1591 |
+
def forward(
|
| 1592 |
+
self,
|
| 1593 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1594 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1595 |
+
output_hidden_states: Optional[bool] = None,
|
| 1596 |
+
output_attentions: Optional[bool] = None,
|
| 1597 |
+
return_dict: Optional[bool] = None,
|
| 1598 |
+
labels: Optional[torch.Tensor] = None,
|
| 1599 |
+
compute_sae: bool = True,
|
| 1600 |
+
normalize_sae: bool = False,
|
| 1601 |
+
) -> tuple[torch.Tensor, ...] | ESMCTokenClassifierOutput:
|
| 1602 |
+
r"""
|
| 1603 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1604 |
+
Per-token labels. Indices must be in ``[0, config.num_labels - 1]``.
|
| 1605 |
+
Positions with index ``-100`` are ignored in the loss.
|
| 1606 |
+
output_attentions (`bool`, *optional*):
|
| 1607 |
+
Whether to return per-block attention weights. Forwarded to the
|
| 1608 |
+
backbone; raises on the ``flash_attention_2`` path.
|
| 1609 |
+
compute_sae (`bool`, *optional*, defaults to ``True``):
|
| 1610 |
+
Whether to run registered SAE models. Has no effect when none are registered.
|
| 1611 |
+
normalize_sae (`bool`, *optional*, defaults to ``False``):
|
| 1612 |
+
When ``True``, scale SAE features by ``idf / max`` normalization buffers.
|
| 1613 |
+
"""
|
| 1614 |
+
return_dict = (
|
| 1615 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 1616 |
+
)
|
| 1617 |
+
|
| 1618 |
+
encoder_outputs = self.esmc(
|
| 1619 |
+
input_ids=input_ids,
|
| 1620 |
+
attention_mask=attention_mask,
|
| 1621 |
+
output_hidden_states=output_hidden_states,
|
| 1622 |
+
output_attentions=output_attentions,
|
| 1623 |
+
return_dict=True,
|
| 1624 |
+
compute_sae=compute_sae,
|
| 1625 |
+
normalize_sae=normalize_sae,
|
| 1626 |
+
)
|
| 1627 |
+
|
| 1628 |
+
sequence_output = self.dropout(encoder_outputs.last_hidden_state)
|
| 1629 |
+
logits = self.classifier(sequence_output)
|
| 1630 |
+
|
| 1631 |
+
loss: torch.Tensor | None = None
|
| 1632 |
+
if labels is not None:
|
| 1633 |
+
loss = CrossEntropyLoss(ignore_index=-100)(
|
| 1634 |
+
logits.view(-1, self.num_labels), labels.to(logits.device).view(-1)
|
| 1635 |
+
)
|
| 1636 |
+
|
| 1637 |
+
if not return_dict:
|
| 1638 |
+
return tuple(
|
| 1639 |
+
v
|
| 1640 |
+
for v in [
|
| 1641 |
+
loss,
|
| 1642 |
+
logits,
|
| 1643 |
+
encoder_outputs.last_hidden_state,
|
| 1644 |
+
encoder_outputs.hidden_states,
|
| 1645 |
+
encoder_outputs.sae_outputs,
|
| 1646 |
+
encoder_outputs.attentions,
|
| 1647 |
+
]
|
| 1648 |
+
if v is not None
|
| 1649 |
+
)
|
| 1650 |
+
|
| 1651 |
+
return ESMCTokenClassifierOutput(
|
| 1652 |
+
loss=loss,
|
| 1653 |
+
logits=logits,
|
| 1654 |
+
last_hidden_state=encoder_outputs.last_hidden_state,
|
| 1655 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1656 |
+
sae_outputs=encoder_outputs.sae_outputs,
|
| 1657 |
+
attentions=encoder_outputs.attentions,
|
| 1658 |
+
)
|
| 1659 |
+
|
| 1660 |
+
|
| 1661 |
+
__all__ = [
|
| 1662 |
+
"ESMCModel",
|
| 1663 |
+
"ESMCForMaskedLM",
|
| 1664 |
+
"ESMCForSequenceClassification",
|
| 1665 |
+
"ESMCForTokenClassification",
|
| 1666 |
+
"ESMCPreTrainedModel",
|
| 1667 |
+
]
|
modeling_esmc_sae.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2026 Biohub. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""PyTorch ESMC SAE (Sparse Autoencoder) model.
|
| 15 |
+
|
| 16 |
+
* :class:`ESMCSAEModel` — the published HF container, one repo per
|
| 17 |
+
``(backbone, codebook_dim, k)`` group. Each backbone layer ships as a
|
| 18 |
+
``layer_{i}.safetensors`` shard; ``from_pretrained`` downloads the whole
|
| 19 |
+
snapshot but loads no weights — callers materialize the layers they need
|
| 20 |
+
via :meth:`initialize_layers`. Single-layer repos auto-load so bare
|
| 21 |
+
``forward(x)`` works.
|
| 22 |
+
* :class:`_ESMCSAELayer` — internal ``nn.Module`` that holds the weights for
|
| 23 |
+
one ``(backbone, codebook_dim, k, layer)`` SAE. Not a published HF artifact;
|
| 24 |
+
obtained only via ``model.layers["<idx>"]``.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
import os
|
| 30 |
+
from dataclasses import dataclass
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from typing import Optional
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn as nn
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
from safetensors.torch import load_file, save_file
|
| 38 |
+
|
| 39 |
+
from transformers.modeling_outputs import ModelOutput
|
| 40 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 41 |
+
from transformers.utils import auto_docstring
|
| 42 |
+
from .configuration_esmc_sae import ESMCSAEConfig, ESMCSAEParams
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
@auto_docstring(
|
| 47 |
+
custom_intro="""
|
| 48 |
+
Output type of [`ESMCSAEModel`].
|
| 49 |
+
"""
|
| 50 |
+
)
|
| 51 |
+
class ESMCSAEOutput(ModelOutput):
|
| 52 |
+
feature_magnitudes: torch.Tensor
|
| 53 |
+
reconstruction_loss: Optional[torch.Tensor] = None
|
| 54 |
+
|
| 55 |
+
def to_sparse(self) -> None:
|
| 56 |
+
self.feature_magnitudes = self.feature_magnitudes.to_sparse()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class _ESMCSAELayer(nn.Module):
|
| 60 |
+
"""One backbone layer's SAE — internal building block of :class:`ESMCSAEModel`.
|
| 61 |
+
|
| 62 |
+
Not exposed via ``AutoModel`` and not loadable on its own. Obtain one
|
| 63 |
+
via ``model.layers["<layer_idx>"]`` after calling ``initialize_layers``.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, params: ESMCSAEParams):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.params = params
|
| 69 |
+
|
| 70 |
+
self.W_enc = nn.Parameter(torch.empty(params.d_model, params.codebook_dim))
|
| 71 |
+
self.W_dec = nn.Parameter(torch.empty(params.codebook_dim, params.d_model))
|
| 72 |
+
self.b_dec = nn.Parameter(torch.zeros(params.d_model))
|
| 73 |
+
# Per-feature normalization stats. Trained alongside the SAE for some
|
| 74 |
+
# variants; for variants that don't ship them, leaving these as ones
|
| 75 |
+
# makes ``_get_sae_outputs``'s ``features / max * idf`` a no-op.
|
| 76 |
+
self.register_buffer("idf", torch.ones(params.codebook_dim))
|
| 77 |
+
self.register_buffer("max", torch.ones(params.codebook_dim))
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def layer(self) -> int:
|
| 81 |
+
"""Backbone-layer index this SAE is trained against."""
|
| 82 |
+
return self.params.layer
|
| 83 |
+
|
| 84 |
+
def forward(self, x: torch.Tensor, **_kwargs: object) -> ESMCSAEOutput:
|
| 85 |
+
del _kwargs
|
| 86 |
+
x = self._zscore_normalize_representation(x)
|
| 87 |
+
|
| 88 |
+
x_with_pre_encoder_bias = x - self.b_dec
|
| 89 |
+
preactivations = F.relu(x_with_pre_encoder_bias @ self.W_enc)
|
| 90 |
+
|
| 91 |
+
topk = torch.topk(preactivations, self.params.k, dim=-1)
|
| 92 |
+
feature_magnitudes = torch.zeros_like(preactivations).scatter(
|
| 93 |
+
-1, topk.indices, topk.values
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
|
| 97 |
+
|
| 98 |
+
reconstruction_loss = (reconstructed - x).pow(2).mean(dim=-1)
|
| 99 |
+
|
| 100 |
+
return ESMCSAEOutput(
|
| 101 |
+
feature_magnitudes=feature_magnitudes,
|
| 102 |
+
reconstruction_loss=reconstruction_loss,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def get_sae_output(
|
| 106 |
+
self, layer_states: torch.Tensor, token_mask: torch.Tensor
|
| 107 |
+
) -> ESMCSAEOutput:
|
| 108 |
+
_, _, v_len = layer_states.shape
|
| 109 |
+
nonpad_states = layer_states[token_mask].view(-1, v_len)
|
| 110 |
+
return self(nonpad_states)
|
| 111 |
+
|
| 112 |
+
def _zscore_normalize_representation(self, x: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
x_mean = x.mean(dim=-1, keepdim=True)
|
| 114 |
+
x = x - x_mean
|
| 115 |
+
x_std = x.std(dim=-1, keepdim=True)
|
| 116 |
+
return x / (x_std + 1e-5)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@auto_docstring
|
| 120 |
+
class ESMCSAEPreTrainedModel(PreTrainedModel):
|
| 121 |
+
config_class = ESMCSAEConfig
|
| 122 |
+
base_model_prefix = "esmc_sae"
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@auto_docstring(
|
| 126 |
+
custom_intro="""
|
| 127 |
+
HF container holding one SAE per backbone layer, all sharing the same
|
| 128 |
+
``(d_model, codebook_dim, k)``.
|
| 129 |
+
|
| 130 |
+
``from_pretrained`` downloads the entire repo (every ``layer_{i}.safetensors``)
|
| 131 |
+
into the local HF cache but does **not** load any weights into memory.
|
| 132 |
+
Callers materialize the layers they actually need by calling
|
| 133 |
+
:meth:`initialize_layers`. The full set is available on disk after the
|
| 134 |
+
first call, so subsequent layer switches read from the local cache without
|
| 135 |
+
re-downloading.
|
| 136 |
+
|
| 137 |
+
Examples::
|
| 138 |
+
|
| 139 |
+
model = ESMCSAEModel.from_pretrained(
|
| 140 |
+
"biohub/esmc-6b-2024-12-sae-k64-codebook16384"
|
| 141 |
+
)
|
| 142 |
+
model.initialize_layers([60]) # ~2.5 GB into memory
|
| 143 |
+
out = model(layer_states, layer=60) # forward through layer 60
|
| 144 |
+
model.initialize_layers([45]) # add layer 45 (cached locally)
|
| 145 |
+
model.release_layer(60) # free layer 60
|
| 146 |
+
"""
|
| 147 |
+
)
|
| 148 |
+
class ESMCSAEModel(ESMCSAEPreTrainedModel):
|
| 149 |
+
def __init__(self, config: ESMCSAEConfig):
|
| 150 |
+
super().__init__(config)
|
| 151 |
+
# Layers are populated lazily by ``initialize_layers``; the container
|
| 152 |
+
# starts empty so ``from_pretrained`` doesn't materialize hundreds of
|
| 153 |
+
# GB of unused parameters.
|
| 154 |
+
self.layers = nn.ModuleDict()
|
| 155 |
+
# Zero-element buffer that rides along with ``.to(device/dtype)``.
|
| 156 |
+
# ``initialize_layers`` reads its current device/dtype so SAEs added
|
| 157 |
+
# after ``model.to("cuda")`` land on CUDA without re-passing ``device=``.
|
| 158 |
+
self.register_buffer("_device_marker", torch.empty(0), persistent=False)
|
| 159 |
+
self._snapshot_dir: Optional[str] = None
|
| 160 |
+
self.post_init()
|
| 161 |
+
|
| 162 |
+
@classmethod
|
| 163 |
+
def from_pretrained( # type: ignore[override]
|
| 164 |
+
cls, pretrained_model_name_or_path: str | os.PathLike, *model_args, **kwargs
|
| 165 |
+
) -> "ESMCSAEModel":
|
| 166 |
+
"""Download (or reuse cached) the full repo and return the model.
|
| 167 |
+
|
| 168 |
+
By default no weights are read into memory and the caller must invoke
|
| 169 |
+
:meth:`initialize_layers` before running :meth:`forward`. The single
|
| 170 |
+
exception is when the repo ships exactly one layer: that layer is
|
| 171 |
+
auto-loaded (honoring ``torch_dtype`` / ``device`` if passed) so the
|
| 172 |
+
bare ``forward(x)`` call just works.
|
| 173 |
+
|
| 174 |
+
Honored kwargs: ``revision``, ``cache_dir``, ``token``,
|
| 175 |
+
``allow_patterns``, ``local_files_only``, ``force_download`` (forwarded
|
| 176 |
+
to ``snapshot_download``); ``torch_dtype`` and ``device`` (used by the
|
| 177 |
+
single-layer auto-load path; otherwise pass them to
|
| 178 |
+
:meth:`initialize_layers`). Behavioral kwargs that imply work we do
|
| 179 |
+
not perform (``device_map``, ``low_cpu_mem_usage``,
|
| 180 |
+
``quantization_config``, ``attn_implementation``) raise so the user
|
| 181 |
+
isn't silently misled. Other HF housekeeping kwargs (``config``,
|
| 182 |
+
``trust_remote_code``, ``adapter_kwargs``, …) are accepted and
|
| 183 |
+
ignored — they only matter for the standard loader, which we bypass.
|
| 184 |
+
"""
|
| 185 |
+
del model_args
|
| 186 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
| 187 |
+
device = kwargs.pop("device", None)
|
| 188 |
+
local_dir = _resolve_snapshot_dir(pretrained_model_name_or_path, kwargs)
|
| 189 |
+
unsupported = {
|
| 190 |
+
"device_map",
|
| 191 |
+
"low_cpu_mem_usage",
|
| 192 |
+
"quantization_config",
|
| 193 |
+
"attn_implementation",
|
| 194 |
+
"max_memory",
|
| 195 |
+
"offload_folder",
|
| 196 |
+
"offload_state_dict",
|
| 197 |
+
} & kwargs.keys()
|
| 198 |
+
if unsupported:
|
| 199 |
+
raise TypeError(
|
| 200 |
+
f"Unsupported kwargs to ESMCSAEModel.from_pretrained: "
|
| 201 |
+
f"{sorted(unsupported)}. The standard HF loader is bypassed —"
|
| 202 |
+
" call initialize_layers(..., device=, dtype=) instead."
|
| 203 |
+
)
|
| 204 |
+
config = ESMCSAEConfig.from_pretrained(local_dir)
|
| 205 |
+
model = cls(config)
|
| 206 |
+
model._snapshot_dir = str(local_dir)
|
| 207 |
+
if device is not None:
|
| 208 |
+
model.to(device)
|
| 209 |
+
if torch_dtype is not None:
|
| 210 |
+
model.to(torch_dtype)
|
| 211 |
+
if len(config.available_layers) == 1:
|
| 212 |
+
model.initialize_layers(list(config.available_layers))
|
| 213 |
+
return model
|
| 214 |
+
|
| 215 |
+
def initialize_layers(
|
| 216 |
+
self,
|
| 217 |
+
layers: list[int],
|
| 218 |
+
*,
|
| 219 |
+
device: torch.device | str | None = None,
|
| 220 |
+
dtype: torch.dtype | None = None,
|
| 221 |
+
) -> None:
|
| 222 |
+
"""Load the requested layers from the local snapshot into memory.
|
| 223 |
+
|
| 224 |
+
Layers already present in :attr:`self.layers` are skipped — calling
|
| 225 |
+
``initialize_layers([23])`` twice is idempotent. ``device`` / ``dtype``
|
| 226 |
+
default to wherever the model itself lives (via the ``_device_marker``
|
| 227 |
+
buffer that moves with ``.to(...)``), so the common pattern of
|
| 228 |
+
``model.to("cuda"); model.initialize_layers([7])`` Just Works.
|
| 229 |
+
"""
|
| 230 |
+
assert self._snapshot_dir is not None, (
|
| 231 |
+
"ESMCSAEModel has no snapshot directory — call "
|
| 232 |
+
"from_pretrained first, or set _snapshot_dir manually."
|
| 233 |
+
)
|
| 234 |
+
if device is None:
|
| 235 |
+
device = self._device_marker.device
|
| 236 |
+
if dtype is None:
|
| 237 |
+
dtype = self._device_marker.dtype
|
| 238 |
+
snapshot_dir = Path(self._snapshot_dir)
|
| 239 |
+
available = set(self.config.available_layers)
|
| 240 |
+
for layer_idx in layers:
|
| 241 |
+
key = str(layer_idx)
|
| 242 |
+
if key in self.layers:
|
| 243 |
+
continue
|
| 244 |
+
if layer_idx not in available:
|
| 245 |
+
raise KeyError(
|
| 246 |
+
f"Layer {layer_idx} is not in this repo. "
|
| 247 |
+
f"available_layers={sorted(available)}"
|
| 248 |
+
)
|
| 249 |
+
shard = snapshot_dir / f"layer_{layer_idx}.safetensors"
|
| 250 |
+
if not shard.exists():
|
| 251 |
+
raise FileNotFoundError(
|
| 252 |
+
f"Missing layer file {shard} — config lists layer "
|
| 253 |
+
f"{layer_idx} as available but the shard is not on disk."
|
| 254 |
+
)
|
| 255 |
+
params = ESMCSAEParams(
|
| 256 |
+
d_model=self.config.d_model,
|
| 257 |
+
codebook_dim=self.config.codebook_dim,
|
| 258 |
+
k=self.config.k,
|
| 259 |
+
layer=layer_idx,
|
| 260 |
+
)
|
| 261 |
+
# Build on the meta device so we don't allocate weights that
|
| 262 |
+
# ``load_state_dict`` would immediately overwrite.
|
| 263 |
+
with torch.device("meta"):
|
| 264 |
+
layer = _ESMCSAELayer(params)
|
| 265 |
+
layer.to_empty(device=device)
|
| 266 |
+
layer.load_state_dict(load_file(str(shard)))
|
| 267 |
+
layer.to(dtype=dtype)
|
| 268 |
+
self.layers[key] = layer
|
| 269 |
+
|
| 270 |
+
def release_layer(self, layer: int) -> None:
|
| 271 |
+
"""Drop the named layer from memory. No-op if not loaded."""
|
| 272 |
+
key = str(layer)
|
| 273 |
+
if key in self.layers:
|
| 274 |
+
del self.layers[key]
|
| 275 |
+
|
| 276 |
+
def loaded_layers(self) -> list[int]:
|
| 277 |
+
"""Sorted list of layer indices currently materialized in memory."""
|
| 278 |
+
return sorted(int(k) for k in self.layers.keys())
|
| 279 |
+
|
| 280 |
+
def forward(
|
| 281 |
+
self, x: torch.Tensor, layer: int | None = None, **kwargs: object
|
| 282 |
+
) -> ESMCSAEOutput:
|
| 283 |
+
if layer is None:
|
| 284 |
+
if len(self.layers) == 1:
|
| 285 |
+
# Unambiguous: exactly one layer loaded → use it.
|
| 286 |
+
((_only_key, only_layer),) = self.layers.items()
|
| 287 |
+
return only_layer(x, **kwargs)
|
| 288 |
+
if len(self.layers) == 0:
|
| 289 |
+
raise RuntimeError(
|
| 290 |
+
"No layers loaded — call "
|
| 291 |
+
f"initialize_layers([...]) first. "
|
| 292 |
+
f"available_layers={self.config.available_layers}"
|
| 293 |
+
)
|
| 294 |
+
raise RuntimeError(
|
| 295 |
+
"Multiple layers are loaded — please select one via "
|
| 296 |
+
f"forward(x, layer=<idx>). Loaded layers: {self.loaded_layers()}"
|
| 297 |
+
)
|
| 298 |
+
key = str(layer)
|
| 299 |
+
if key not in self.layers:
|
| 300 |
+
raise KeyError(
|
| 301 |
+
f"Layer {layer} is not loaded. Call "
|
| 302 |
+
f"initialize_layers([{layer}]) first. Loaded layers: "
|
| 303 |
+
f"{self.loaded_layers()}"
|
| 304 |
+
)
|
| 305 |
+
return self.layers[key](x, **kwargs)
|
| 306 |
+
|
| 307 |
+
def save_pretrained( # type: ignore[override]
|
| 308 |
+
self, save_directory: str | os.PathLike, *args, **kwargs
|
| 309 |
+
) -> None:
|
| 310 |
+
"""Write ``config.json`` plus one ``layer_{i}.safetensors`` per loaded layer.
|
| 311 |
+
|
| 312 |
+
Only layers currently in :attr:`self.layers` are written.
|
| 313 |
+
``available_layers`` in the saved config is synced to what's actually
|
| 314 |
+
on disk so a ``release_layer`` + ``save_pretrained`` round-trip never
|
| 315 |
+
advertises a layer whose shard is missing.
|
| 316 |
+
"""
|
| 317 |
+
del args, kwargs
|
| 318 |
+
save_directory = Path(save_directory)
|
| 319 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
| 320 |
+
# Sync available_layers to what we're about to write — never advertise
|
| 321 |
+
# a layer that isn't on disk in this repo.
|
| 322 |
+
self.config.available_layers = self.loaded_layers()
|
| 323 |
+
self.config.save_pretrained(str(save_directory))
|
| 324 |
+
for key, layer in self.layers.items():
|
| 325 |
+
shard = save_directory / f"layer_{key}.safetensors"
|
| 326 |
+
save_file(
|
| 327 |
+
{
|
| 328 |
+
k: v.detach().cpu().contiguous()
|
| 329 |
+
for k, v in layer.state_dict().items()
|
| 330 |
+
},
|
| 331 |
+
str(shard),
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _resolve_snapshot_dir(
|
| 336 |
+
pretrained_model_name_or_path: str | os.PathLike, kwargs: dict
|
| 337 |
+
) -> str:
|
| 338 |
+
"""Local dir → return as-is; hub id → ``snapshot_download`` it.
|
| 339 |
+
|
| 340 |
+
A directory only counts as "local" if it actually contains ``config.json``,
|
| 341 |
+
so a stale subdir named like a hub id (``./biohub/esmc-...``)
|
| 342 |
+
doesn't accidentally shadow the hub fetch.
|
| 343 |
+
|
| 344 |
+
Pops the standard ``snapshot_download`` keyword args from ``kwargs`` so
|
| 345 |
+
callers can forward them via ``from_pretrained``.
|
| 346 |
+
"""
|
| 347 |
+
path = Path(pretrained_model_name_or_path)
|
| 348 |
+
if path.is_dir() and (path / "config.json").exists():
|
| 349 |
+
return str(path)
|
| 350 |
+
from huggingface_hub import snapshot_download
|
| 351 |
+
|
| 352 |
+
return snapshot_download(
|
| 353 |
+
repo_id=str(pretrained_model_name_or_path),
|
| 354 |
+
revision=kwargs.pop("revision", None),
|
| 355 |
+
cache_dir=kwargs.pop("cache_dir", None),
|
| 356 |
+
token=kwargs.pop("token", None),
|
| 357 |
+
allow_patterns=kwargs.pop("allow_patterns", None),
|
| 358 |
+
local_files_only=kwargs.pop("local_files_only", False),
|
| 359 |
+
force_download=kwargs.pop("force_download", False),
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
__all__ = ["ESMCSAEModel", "ESMCSAEOutput", "ESMCSAEPreTrainedModel"]
|
modeling_esmfold2.py
ADDED
|
@@ -0,0 +1,1288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch ESMFold2 model — the standard released architecture.
|
| 2 |
+
|
| 3 |
+
Quickstart::
|
| 4 |
+
|
| 5 |
+
from transformers import ESMFold2Model
|
| 6 |
+
|
| 7 |
+
model = ESMFold2Model.from_pretrained("biohub/ESMFold2").cuda().eval()
|
| 8 |
+
open("ubq.pdb", "w").write(model.infer_protein_as_pdb("MQIFVKTLTGKT..."))
|
| 9 |
+
|
| 10 |
+
For multi-chain / ligand / MSA inputs see ``ESMFold2InputBuilder`` in the
|
| 11 |
+
companion ``esm`` package.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import importlib
|
| 15 |
+
import math
|
| 16 |
+
import sys
|
| 17 |
+
from contextlib import contextmanager
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any, cast
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
from torch import Tensor
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
te = importlib.import_module("transformer_engine.pytorch")
|
| 28 |
+
te_recipe = importlib.import_module("transformer_engine.common.recipe")
|
| 29 |
+
DelayedScaling = te_recipe.DelayedScaling
|
| 30 |
+
Format = te_recipe.Format
|
| 31 |
+
|
| 32 |
+
TE_AVAILABLE = True
|
| 33 |
+
except ImportError:
|
| 34 |
+
te = None # type: ignore[assignment]
|
| 35 |
+
DelayedScaling = None # type: ignore[assignment]
|
| 36 |
+
Format = None # type: ignore[assignment]
|
| 37 |
+
TE_AVAILABLE = False
|
| 38 |
+
|
| 39 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 40 |
+
from .configuration_esmc import ESMCConfig as _FastPLMSESMCConfig
|
| 41 |
+
from .configuration_esmc_sae import ESMCSAEConfig as _FastPLMSESMCSAEConfig
|
| 42 |
+
from .configuration_esmfold2 import ESMFold2Config
|
| 43 |
+
from .modeling_esmc import ESMCModel as _FastPLMSESMCModel
|
| 44 |
+
from .modeling_esmc_sae import _ESMCSAELayer as _FastPLMSESMCSAELayer
|
| 45 |
+
from .modeling_esmfold2_common import (
|
| 46 |
+
CHAR_VOCAB_SIZE,
|
| 47 |
+
MAX_ATOMIC_NUMBER,
|
| 48 |
+
NUM_RES_TYPES,
|
| 49 |
+
DiffusionStructureHead,
|
| 50 |
+
FoldingTrunk,
|
| 51 |
+
InputsEmbedder,
|
| 52 |
+
LanguageModelShim,
|
| 53 |
+
MSAPairWeightedAveraging,
|
| 54 |
+
OuterProductMean,
|
| 55 |
+
ResIdxAsymIdSymIdEntityIdEncoding,
|
| 56 |
+
RowAttentionPooling,
|
| 57 |
+
SwiGLUMLP,
|
| 58 |
+
TriangleMultiplicativeUpdate,
|
| 59 |
+
_categorical_mean,
|
| 60 |
+
_compute_intra_token_idx,
|
| 61 |
+
compute_lm_hidden_states,
|
| 62 |
+
gather_rep_atom_coords,
|
| 63 |
+
gather_token_to_atom,
|
| 64 |
+
)
|
| 65 |
+
from .esmfold2_affine3d import Affine3D as _FastPLMSESMFold2Affine3D
|
| 66 |
+
from .esmfold2_aligner import Aligner as _FastPLMSESMFold2Aligner
|
| 67 |
+
from .esmfold2_atom_indexer import AtomIndexer as _FastPLMSESMFold2AtomIndexer
|
| 68 |
+
from .esmfold2_conformers import load_ccd as _fastplms_esmfold2_load_ccd
|
| 69 |
+
from .esmfold2_constants import ELEMENT_NUMBER_TO_SYMBOL as _FASTPLMS_ESMFOLD2_ELEMENT_NUMBER_TO_SYMBOL
|
| 70 |
+
from .esmfold2_constants_esm3 import CHAIN_BREAK_STR as _FASTPLMS_ESMFOLD2_CHAIN_BREAK_STR
|
| 71 |
+
from .esmfold2_input_builder import StructurePredictionInput as _FastPLMSESMFold2StructurePredictionInput
|
| 72 |
+
from .esmfold2_metrics import compute_rmsd as _fastplms_esmfold2_compute_rmsd
|
| 73 |
+
from .esmfold2_misc import slice_any_object as _fastplms_esmfold2_slice_any_object
|
| 74 |
+
from .esmfold2_mmcif_parsing import MmcifWrapper as _FastPLMSESMFold2MmcifWrapper
|
| 75 |
+
from .esmfold2_molecular_complex import MolecularComplex as _FastPLMSESMFold2MolecularComplex
|
| 76 |
+
from .esmfold2_msa import MSA as _FastPLMSESMFold2MSA
|
| 77 |
+
from .esmfold2_msa_filter_sequences import greedy_select_indices as _fastplms_esmfold2_greedy_select_indices
|
| 78 |
+
from .esmfold2_normalize_coordinates import normalize_coordinates as _fastplms_esmfold2_normalize_coordinates
|
| 79 |
+
from .esmfold2_output import build_molecular_complex_from_features as _fastplms_esmfold2_build_molecular_complex_from_features
|
| 80 |
+
from .esmfold2_paired_msa import construct_paired_msa as _fastplms_esmfold2_construct_paired_msa
|
| 81 |
+
from .esmfold2_parsing import FastaEntry as _FastPLMSESMFold2FastaEntry
|
| 82 |
+
from .esmfold2_predicted_aligned_error import compute_tm as _fastplms_esmfold2_compute_tm
|
| 83 |
+
from .esmfold2_prepare_input import prepare_esmfold2_input as _fastplms_esmfold2_prepare_esmfold2_input
|
| 84 |
+
from .esmfold2_processor import ESMFold2InputBuilder as _FastPLMSESMFold2InputBuilder
|
| 85 |
+
from .esmfold2_protein_chain import ProteinChain as _FastPLMSESMFold2ProteinChain
|
| 86 |
+
from .esmfold2_protein_complex import ProteinComplex as _FastPLMSESMFold2ProteinComplex
|
| 87 |
+
from .esmfold2_protein_structure import index_by_atom_name as _fastplms_esmfold2_index_by_atom_name
|
| 88 |
+
from .esmfold2_residue_constants import restypes as _FASTPLMS_ESMFOLD2_RESTYPES
|
| 89 |
+
from .esmfold2_sequential_dataclass import SequentialDataclass as _FastPLMSESMFold2SequentialDataclass
|
| 90 |
+
from .esmfold2_system import run_subprocess_with_errorcheck as _fastplms_esmfold2_run_subprocess_with_errorcheck
|
| 91 |
+
from .esmfold2_types import ProteinInput as _FastPLMSESMFold2ProteinInput
|
| 92 |
+
from .esmfold2_utils_types import PathOrBuffer as _FastPLMSESMFold2PathOrBuffer
|
| 93 |
+
|
| 94 |
+
_EPS = 1e-6
|
| 95 |
+
_NONPOLYMER_ID = 4
|
| 96 |
+
|
| 97 |
+
# Default for the triangle / OPM / pair-transition L² ops. Caps peak memory
|
| 98 |
+
# so L≈2k folds on an 80 GB GPU (~76 GB peak at chunk=128 for L=1438;
|
| 99 |
+
# chunk=64 leaves headroom for the largest foldbench targets). Override via
|
| 100 |
+
# ``model.set_chunk_size(...)``; pass None to disable chunking (faster for
|
| 101 |
+
# short L but OOM-prone past ~600).
|
| 102 |
+
_DEFAULT_CHUNK_SIZE = 64
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _ensure_vendored_esm_alias() -> None:
|
| 106 |
+
package = __package__
|
| 107 |
+
assert package is not None
|
| 108 |
+
vendored_esm = importlib.import_module(f"{package}.esm")
|
| 109 |
+
sys.modules["esm"] = vendored_esm
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class PairTransition(nn.Module):
|
| 113 |
+
"""LayerNorm + SwiGLU feed-forward residual block on the pair representation."""
|
| 114 |
+
|
| 115 |
+
def __init__(self, d_model: int, expansion_ratio: int = 4) -> None:
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.norm = nn.LayerNorm(d_model)
|
| 118 |
+
self.ffn = SwiGLUMLP(d_model, expansion_ratio=expansion_ratio, bias=False)
|
| 119 |
+
self._chunk_size: int | None = _DEFAULT_CHUNK_SIZE
|
| 120 |
+
|
| 121 |
+
def set_chunk_size(self, chunk_size: int | None) -> None:
|
| 122 |
+
self._chunk_size = chunk_size
|
| 123 |
+
|
| 124 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 125 |
+
if self._chunk_size is None or x.shape[1] <= self._chunk_size:
|
| 126 |
+
return self.ffn(self.norm(x))
|
| 127 |
+
out: list[Tensor] = []
|
| 128 |
+
for s in range(0, x.shape[1], self._chunk_size):
|
| 129 |
+
e = min(s + self._chunk_size, x.shape[1])
|
| 130 |
+
sl = x[:, s:e]
|
| 131 |
+
out.append(self.ffn(self.norm(sl)))
|
| 132 |
+
return torch.cat(out, dim=1)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class ConfidenceHead(nn.Module):
|
| 136 |
+
"""Predicts pLDDT, PAE, PDE, resolved-atom probability and distogram bins."""
|
| 137 |
+
|
| 138 |
+
boundaries: Tensor
|
| 139 |
+
|
| 140 |
+
def __init__(self, config: "ESMFold2Config") -> None:
|
| 141 |
+
super().__init__()
|
| 142 |
+
ch = config.confidence_head
|
| 143 |
+
d_single = config.d_single
|
| 144 |
+
d_pair = config.d_pair
|
| 145 |
+
d_inputs = config.inputs.d_inputs
|
| 146 |
+
|
| 147 |
+
boundaries = torch.linspace(ch.min_dist, ch.max_dist, ch.distogram_bins - 1)
|
| 148 |
+
self.register_buffer("boundaries", boundaries)
|
| 149 |
+
self.dist_bin_pairwise_embed = nn.Embedding(ch.distogram_bins, d_pair)
|
| 150 |
+
|
| 151 |
+
self.s_norm = nn.LayerNorm(d_single)
|
| 152 |
+
self.s_inputs_to_single = nn.Linear(d_inputs, d_single, bias=False)
|
| 153 |
+
self.s_to_z = nn.Linear(d_inputs, d_pair, bias=False)
|
| 154 |
+
self.s_to_z_transpose = nn.Linear(d_inputs, d_pair, bias=False)
|
| 155 |
+
self.s_to_z_prod_in1 = nn.Linear(d_inputs, d_pair, bias=False)
|
| 156 |
+
self.s_to_z_prod_in2 = nn.Linear(d_inputs, d_pair, bias=False)
|
| 157 |
+
self.s_to_z_prod_out = nn.Linear(d_pair, d_pair, bias=False)
|
| 158 |
+
self.s_input_to_s = nn.Linear(d_inputs, d_single, bias=False)
|
| 159 |
+
self.s_inputs_norm = nn.LayerNorm(d_inputs)
|
| 160 |
+
self.z_norm = nn.LayerNorm(d_pair)
|
| 161 |
+
|
| 162 |
+
self.row_attention_pooling = RowAttentionPooling(
|
| 163 |
+
d_pair=d_pair, d_single=d_single
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
pf = ch.folding_trunk
|
| 167 |
+
self.folding_trunk = FoldingTrunk(
|
| 168 |
+
n_layers=pf.n_layers, d_pair=d_pair, expansion_ratio=4
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# Heads.
|
| 172 |
+
self.plddt_ln = nn.LayerNorm(d_single)
|
| 173 |
+
max_atoms_per_token = 23
|
| 174 |
+
self.plddt_weight = nn.Parameter(
|
| 175 |
+
torch.zeros(max_atoms_per_token, d_single, ch.num_plddt_bins)
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
self.pae_ln = nn.LayerNorm(d_pair)
|
| 179 |
+
self.pae_head = nn.Linear(d_pair, ch.num_pae_bins, bias=False)
|
| 180 |
+
|
| 181 |
+
self.pde_ln = nn.LayerNorm(d_pair)
|
| 182 |
+
self.pde_head = nn.Linear(d_pair, ch.num_pde_bins, bias=False)
|
| 183 |
+
|
| 184 |
+
self.resolved_ln = nn.LayerNorm(d_single)
|
| 185 |
+
# 2 = resolved logits ([unresolved, resolved]).
|
| 186 |
+
self.resolved_weight = nn.Parameter(
|
| 187 |
+
torch.zeros(max_atoms_per_token, d_single, 2)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def set_kernel_backend(self, backend: str | None) -> None:
|
| 191 |
+
self.folding_trunk.set_kernel_backend(backend)
|
| 192 |
+
|
| 193 |
+
def set_chunk_size(self, chunk_size: int | None) -> None:
|
| 194 |
+
self.folding_trunk.set_chunk_size(chunk_size)
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def _repeat_batch(x: Tensor, num_diffusion_samples: int) -> Tensor:
|
| 198 |
+
return (
|
| 199 |
+
x
|
| 200 |
+
if num_diffusion_samples == 1
|
| 201 |
+
else x.repeat_interleave(num_diffusion_samples, 0)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
def _flatten_sample_axis(x: Tensor) -> Tensor:
|
| 206 |
+
if x.ndim == 4:
|
| 207 |
+
b, mult, n, c = x.shape
|
| 208 |
+
return x.reshape(b * mult, n, c)
|
| 209 |
+
return x
|
| 210 |
+
|
| 211 |
+
def forward(
|
| 212 |
+
self,
|
| 213 |
+
s_inputs: Tensor,
|
| 214 |
+
z: Tensor,
|
| 215 |
+
x_pred: Tensor,
|
| 216 |
+
distogram_atom_idx: Tensor,
|
| 217 |
+
token_attention_mask: Tensor,
|
| 218 |
+
atom_to_token: Tensor,
|
| 219 |
+
atom_attention_mask: Tensor,
|
| 220 |
+
asym_id: Tensor,
|
| 221 |
+
mol_type: Tensor,
|
| 222 |
+
num_diffusion_samples: int = 1,
|
| 223 |
+
relative_position_encoding: Tensor | None = None,
|
| 224 |
+
token_bonds_encoding: Tensor | None = None,
|
| 225 |
+
) -> dict[str, Tensor]:
|
| 226 |
+
s_inputs_normed = self.s_inputs_norm(s_inputs)
|
| 227 |
+
|
| 228 |
+
z_base = self.z_norm(z)
|
| 229 |
+
if relative_position_encoding is not None:
|
| 230 |
+
z_base = z_base + relative_position_encoding
|
| 231 |
+
if token_bonds_encoding is not None:
|
| 232 |
+
z_base = z_base + token_bonds_encoding
|
| 233 |
+
z_base = z_base + self.s_to_z(s_inputs_normed).unsqueeze(2)
|
| 234 |
+
z_base = z_base + self.s_to_z_transpose(s_inputs_normed).unsqueeze(1)
|
| 235 |
+
z_base = z_base + self.s_to_z_prod_out(
|
| 236 |
+
self.s_to_z_prod_in1(s_inputs_normed)[:, :, None, :]
|
| 237 |
+
* self.s_to_z_prod_in2(s_inputs_normed)[:, None, :, :]
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
pair = self._repeat_batch(z_base, num_diffusion_samples)
|
| 241 |
+
x_pred_flat = self._flatten_sample_axis(x_pred)
|
| 242 |
+
atom_to_token_m = self._repeat_batch(atom_to_token, num_diffusion_samples)
|
| 243 |
+
atom_mask_m = self._repeat_batch(atom_attention_mask, num_diffusion_samples)
|
| 244 |
+
rep_idx_m = self._repeat_batch(distogram_atom_idx, num_diffusion_samples).long()
|
| 245 |
+
mask = self._repeat_batch(token_attention_mask, num_diffusion_samples)
|
| 246 |
+
Bm = pair.shape[0]
|
| 247 |
+
|
| 248 |
+
rep_coords = gather_rep_atom_coords(x_pred_flat, rep_idx_m)
|
| 249 |
+
rep_distances = torch.cdist(
|
| 250 |
+
rep_coords, rep_coords, compute_mode="donot_use_mm_for_euclid_dist"
|
| 251 |
+
)
|
| 252 |
+
distogram_bins = (
|
| 253 |
+
(rep_distances.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
|
| 254 |
+
)
|
| 255 |
+
pair = pair + self.dist_bin_pairwise_embed(distogram_bins)
|
| 256 |
+
|
| 257 |
+
pair_mask = mask[:, :, None].float() * mask[:, None, :].float()
|
| 258 |
+
|
| 259 |
+
# FoldingTrunk handles the bf16 cast internally during inference so
|
| 260 |
+
# each block's fused trimul engages. In-place residual avoids an
|
| 261 |
+
# extra fp32 pair allocation.
|
| 262 |
+
with torch.amp.autocast("cuda", enabled=pair.is_cuda, dtype=torch.bfloat16):
|
| 263 |
+
pair_delta = self.folding_trunk(pair, pair_attention_mask=pair_mask)
|
| 264 |
+
pair.add_(pair_delta.float())
|
| 265 |
+
del pair_delta
|
| 266 |
+
single = self.row_attention_pooling(pair, mask)
|
| 267 |
+
|
| 268 |
+
atom_mask_f = atom_mask_m.float()
|
| 269 |
+
s_at_atoms = gather_token_to_atom(single, atom_to_token_m)
|
| 270 |
+
s_at_atoms_ln = self.plddt_ln(s_at_atoms)
|
| 271 |
+
|
| 272 |
+
intra_idx = _compute_intra_token_idx(atom_to_token_m)
|
| 273 |
+
intra_idx = intra_idx.clamp(max=self.plddt_weight.shape[0] - 1)
|
| 274 |
+
w_plddt = self.plddt_weight[intra_idx]
|
| 275 |
+
plddt_logits = torch.einsum("...c,...cb->...b", s_at_atoms_ln, w_plddt)
|
| 276 |
+
plddt_per_atom = _categorical_mean(plddt_logits, start=0.0, end=1.0)
|
| 277 |
+
|
| 278 |
+
L = single.shape[1]
|
| 279 |
+
plddt_sum = torch.zeros(Bm, L, device=single.device, dtype=plddt_per_atom.dtype)
|
| 280 |
+
atom_count = torch.zeros(
|
| 281 |
+
Bm, L, device=single.device, dtype=plddt_per_atom.dtype
|
| 282 |
+
)
|
| 283 |
+
atom_mask_t = atom_mask_f.to(plddt_per_atom.dtype)
|
| 284 |
+
plddt_sum.scatter_add_(1, atom_to_token_m, plddt_per_atom * atom_mask_t)
|
| 285 |
+
atom_count.scatter_add_(1, atom_to_token_m, atom_mask_t)
|
| 286 |
+
plddt = plddt_sum / atom_count.clamp(min=1e-6)
|
| 287 |
+
|
| 288 |
+
complex_plddt = (plddt_per_atom * atom_mask_f).sum(dim=-1) / (
|
| 289 |
+
atom_mask_f.sum(dim=-1) + _EPS
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
expanded_type = self._repeat_batch(mol_type, num_diffusion_samples)
|
| 293 |
+
expanded_asym = self._repeat_batch(asym_id, num_diffusion_samples)
|
| 294 |
+
is_ligand = (expanded_type == _NONPOLYMER_ID).float()
|
| 295 |
+
inter_chain = (
|
| 296 |
+
expanded_asym.unsqueeze(-1) != expanded_asym.unsqueeze(-2)
|
| 297 |
+
).float()
|
| 298 |
+
near_contact = (rep_distances < 8).float()
|
| 299 |
+
interface_per_token = (
|
| 300 |
+
near_contact * inter_chain * (1.0 - is_ligand).unsqueeze(-1)
|
| 301 |
+
).amax(dim=-1)
|
| 302 |
+
iplddt_weight = torch.where(
|
| 303 |
+
is_ligand.bool(),
|
| 304 |
+
torch.full_like(interface_per_token, 2.0),
|
| 305 |
+
interface_per_token,
|
| 306 |
+
)
|
| 307 |
+
iplddt_weight_atoms = gather_token_to_atom(
|
| 308 |
+
iplddt_weight.unsqueeze(-1), atom_to_token_m
|
| 309 |
+
).squeeze(-1)
|
| 310 |
+
atom_iplddt_w = atom_mask_f * iplddt_weight_atoms
|
| 311 |
+
complex_iplddt = (plddt_per_atom * atom_iplddt_w).sum(dim=-1) / (
|
| 312 |
+
atom_iplddt_w.sum(dim=-1) + _EPS
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
plddt_ca = plddt_per_atom.gather(1, rep_idx_m)
|
| 316 |
+
|
| 317 |
+
# PAE
|
| 318 |
+
pae_logits = self.pae_head(self.pae_ln(pair))
|
| 319 |
+
pae = _categorical_mean(pae_logits, start=0.0, end=32.0).detach()
|
| 320 |
+
|
| 321 |
+
# PDE
|
| 322 |
+
pde_logits = self.pde_head(self.pde_ln(pair))
|
| 323 |
+
pde = _categorical_mean(pde_logits, start=0.0, end=32.0).detach()
|
| 324 |
+
|
| 325 |
+
# Resolved (per-atom binary).
|
| 326 |
+
s_at_atoms_res = self.resolved_ln(s_at_atoms)
|
| 327 |
+
w_res = self.resolved_weight[intra_idx]
|
| 328 |
+
resolved_logits = torch.einsum("...c,...cb->...b", s_at_atoms_res, w_res)
|
| 329 |
+
|
| 330 |
+
# pTM / ipTM from pae_logits.
|
| 331 |
+
n_bins = pae_logits.shape[-1]
|
| 332 |
+
bin_width = 32.0 / n_bins
|
| 333 |
+
bin_centers = torch.arange(
|
| 334 |
+
0.5 * bin_width, 32.0, bin_width, device=pae_logits.device
|
| 335 |
+
)
|
| 336 |
+
mask_f = mask.float()
|
| 337 |
+
N_res = mask_f.sum(dim=-1, keepdim=True)
|
| 338 |
+
d0 = 1.24 * (N_res.clamp(min=19) - 15) ** (1 / 3) - 1.8
|
| 339 |
+
tm_per_bin = 1 / (1 + (bin_centers / d0) ** 2)
|
| 340 |
+
pae_probs = F.softmax(pae_logits, dim=-1)
|
| 341 |
+
tm_expected = (pae_probs * tm_per_bin[:, None, None, :]).sum(dim=-1)
|
| 342 |
+
|
| 343 |
+
pair_mask_2d = mask_f.unsqueeze(-1) * mask_f.unsqueeze(-2)
|
| 344 |
+
ptm_per_row = (tm_expected * pair_mask_2d).sum(dim=-1) / (
|
| 345 |
+
pair_mask_2d.sum(dim=-1) + _EPS
|
| 346 |
+
)
|
| 347 |
+
ptm = ptm_per_row.max(dim=-1).values
|
| 348 |
+
|
| 349 |
+
inter_chain_mask = (
|
| 350 |
+
expanded_asym.unsqueeze(-1) != expanded_asym.unsqueeze(-2)
|
| 351 |
+
).float() * pair_mask_2d
|
| 352 |
+
iptm_per_row = (tm_expected * inter_chain_mask).sum(dim=-1) / (
|
| 353 |
+
inter_chain_mask.sum(dim=-1) + _EPS
|
| 354 |
+
)
|
| 355 |
+
iptm = iptm_per_row.max(dim=-1).values
|
| 356 |
+
|
| 357 |
+
max_chain_id = int(expanded_asym.max().item()) if Bm > 0 else 0
|
| 358 |
+
n_chains = max_chain_id + 1
|
| 359 |
+
pair_chains_iptm = torch.zeros(
|
| 360 |
+
Bm, n_chains, n_chains, device=tm_expected.device, dtype=tm_expected.dtype
|
| 361 |
+
)
|
| 362 |
+
for c1 in range(n_chains):
|
| 363 |
+
chain_c1 = (expanded_asym == c1).float() * mask_f
|
| 364 |
+
if chain_c1.sum() == 0:
|
| 365 |
+
continue
|
| 366 |
+
for c2 in range(n_chains):
|
| 367 |
+
chain_c2 = (expanded_asym == c2).float() * mask_f
|
| 368 |
+
pair_m = chain_c1.unsqueeze(-1) * chain_c2.unsqueeze(-2)
|
| 369 |
+
denom = pair_m.sum(dim=(-1, -2)) + _EPS
|
| 370 |
+
pair_chains_iptm[:, c1, c2] = (tm_expected * pair_m).sum(
|
| 371 |
+
dim=(-1, -2)
|
| 372 |
+
) / denom
|
| 373 |
+
|
| 374 |
+
return {
|
| 375 |
+
"plddt_logits": plddt_logits,
|
| 376 |
+
"plddt": plddt.detach(),
|
| 377 |
+
"plddt_per_atom": plddt_per_atom.detach(),
|
| 378 |
+
"plddt_ca": plddt_ca.detach(),
|
| 379 |
+
"complex_plddt": complex_plddt.detach(),
|
| 380 |
+
"complex_iplddt": complex_iplddt.detach(),
|
| 381 |
+
"pae_logits": pae_logits,
|
| 382 |
+
"pae": pae,
|
| 383 |
+
"pde_logits": pde_logits,
|
| 384 |
+
"pde": pde,
|
| 385 |
+
"resolved_logits": resolved_logits,
|
| 386 |
+
"ptm": ptm.detach(),
|
| 387 |
+
"iptm": iptm.detach(),
|
| 388 |
+
"pair_chains_iptm": pair_chains_iptm.detach(),
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def _inverse_softplus(value: float) -> float:
|
| 393 |
+
return value + math.log(-math.expm1(-value))
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def _convert_te_modules_to_fp8_inplace(module: nn.Module) -> None:
|
| 397 |
+
"""Re-init each TE module via quantized_model_init so weights live as fp8.
|
| 398 |
+
|
| 399 |
+
Must be called inside torch.no_grad(); covers nn.Linear, te.Linear,
|
| 400 |
+
te.LayerNormLinear, te.LayerNormMLP — the last two hold 99% of ESMC weight.
|
| 401 |
+
"""
|
| 402 |
+
if not TE_AVAILABLE:
|
| 403 |
+
raise RuntimeError("transformer_engine is not available; cannot use fp8.")
|
| 404 |
+
quantized_model_init = importlib.import_module(
|
| 405 |
+
"transformer_engine.pytorch"
|
| 406 |
+
).quantized_model_init
|
| 407 |
+
|
| 408 |
+
def _walk(mod: nn.Module) -> None:
|
| 409 |
+
for name, child in list(mod.named_children()):
|
| 410 |
+
replaced = False
|
| 411 |
+
if isinstance(child, nn.Linear):
|
| 412 |
+
in_f, out_f = child.in_features, child.out_features
|
| 413 |
+
has_bias = child.bias is not None
|
| 414 |
+
device = child.weight.device
|
| 415 |
+
dtype = child.weight.dtype
|
| 416 |
+
w = child.weight.data
|
| 417 |
+
b = child.bias.data if has_bias else None
|
| 418 |
+
setattr(mod, name, nn.Identity())
|
| 419 |
+
del child
|
| 420 |
+
torch.cuda.empty_cache()
|
| 421 |
+
with quantized_model_init(enabled=True):
|
| 422 |
+
new_mod = te.Linear( # type: ignore[union-attr]
|
| 423 |
+
in_f, out_f, bias=has_bias, params_dtype=dtype
|
| 424 |
+
).to(device)
|
| 425 |
+
new_mod.weight.quantize_(w) # type: ignore[attr-defined,operator]
|
| 426 |
+
if has_bias:
|
| 427 |
+
assert b is not None
|
| 428 |
+
new_mod.bias.data.copy_(b) # type: ignore[union-attr]
|
| 429 |
+
del w, b
|
| 430 |
+
replaced = True
|
| 431 |
+
elif isinstance(child, te.Linear): # type: ignore[union-attr]
|
| 432 |
+
# te.Linear with bf16 weight → re-init inside quantized_model_init for fp8.
|
| 433 |
+
in_f, out_f = child.in_features, child.out_features
|
| 434 |
+
has_bias = child.bias is not None
|
| 435 |
+
device = child.weight.device
|
| 436 |
+
dtype = (
|
| 437 |
+
child.weight.dtype
|
| 438 |
+
if not hasattr(child.weight, "_data")
|
| 439 |
+
else torch.bfloat16
|
| 440 |
+
)
|
| 441 |
+
state = {k: v.detach().clone() for k, v in child.state_dict().items()}
|
| 442 |
+
setattr(mod, name, nn.Identity())
|
| 443 |
+
del child
|
| 444 |
+
torch.cuda.empty_cache()
|
| 445 |
+
with quantized_model_init(enabled=True):
|
| 446 |
+
new_mod = te.Linear( # type: ignore[union-attr]
|
| 447 |
+
in_f,
|
| 448 |
+
out_f,
|
| 449 |
+
bias=has_bias,
|
| 450 |
+
params_dtype=dtype, # type: ignore[arg-type]
|
| 451 |
+
).to(device) # type: ignore[arg-type]
|
| 452 |
+
new_mod.load_state_dict(state, strict=False)
|
| 453 |
+
replaced = True
|
| 454 |
+
elif (
|
| 455 |
+
hasattr(te, "LayerNormLinear") and isinstance(child, te.LayerNormLinear) # type: ignore[union-attr]
|
| 456 |
+
):
|
| 457 |
+
state = {k: v.detach().clone() for k, v in child.state_dict().items()}
|
| 458 |
+
hidden_size = child.in_features
|
| 459 |
+
out_features = child.out_features
|
| 460 |
+
has_bias = child.use_bias
|
| 461 |
+
device = next(child.parameters()).device
|
| 462 |
+
setattr(mod, name, nn.Identity())
|
| 463 |
+
del child
|
| 464 |
+
torch.cuda.empty_cache()
|
| 465 |
+
with quantized_model_init(enabled=True):
|
| 466 |
+
new_mod = te.LayerNormLinear( # type: ignore[union-attr]
|
| 467 |
+
hidden_size,
|
| 468 |
+
out_features,
|
| 469 |
+
bias=has_bias,
|
| 470 |
+
params_dtype=torch.bfloat16,
|
| 471 |
+
).to(device)
|
| 472 |
+
new_mod.load_state_dict(state, strict=False)
|
| 473 |
+
replaced = True
|
| 474 |
+
elif (
|
| 475 |
+
hasattr(te, "LayerNormMLP") and isinstance(child, te.LayerNormMLP) # type: ignore[union-attr]
|
| 476 |
+
):
|
| 477 |
+
state = {k: v.detach().clone() for k, v in child.state_dict().items()}
|
| 478 |
+
fc1_weight: Tensor = child.fc1_weight # type: ignore[attr-defined]
|
| 479 |
+
hidden_size = int(fc1_weight.shape[1])
|
| 480 |
+
# fc1 packed as (2*ffn_hidden_size, hidden_size) for swiglu.
|
| 481 |
+
ffn_hidden_size = int(fc1_weight.shape[0]) // 2
|
| 482 |
+
has_bias = (
|
| 483 |
+
getattr(child, "fc1_bias", None) is not None
|
| 484 |
+
and child.fc1_bias is not None # type: ignore[attr-defined]
|
| 485 |
+
)
|
| 486 |
+
device = fc1_weight.device
|
| 487 |
+
setattr(mod, name, nn.Identity())
|
| 488 |
+
del child
|
| 489 |
+
torch.cuda.empty_cache()
|
| 490 |
+
with quantized_model_init(enabled=True):
|
| 491 |
+
new_mod = te.LayerNormMLP( # type: ignore[union-attr]
|
| 492 |
+
hidden_size=hidden_size,
|
| 493 |
+
ffn_hidden_size=ffn_hidden_size,
|
| 494 |
+
bias=has_bias,
|
| 495 |
+
activation="swiglu",
|
| 496 |
+
params_dtype=torch.bfloat16,
|
| 497 |
+
).to(device) # type: ignore[arg-type]
|
| 498 |
+
new_mod.load_state_dict(state, strict=False)
|
| 499 |
+
replaced = True
|
| 500 |
+
|
| 501 |
+
if replaced:
|
| 502 |
+
# Freeze via .eval()+.requires_grad_(False); per-param ops would unwrap Float8Tensor.
|
| 503 |
+
new_mod.eval().requires_grad_(False)
|
| 504 |
+
setattr(mod, name, new_mod)
|
| 505 |
+
torch.cuda.empty_cache()
|
| 506 |
+
else:
|
| 507 |
+
_walk(child)
|
| 508 |
+
|
| 509 |
+
_walk(module)
|
| 510 |
+
torch.cuda.empty_cache()
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
@contextmanager
|
| 514 |
+
def _lm_precision_context(fp8: bool):
|
| 515 |
+
"""bf16 autocast (+ optional TE fp8 autocast) around the LM forward.
|
| 516 |
+
|
| 517 |
+
te.autocast keeps te.Linear outputs bf16 instead of the fp32 default
|
| 518 |
+
(~425 MB at L=1024 in the hidden-state cache).
|
| 519 |
+
"""
|
| 520 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 521 |
+
if fp8 and TE_AVAILABLE:
|
| 522 |
+
fp8_recipe = DelayedScaling( # type: ignore[misc]
|
| 523 |
+
fp8_format=Format.HYBRID, # type: ignore[union-attr]
|
| 524 |
+
amax_history_len=1,
|
| 525 |
+
amax_compute_algo="most_recent",
|
| 526 |
+
)
|
| 527 |
+
with te.autocast(enabled=True, recipe=fp8_recipe): # type: ignore[union-attr]
|
| 528 |
+
yield
|
| 529 |
+
else:
|
| 530 |
+
yield
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
class ESMFold2Model(PreTrainedModel):
|
| 534 |
+
"""ESMFold2 — all-atom structure prediction with an ESMC PLM backbone.
|
| 535 |
+
|
| 536 |
+
This is the standard released ESMFold2 architecture (uses a linear-
|
| 537 |
+
recurrent trunk, internally referred to as "parcae").
|
| 538 |
+
|
| 539 |
+
Forward kwargs that callers commonly override:
|
| 540 |
+
|
| 541 |
+
* ``num_loops`` (default ``config.num_loops``): trunk refinement
|
| 542 |
+
loops.
|
| 543 |
+
* ``num_diffusion_samples`` (default ``config.num_diffusion_samples``):
|
| 544 |
+
parallel structure samples; the confidence head re-runs once per
|
| 545 |
+
sample, so memory scales linearly. Pass ``1`` for cheap inference.
|
| 546 |
+
* ``num_sampling_steps`` (default ``config.structure_head.inference_num_steps``):
|
| 547 |
+
diffusion ODE solver steps. Lower for speed, higher for quality.
|
| 548 |
+
|
| 549 |
+
Memory / perf knobs:
|
| 550 |
+
|
| 551 |
+
* ``model.set_chunk_size(int|None)``: caps L² ops (triangle / OPM /
|
| 552 |
+
pair transition) at this token-axis chunk. Default 64 — fits
|
| 553 |
+
L≈2k on an 80 GB GPU. Pass ``None`` for faster inference at L<600.
|
| 554 |
+
* ``model.set_kernel_backend(None | "fused" | "cuequivariance")``:
|
| 555 |
+
select kernel backend (None = reference path).
|
| 556 |
+
"""
|
| 557 |
+
|
| 558 |
+
config_class = ESMFold2Config
|
| 559 |
+
_keys_to_ignore_on_load_unexpected = [r"\._extra_state$"]
|
| 560 |
+
|
| 561 |
+
def __init__(self, config: ESMFold2Config) -> None:
|
| 562 |
+
super().__init__(config)
|
| 563 |
+
d_inputs = config.inputs.d_inputs
|
| 564 |
+
d_pair = config.d_pair
|
| 565 |
+
|
| 566 |
+
self.inputs_embedder = InputsEmbedder(config)
|
| 567 |
+
self.z_init_1 = nn.Linear(d_inputs, d_pair, bias=False)
|
| 568 |
+
self.z_init_2 = nn.Linear(d_inputs, d_pair, bias=False)
|
| 569 |
+
self.rel_pos = ResIdxAsymIdSymIdEntityIdEncoding(
|
| 570 |
+
n_relative_residx_bins=config.n_relative_residx_bins,
|
| 571 |
+
n_relative_chain_bins=config.n_relative_chain_bins,
|
| 572 |
+
d_pair=d_pair,
|
| 573 |
+
)
|
| 574 |
+
self.token_bonds = nn.Linear(1, d_pair, bias=False)
|
| 575 |
+
self.language_model = LanguageModelShim(
|
| 576 |
+
d_z=d_pair, d_model=config.lm_d_model, num_layers=config.lm_num_layers
|
| 577 |
+
)
|
| 578 |
+
self._esmc: nn.Module | None = None
|
| 579 |
+
self._esmc_fp8: bool = False # set by load_esmc(fp8=True)
|
| 580 |
+
self._esmfold2_input_builder: Any | None = None
|
| 581 |
+
|
| 582 |
+
pf = config.folding_trunk
|
| 583 |
+
self.folding_trunk = FoldingTrunk(
|
| 584 |
+
n_layers=pf.n_layers, d_pair=d_pair, expansion_ratio=4
|
| 585 |
+
)
|
| 586 |
+
if config.lm_encoder.enabled:
|
| 587 |
+
self.lm_encoder: FoldingTrunk | None = FoldingTrunk(
|
| 588 |
+
n_layers=config.lm_encoder.n_layers, d_pair=d_pair, expansion_ratio=4
|
| 589 |
+
)
|
| 590 |
+
else:
|
| 591 |
+
self.lm_encoder = None
|
| 592 |
+
|
| 593 |
+
self.parcae_input_norm = nn.LayerNorm(d_pair)
|
| 594 |
+
self.parcae_log_a = nn.Parameter(torch.zeros(d_pair))
|
| 595 |
+
parcae_decay_init = math.sqrt(1.0 / 5.0)
|
| 596 |
+
parcae_delta_init = -math.log(parcae_decay_init)
|
| 597 |
+
self.parcae_log_delta = nn.Parameter(
|
| 598 |
+
torch.full(
|
| 599 |
+
(d_pair,), _inverse_softplus(parcae_delta_init), dtype=torch.float32
|
| 600 |
+
)
|
| 601 |
+
)
|
| 602 |
+
self.parcae_b_cont = nn.Parameter(torch.eye(d_pair))
|
| 603 |
+
self.parcae_readout = nn.Linear(d_pair, d_pair, bias=False)
|
| 604 |
+
nn.init.eye_(self.parcae_readout.weight)
|
| 605 |
+
self.parcae_coda = FoldingTrunk(
|
| 606 |
+
n_layers=config.parcae.coda_n_layers, d_pair=d_pair, expansion_ratio=4
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# Heads --------------------------------------------------------------
|
| 610 |
+
self.structure_head = DiffusionStructureHead(config)
|
| 611 |
+
self.distogram_head = nn.Linear(
|
| 612 |
+
d_pair, config.structure_head.distogram_bins, bias=True
|
| 613 |
+
)
|
| 614 |
+
self.confidence_head = ConfidenceHead(config)
|
| 615 |
+
|
| 616 |
+
msa_cfg = config.msa_encoder
|
| 617 |
+
self.msa_encoder = None
|
| 618 |
+
if msa_cfg.enabled:
|
| 619 |
+
self.msa_encoder = MSAEncoder(
|
| 620 |
+
d_msa=msa_cfg.d_msa,
|
| 621 |
+
d_pair=d_pair,
|
| 622 |
+
d_inputs=d_inputs,
|
| 623 |
+
d_hidden=msa_cfg.d_hidden,
|
| 624 |
+
n_layers=msa_cfg.n_layers,
|
| 625 |
+
n_heads_msa=msa_cfg.n_heads_msa,
|
| 626 |
+
msa_head_width=msa_cfg.msa_head_width,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
self.post_init()
|
| 630 |
+
|
| 631 |
+
def load_esmc(self, esmc_model_path: str, precision: str = "bf16") -> None:
|
| 632 |
+
"""Load the ESMC LM.
|
| 633 |
+
|
| 634 |
+
``precision``: ``"bf16"`` (default), ``"fp32"``, or ``"fp8"``.
|
| 635 |
+
``"fp8"`` requires H100 + TransformerEngine ≥ 2.x and quantizes
|
| 636 |
+
every TE module's weights to fp8 storage.
|
| 637 |
+
"""
|
| 638 |
+
from .modeling_esmc import ESMCModel
|
| 639 |
+
|
| 640 |
+
dtype_map = {
|
| 641 |
+
"bf16": torch.bfloat16,
|
| 642 |
+
"fp32": torch.float32,
|
| 643 |
+
"fp8": torch.bfloat16, # underlying weights stay bf16, TE re-quantizes to fp8
|
| 644 |
+
}
|
| 645 |
+
if precision not in dtype_map:
|
| 646 |
+
raise ValueError(
|
| 647 |
+
f"precision must be one of {list(dtype_map)}, got {precision!r}"
|
| 648 |
+
)
|
| 649 |
+
dtype = dtype_map[precision]
|
| 650 |
+
|
| 651 |
+
esmc = (
|
| 652 |
+
ESMCModel.from_pretrained(esmc_model_path)
|
| 653 |
+
.to(device=self.device, dtype=dtype)
|
| 654 |
+
.eval()
|
| 655 |
+
)
|
| 656 |
+
for p in esmc.parameters():
|
| 657 |
+
p.requires_grad_(False)
|
| 658 |
+
|
| 659 |
+
if precision == "fp8":
|
| 660 |
+
if not TE_AVAILABLE:
|
| 661 |
+
raise RuntimeError(
|
| 662 |
+
"transformer_engine is not available; cannot use fp8."
|
| 663 |
+
)
|
| 664 |
+
with torch.no_grad():
|
| 665 |
+
_convert_te_modules_to_fp8_inplace(esmc)
|
| 666 |
+
self._esmc_fp8 = True
|
| 667 |
+
else:
|
| 668 |
+
self._esmc_fp8 = False
|
| 669 |
+
|
| 670 |
+
self._esmc = esmc
|
| 671 |
+
|
| 672 |
+
@classmethod
|
| 673 |
+
def from_pretrained(
|
| 674 |
+
cls, pretrained_model_name_or_path, *args, load_esmc: bool = True, **kwargs
|
| 675 |
+
):
|
| 676 |
+
if cls is ESMFold2Model and "config" not in kwargs:
|
| 677 |
+
config = ESMFold2Config.from_pretrained(
|
| 678 |
+
pretrained_model_name_or_path, **kwargs
|
| 679 |
+
)
|
| 680 |
+
if config.type == "experimental":
|
| 681 |
+
experimental_module = importlib.import_module(
|
| 682 |
+
f"{__package__}.modeling_esmfold2_experimental"
|
| 683 |
+
)
|
| 684 |
+
return experimental_module.ESMFold2ExperimentalModel.from_pretrained(
|
| 685 |
+
pretrained_model_name_or_path,
|
| 686 |
+
*args,
|
| 687 |
+
config=config,
|
| 688 |
+
load_esmc=load_esmc,
|
| 689 |
+
**kwargs,
|
| 690 |
+
)
|
| 691 |
+
kwargs["config"] = config
|
| 692 |
+
# Pop the precision knob before forwarding to the HF loader.
|
| 693 |
+
esmc_precision = kwargs.pop("esmc_precision", "bf16")
|
| 694 |
+
model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
| 695 |
+
if load_esmc:
|
| 696 |
+
model.load_esmc(model.config.esmc_id, precision=esmc_precision)
|
| 697 |
+
return model
|
| 698 |
+
|
| 699 |
+
def set_kernel_backend(self, backend: str | None) -> None:
|
| 700 |
+
"""Select kernel backend.
|
| 701 |
+
|
| 702 |
+
Args:
|
| 703 |
+
backend: ``None`` (reference path), ``"fused"`` (vendored Triton
|
| 704 |
+
kernels), or ``"cuequivariance"`` (cuequivariance kernels
|
| 705 |
+
where applicable; vanilla python fallback otherwise).
|
| 706 |
+
"""
|
| 707 |
+
self.folding_trunk.set_kernel_backend(backend)
|
| 708 |
+
if self.lm_encoder is not None:
|
| 709 |
+
self.lm_encoder.set_kernel_backend(backend)
|
| 710 |
+
self.parcae_coda.set_kernel_backend(backend)
|
| 711 |
+
self.confidence_head.set_kernel_backend(backend)
|
| 712 |
+
self.structure_head.set_kernel_backend(backend)
|
| 713 |
+
|
| 714 |
+
def apply_torch_compile(
|
| 715 |
+
self, mode: str = "fixed_seqlen", dynamic: bool | None = None
|
| 716 |
+
) -> None:
|
| 717 |
+
"""Compile L²-heavy blocks. ``mode='fixed_seqlen'`` recompiles per L; ``'dynamic_seqlen'`` compiles once.
|
| 718 |
+
|
| 719 |
+
Does NOT stack with our Triton kernels — call ``set_kernel_backend(None)``
|
| 720 |
+
before compiling.
|
| 721 |
+
"""
|
| 722 |
+
import torch._dynamo
|
| 723 |
+
|
| 724 |
+
torch._dynamo.config.cache_size_limit = 512 # type: ignore[attr-defined]
|
| 725 |
+
torch._dynamo.config.accumulated_cache_size_limit = 512 # type: ignore[attr-defined]
|
| 726 |
+
# capture_scalar_outputs avoids graph breaks at .item() in atom-attention path.
|
| 727 |
+
torch._dynamo.config.capture_scalar_outputs = True # type: ignore[attr-defined]
|
| 728 |
+
|
| 729 |
+
if dynamic is None:
|
| 730 |
+
dynamic = mode == "dynamic_seqlen"
|
| 731 |
+
kwargs: dict = {"dynamic": dynamic}
|
| 732 |
+
|
| 733 |
+
from .modeling_esmfold2_common import (
|
| 734 |
+
DiffusionModule,
|
| 735 |
+
DiffusionTransformer,
|
| 736 |
+
PairUpdateBlock,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
compile_targets = (
|
| 740 |
+
PairUpdateBlock,
|
| 741 |
+
DiffusionTransformer,
|
| 742 |
+
DiffusionModule,
|
| 743 |
+
MSAEncoderBlock,
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
def _maybe_compile(module: nn.Module) -> None:
|
| 747 |
+
if isinstance(module, compile_targets):
|
| 748 |
+
module.forward = torch.compile(module.forward, **kwargs) # type: ignore[assignment]
|
| 749 |
+
|
| 750 |
+
self.apply(_maybe_compile)
|
| 751 |
+
|
| 752 |
+
def set_chunk_size(self, chunk_size: int | None) -> None:
|
| 753 |
+
self.folding_trunk.set_chunk_size(chunk_size)
|
| 754 |
+
if self.lm_encoder is not None:
|
| 755 |
+
self.lm_encoder.set_chunk_size(chunk_size)
|
| 756 |
+
self.parcae_coda.set_chunk_size(chunk_size)
|
| 757 |
+
self.confidence_head.set_chunk_size(chunk_size)
|
| 758 |
+
if self.msa_encoder is not None:
|
| 759 |
+
self.msa_encoder.set_chunk_size(chunk_size)
|
| 760 |
+
|
| 761 |
+
def _compute_lm_hidden_states(
|
| 762 |
+
self,
|
| 763 |
+
input_ids: Tensor,
|
| 764 |
+
asym_id: Tensor,
|
| 765 |
+
residue_index: Tensor,
|
| 766 |
+
mol_type: Tensor,
|
| 767 |
+
tok_mask: Tensor,
|
| 768 |
+
) -> Tensor:
|
| 769 |
+
assert self._esmc is not None
|
| 770 |
+
# fp8 TE kernels require prod(shape[:-1]) % 8 == 0.
|
| 771 |
+
pad_to = 8 if self._esmc_fp8 else None
|
| 772 |
+
with _lm_precision_context(self._esmc_fp8):
|
| 773 |
+
return compute_lm_hidden_states(
|
| 774 |
+
self._esmc,
|
| 775 |
+
input_ids,
|
| 776 |
+
asym_id,
|
| 777 |
+
residue_index,
|
| 778 |
+
mol_type,
|
| 779 |
+
tok_mask,
|
| 780 |
+
pad_to_multiple=pad_to,
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
def _discretized_dynamics(self) -> tuple[Tensor, Tensor]:
|
| 784 |
+
delta = F.softplus(self.parcae_log_delta)
|
| 785 |
+
a = torch.exp(-delta * torch.exp(self.parcae_log_a))
|
| 786 |
+
b = delta[:, None] * self.parcae_b_cont
|
| 787 |
+
return a, b
|
| 788 |
+
|
| 789 |
+
def _init_pair_state(self, ref: Tensor) -> Tensor:
|
| 790 |
+
std = math.sqrt(2.0 / (5.0 * ref.shape[-1]))
|
| 791 |
+
state = torch.empty_like(ref, dtype=torch.float32)
|
| 792 |
+
nn.init.trunc_normal_(state, mean=0.0, std=std, a=-3 * std, b=3 * std)
|
| 793 |
+
return state.to(dtype=ref.dtype)
|
| 794 |
+
|
| 795 |
+
def _run_one_loop(
|
| 796 |
+
self,
|
| 797 |
+
z: Tensor,
|
| 798 |
+
z_init: Tensor,
|
| 799 |
+
lm_z: Tensor | None,
|
| 800 |
+
_msa_kwargs: dict | None,
|
| 801 |
+
pair_mask: Tensor,
|
| 802 |
+
a: Tensor,
|
| 803 |
+
b_mat: Tensor,
|
| 804 |
+
total_steps: int,
|
| 805 |
+
) -> Tensor:
|
| 806 |
+
# Helper method (not inline) so per-iter locals free on return —
|
| 807 |
+
# otherwise leaks ~2 GB L²×c_z into distogram/sample scope.
|
| 808 |
+
# training=True forces dropout under eval(), matching the per-loop
|
| 809 |
+
# dropout strategy used at train time.
|
| 810 |
+
lm_cfg = self.config.lm_encoder
|
| 811 |
+
_per_loop_lm_dropout = (
|
| 812 |
+
lm_z is not None
|
| 813 |
+
and getattr(lm_cfg, "per_loop_lm_dropout", False)
|
| 814 |
+
and getattr(lm_cfg, "lm_dropout", 0.0) > 0.0
|
| 815 |
+
)
|
| 816 |
+
_lm_dropout_p = getattr(lm_cfg, "lm_dropout", 0.0)
|
| 817 |
+
|
| 818 |
+
for _ in range(total_steps):
|
| 819 |
+
if _per_loop_lm_dropout:
|
| 820 |
+
assert lm_z is not None # narrowed by _per_loop_lm_dropout
|
| 821 |
+
lm_z_i: Tensor | None = F.dropout(lm_z, p=_lm_dropout_p, training=True)
|
| 822 |
+
else:
|
| 823 |
+
lm_z_i = lm_z
|
| 824 |
+
|
| 825 |
+
refined_lm_z: Tensor | None = None
|
| 826 |
+
if lm_z_i is not None and self.lm_encoder is not None:
|
| 827 |
+
refined_lm_z = self.lm_encoder(
|
| 828 |
+
lm_z_i.to(z_init.dtype), pair_attention_mask=pair_mask
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
z_inject_pair = z_init
|
| 832 |
+
if lm_z_i is not None and self.lm_encoder is None:
|
| 833 |
+
z_inject_pair = z_inject_pair + lm_z_i.to(z_inject_pair.dtype)
|
| 834 |
+
|
| 835 |
+
if self.msa_encoder is not None and _msa_kwargs is not None:
|
| 836 |
+
msa_pair = self.msa_encoder(x_pair=z_inject_pair, **_msa_kwargs).to(
|
| 837 |
+
z_inject_pair.dtype
|
| 838 |
+
)
|
| 839 |
+
z_inject_pair = (
|
| 840 |
+
msa_pair
|
| 841 |
+
if self.config.msa_encoder_overwrite
|
| 842 |
+
else (z_inject_pair + msa_pair)
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
if refined_lm_z is not None:
|
| 846 |
+
z_inject_pair = z_inject_pair + refined_lm_z.to(z_inject_pair.dtype)
|
| 847 |
+
|
| 848 |
+
injected_pair = self.parcae_input_norm(z_inject_pair)
|
| 849 |
+
z = a * z + F.linear(injected_pair.to(z.dtype), b_mat)
|
| 850 |
+
z = self.folding_trunk(z, pair_attention_mask=pair_mask)
|
| 851 |
+
|
| 852 |
+
return z
|
| 853 |
+
|
| 854 |
+
@torch.inference_mode()
|
| 855 |
+
def forward(
|
| 856 |
+
self,
|
| 857 |
+
token_index: Tensor,
|
| 858 |
+
residue_index: Tensor,
|
| 859 |
+
asym_id: Tensor,
|
| 860 |
+
sym_id: Tensor,
|
| 861 |
+
entity_id: Tensor,
|
| 862 |
+
mol_type: Tensor,
|
| 863 |
+
res_type: Tensor,
|
| 864 |
+
token_bonds: Tensor,
|
| 865 |
+
token_attention_mask: Tensor,
|
| 866 |
+
ref_pos: Tensor,
|
| 867 |
+
ref_element: Tensor,
|
| 868 |
+
ref_charge: Tensor,
|
| 869 |
+
ref_atom_name_chars: Tensor,
|
| 870 |
+
ref_space_uid: Tensor,
|
| 871 |
+
atom_attention_mask: Tensor,
|
| 872 |
+
atom_to_token: Tensor,
|
| 873 |
+
distogram_atom_idx: Tensor,
|
| 874 |
+
deletion_mean: Tensor | None = None,
|
| 875 |
+
msa: Tensor | None = None,
|
| 876 |
+
has_deletion: Tensor | None = None,
|
| 877 |
+
deletion_value: Tensor | None = None,
|
| 878 |
+
msa_attention_mask: Tensor | None = None,
|
| 879 |
+
input_ids: Tensor | None = None,
|
| 880 |
+
lm_hidden_states: Tensor | None = None,
|
| 881 |
+
num_loops: int | None = None,
|
| 882 |
+
num_diffusion_samples: int | None = None,
|
| 883 |
+
num_sampling_steps: int | None = None,
|
| 884 |
+
**kwargs,
|
| 885 |
+
) -> dict[str, Tensor]:
|
| 886 |
+
tok_mask = token_attention_mask
|
| 887 |
+
atm_mask = atom_attention_mask
|
| 888 |
+
disto_idx = distogram_atom_idx
|
| 889 |
+
|
| 890 |
+
n_loops: int = num_loops if num_loops is not None else self.config.num_loops
|
| 891 |
+
n_samples: int = (
|
| 892 |
+
num_diffusion_samples
|
| 893 |
+
if num_diffusion_samples is not None
|
| 894 |
+
else self.config.num_diffusion_samples
|
| 895 |
+
)
|
| 896 |
+
total_steps = max(1, n_loops + 1)
|
| 897 |
+
|
| 898 |
+
if res_type.dim() == 2:
|
| 899 |
+
res_type_oh = F.one_hot(res_type.long(), num_classes=NUM_RES_TYPES).float()
|
| 900 |
+
res_type_oh = res_type_oh * tok_mask.unsqueeze(-1).float()
|
| 901 |
+
else:
|
| 902 |
+
res_type_oh = res_type.float()
|
| 903 |
+
|
| 904 |
+
if msa is not None:
|
| 905 |
+
msa_oh_profile = F.one_hot(msa.long(), num_classes=NUM_RES_TYPES).float()
|
| 906 |
+
if msa_attention_mask is not None:
|
| 907 |
+
mask_f = msa_attention_mask.float().unsqueeze(-1)
|
| 908 |
+
msa_oh_profile = msa_oh_profile * mask_f
|
| 909 |
+
valid_seq_count = msa_attention_mask.float().sum(dim=1).clamp(min=1)
|
| 910 |
+
profile = msa_oh_profile.sum(dim=1) / valid_seq_count.unsqueeze(-1)
|
| 911 |
+
else:
|
| 912 |
+
profile = msa_oh_profile.mean(dim=1)
|
| 913 |
+
else:
|
| 914 |
+
profile = res_type_oh
|
| 915 |
+
|
| 916 |
+
if deletion_mean is None:
|
| 917 |
+
deletion_mean = torch.zeros(
|
| 918 |
+
res_type.shape[0], res_type.shape[1], device=res_type.device
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
ref_element_oh = F.one_hot(
|
| 922 |
+
ref_element.long(), num_classes=MAX_ATOMIC_NUMBER
|
| 923 |
+
).float()
|
| 924 |
+
ref_atom_name_chars_oh = F.one_hot(
|
| 925 |
+
ref_atom_name_chars.long(), num_classes=CHAR_VOCAB_SIZE
|
| 926 |
+
).float()
|
| 927 |
+
# Bias-free downstream Linears require zeroed padding.
|
| 928 |
+
atm_mask_f = atm_mask.float()
|
| 929 |
+
ref_element_oh = ref_element_oh * atm_mask_f.unsqueeze(-1)
|
| 930 |
+
ref_atom_name_chars_oh = ref_atom_name_chars_oh * atm_mask_f.unsqueeze(
|
| 931 |
+
-1
|
| 932 |
+
).unsqueeze(-1)
|
| 933 |
+
atom_to_token = atom_to_token * atm_mask.long()
|
| 934 |
+
|
| 935 |
+
use_amp = ref_pos.device.type == "cuda"
|
| 936 |
+
with torch.amp.autocast("cuda", enabled=use_amp, dtype=torch.bfloat16):
|
| 937 |
+
x_inputs = self.inputs_embedder(
|
| 938 |
+
aatype=res_type_oh,
|
| 939 |
+
profile=profile.float(),
|
| 940 |
+
deletion_mean=deletion_mean.float(),
|
| 941 |
+
ref_pos=ref_pos,
|
| 942 |
+
atom_attention_mask=atm_mask,
|
| 943 |
+
ref_space_uid=ref_space_uid,
|
| 944 |
+
ref_charge=ref_charge,
|
| 945 |
+
ref_element=ref_element_oh,
|
| 946 |
+
ref_atom_name_chars=ref_atom_name_chars_oh,
|
| 947 |
+
atom_to_token=atom_to_token,
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
z_init = self.z_init_1(x_inputs).unsqueeze(2) + self.z_init_2(
|
| 951 |
+
x_inputs
|
| 952 |
+
).unsqueeze(1)
|
| 953 |
+
|
| 954 |
+
relative_position_encoding = self.rel_pos(
|
| 955 |
+
residue_index=residue_index,
|
| 956 |
+
asym_id=asym_id,
|
| 957 |
+
sym_id=sym_id,
|
| 958 |
+
entity_id=entity_id,
|
| 959 |
+
token_index=token_index,
|
| 960 |
+
)
|
| 961 |
+
token_bonds_encoding = self.token_bonds(token_bonds.float())
|
| 962 |
+
z_init = z_init + relative_position_encoding + token_bonds_encoding
|
| 963 |
+
|
| 964 |
+
if (
|
| 965 |
+
lm_hidden_states is None
|
| 966 |
+
and input_ids is not None
|
| 967 |
+
and self._esmc is not None
|
| 968 |
+
):
|
| 969 |
+
lm_hidden_states = self._compute_lm_hidden_states(
|
| 970 |
+
input_ids, asym_id, residue_index, mol_type, tok_mask
|
| 971 |
+
)
|
| 972 |
+
lm_z: Tensor | None = None
|
| 973 |
+
if lm_hidden_states is not None:
|
| 974 |
+
lm_z = self.language_model(lm_hidden_states.detach())
|
| 975 |
+
del lm_hidden_states
|
| 976 |
+
|
| 977 |
+
pair_mask = tok_mask[:, :, None].float() * tok_mask[:, None, :].float()
|
| 978 |
+
|
| 979 |
+
z = self._init_pair_state(z_init)
|
| 980 |
+
|
| 981 |
+
a, b = self._discretized_dynamics()
|
| 982 |
+
a = a.view(1, 1, 1, -1).to(device=z.device, dtype=z.dtype)
|
| 983 |
+
b_mat = b.to(device=z.device, dtype=z.dtype)
|
| 984 |
+
|
| 985 |
+
_msa_kwargs: dict | None = None
|
| 986 |
+
if self.msa_encoder is not None and msa is not None:
|
| 987 |
+
B_msa, M, L_msa = msa.shape
|
| 988 |
+
msa_oh = F.one_hot(
|
| 989 |
+
msa.permute(0, 2, 1).long(), num_classes=NUM_RES_TYPES
|
| 990 |
+
).float()
|
| 991 |
+
msa_attn = (
|
| 992 |
+
msa_attention_mask.permute(0, 2, 1).float()
|
| 993 |
+
if msa_attention_mask is not None
|
| 994 |
+
else tok_mask[:, :, None].expand(-1, -1, M).float()
|
| 995 |
+
)
|
| 996 |
+
# Bias-free MSAEncoder.embed requires zeroed padding.
|
| 997 |
+
msa_oh = msa_oh * msa_attn.unsqueeze(-1)
|
| 998 |
+
hd = (
|
| 999 |
+
has_deletion.permute(0, 2, 1).float()
|
| 1000 |
+
if has_deletion is not None
|
| 1001 |
+
else torch.zeros(B_msa, L_msa, M, device=msa.device)
|
| 1002 |
+
)
|
| 1003 |
+
dv = (
|
| 1004 |
+
deletion_value.permute(0, 2, 1).float()
|
| 1005 |
+
if deletion_value is not None
|
| 1006 |
+
else torch.zeros(B_msa, L_msa, M, device=msa.device)
|
| 1007 |
+
)
|
| 1008 |
+
_msa_kwargs = dict(
|
| 1009 |
+
x_inputs=x_inputs,
|
| 1010 |
+
msa_oh=msa_oh,
|
| 1011 |
+
has_deletion=hd,
|
| 1012 |
+
deletion_value=dv,
|
| 1013 |
+
msa_attention_mask=msa_attn,
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
# Method call (not inline loop) frees per-iter L²×c_z locals.
|
| 1017 |
+
z = self._run_one_loop(
|
| 1018 |
+
z=z,
|
| 1019 |
+
z_init=z_init,
|
| 1020 |
+
lm_z=lm_z,
|
| 1021 |
+
_msa_kwargs=_msa_kwargs,
|
| 1022 |
+
pair_mask=pair_mask,
|
| 1023 |
+
a=a,
|
| 1024 |
+
b_mat=b_mat,
|
| 1025 |
+
total_steps=total_steps,
|
| 1026 |
+
)
|
| 1027 |
+
del z_init, lm_z, _msa_kwargs, a, b_mat
|
| 1028 |
+
|
| 1029 |
+
z = self.parcae_readout(z)
|
| 1030 |
+
z = self.parcae_coda(z, pair_attention_mask=pair_mask)
|
| 1031 |
+
|
| 1032 |
+
z = z.float()
|
| 1033 |
+
distogram_logits = self.distogram_head(z + z.transpose(-2, -3))
|
| 1034 |
+
|
| 1035 |
+
structure_output = self.structure_head.sample(
|
| 1036 |
+
z_trunk=z,
|
| 1037 |
+
s_inputs=x_inputs,
|
| 1038 |
+
s_trunk=None,
|
| 1039 |
+
relative_position_encoding=relative_position_encoding,
|
| 1040 |
+
ref_pos=ref_pos,
|
| 1041 |
+
ref_charge=ref_charge,
|
| 1042 |
+
ref_mask=atm_mask,
|
| 1043 |
+
ref_element=ref_element_oh,
|
| 1044 |
+
ref_atom_name_chars=ref_atom_name_chars_oh,
|
| 1045 |
+
ref_space_uid=ref_space_uid,
|
| 1046 |
+
tok_idx=atom_to_token,
|
| 1047 |
+
asym_id=asym_id,
|
| 1048 |
+
residue_index=residue_index,
|
| 1049 |
+
entity_id=entity_id,
|
| 1050 |
+
token_index=token_index,
|
| 1051 |
+
sym_id=sym_id,
|
| 1052 |
+
token_attention_mask=tok_mask,
|
| 1053 |
+
num_diffusion_samples=n_samples,
|
| 1054 |
+
num_sampling_steps=num_sampling_steps,
|
| 1055 |
+
return_atom_repr=False,
|
| 1056 |
+
denoising_early_exit_rmsd=None,
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
sample_coords = structure_output["sample_atom_coords"]
|
| 1060 |
+
assert sample_coords is not None
|
| 1061 |
+
output: dict[str, Tensor] = {"distogram_logits": distogram_logits}
|
| 1062 |
+
output["sample_atom_coords"] = sample_coords
|
| 1063 |
+
|
| 1064 |
+
confidence_output = self.confidence_head(
|
| 1065 |
+
s_inputs=x_inputs.detach(),
|
| 1066 |
+
z=z.detach().float(),
|
| 1067 |
+
x_pred=sample_coords.detach(),
|
| 1068 |
+
distogram_atom_idx=disto_idx,
|
| 1069 |
+
token_attention_mask=tok_mask,
|
| 1070 |
+
atom_to_token=atom_to_token,
|
| 1071 |
+
atom_attention_mask=atm_mask,
|
| 1072 |
+
asym_id=asym_id,
|
| 1073 |
+
mol_type=mol_type,
|
| 1074 |
+
num_diffusion_samples=n_samples,
|
| 1075 |
+
relative_position_encoding=relative_position_encoding.detach(),
|
| 1076 |
+
token_bonds_encoding=token_bonds_encoding.detach(),
|
| 1077 |
+
)
|
| 1078 |
+
output.update(confidence_output)
|
| 1079 |
+
output["atom_pad_mask"] = (
|
| 1080 |
+
atm_mask.unsqueeze(0) if atm_mask.dim() == 1 else atm_mask
|
| 1081 |
+
)
|
| 1082 |
+
output["residue_index"] = residue_index
|
| 1083 |
+
output["entity_id"] = entity_id
|
| 1084 |
+
return output
|
| 1085 |
+
|
| 1086 |
+
@torch.no_grad()
|
| 1087 |
+
def infer_protein(self, seq: str, **forward_kwargs) -> dict:
|
| 1088 |
+
from .protein_utils import prepare_protein_features
|
| 1089 |
+
|
| 1090 |
+
features = prepare_protein_features(seq)
|
| 1091 |
+
features = {k: v.to(self.device) for k, v in features.items()}
|
| 1092 |
+
return self(**features, **forward_kwargs)
|
| 1093 |
+
|
| 1094 |
+
@property
|
| 1095 |
+
def input_builder(self):
|
| 1096 |
+
if self._esmfold2_input_builder is None:
|
| 1097 |
+
from .esmfold2_processor import ESMFold2InputBuilder
|
| 1098 |
+
|
| 1099 |
+
self._esmfold2_input_builder = ESMFold2InputBuilder()
|
| 1100 |
+
return self._esmfold2_input_builder
|
| 1101 |
+
|
| 1102 |
+
@property
|
| 1103 |
+
def input_types(self):
|
| 1104 |
+
from . import esmfold2_types
|
| 1105 |
+
|
| 1106 |
+
return esmfold2_types
|
| 1107 |
+
|
| 1108 |
+
def prepare_structure_input(self, input, seed: int | None = None):
|
| 1109 |
+
return self.input_builder.prepare_input(input, seed=seed, device=self.device)
|
| 1110 |
+
|
| 1111 |
+
def fold(
|
| 1112 |
+
self,
|
| 1113 |
+
input,
|
| 1114 |
+
*,
|
| 1115 |
+
num_loops: int = 3,
|
| 1116 |
+
num_sampling_steps: int = 50,
|
| 1117 |
+
num_diffusion_samples: int = 1,
|
| 1118 |
+
seed: int | None = None,
|
| 1119 |
+
noise_scale: float | None = None,
|
| 1120 |
+
step_scale: float | None = None,
|
| 1121 |
+
max_inference_sigma: int | None = None,
|
| 1122 |
+
early_exit: bool = False,
|
| 1123 |
+
complex_id: str = "pred",
|
| 1124 |
+
):
|
| 1125 |
+
return self.input_builder.fold(
|
| 1126 |
+
self,
|
| 1127 |
+
input,
|
| 1128 |
+
num_loops=num_loops,
|
| 1129 |
+
num_sampling_steps=num_sampling_steps,
|
| 1130 |
+
num_diffusion_samples=num_diffusion_samples,
|
| 1131 |
+
seed=seed,
|
| 1132 |
+
noise_scale=noise_scale,
|
| 1133 |
+
step_scale=step_scale,
|
| 1134 |
+
max_inference_sigma=max_inference_sigma,
|
| 1135 |
+
early_exit=early_exit,
|
| 1136 |
+
complex_id=complex_id,
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
def fold_protein(
|
| 1140 |
+
self,
|
| 1141 |
+
sequence: str,
|
| 1142 |
+
*,
|
| 1143 |
+
chain_id: str = "A",
|
| 1144 |
+
num_loops: int = 3,
|
| 1145 |
+
num_sampling_steps: int = 50,
|
| 1146 |
+
num_diffusion_samples: int = 1,
|
| 1147 |
+
seed: int | None = None,
|
| 1148 |
+
complex_id: str = "pred",
|
| 1149 |
+
):
|
| 1150 |
+
from .esmfold2_types import ProteinInput, StructurePredictionInput
|
| 1151 |
+
|
| 1152 |
+
input = StructurePredictionInput(
|
| 1153 |
+
sequences=[ProteinInput(id=chain_id, sequence=sequence)]
|
| 1154 |
+
)
|
| 1155 |
+
return self.fold(
|
| 1156 |
+
input,
|
| 1157 |
+
num_loops=num_loops,
|
| 1158 |
+
num_sampling_steps=num_sampling_steps,
|
| 1159 |
+
num_diffusion_samples=num_diffusion_samples,
|
| 1160 |
+
seed=seed,
|
| 1161 |
+
complex_id=complex_id,
|
| 1162 |
+
)
|
| 1163 |
+
|
| 1164 |
+
@staticmethod
|
| 1165 |
+
def result_to_cif(result) -> str:
|
| 1166 |
+
assert not isinstance(result, list), "Pass one MolecularComplexResult at a time."
|
| 1167 |
+
return result.complex.to_mmcif()
|
| 1168 |
+
|
| 1169 |
+
@staticmethod
|
| 1170 |
+
def result_to_pdb(result) -> str:
|
| 1171 |
+
assert not isinstance(result, list), "Pass one MolecularComplexResult at a time."
|
| 1172 |
+
return result.complex.to_protein_complex().to_pdb_string()
|
| 1173 |
+
|
| 1174 |
+
def save_as_cif(self, result, output_path: str | Path) -> None:
|
| 1175 |
+
Path(output_path).write_text(self.result_to_cif(result))
|
| 1176 |
+
|
| 1177 |
+
def save_as_pdb(self, result, output_path: str | Path) -> None:
|
| 1178 |
+
Path(output_path).write_text(self.result_to_pdb(result))
|
| 1179 |
+
|
| 1180 |
+
def infer_protein_as_cif(self, seq: str, **forward_kwargs) -> str:
|
| 1181 |
+
return self.result_to_cif(self.fold_protein(seq, **forward_kwargs))
|
| 1182 |
+
|
| 1183 |
+
def infer_protein_as_pdb(self, seq: str, **forward_kwargs) -> str:
|
| 1184 |
+
return self.result_to_pdb(self.fold_protein(seq, **forward_kwargs))
|
| 1185 |
+
|
| 1186 |
+
|
| 1187 |
+
class MSAEncoderBlock(nn.Module):
|
| 1188 |
+
"""One MSA encoder block: OPM into pair, MSA pair-weighted averaging, triangle update."""
|
| 1189 |
+
|
| 1190 |
+
def __init__(
|
| 1191 |
+
self,
|
| 1192 |
+
d_msa: int,
|
| 1193 |
+
d_pair: int,
|
| 1194 |
+
d_hidden: int,
|
| 1195 |
+
n_heads_msa: int,
|
| 1196 |
+
msa_head_width: int,
|
| 1197 |
+
is_final_block: bool = False,
|
| 1198 |
+
) -> None:
|
| 1199 |
+
super().__init__()
|
| 1200 |
+
self.is_final_block = is_final_block
|
| 1201 |
+
self.outer_product_mean = OuterProductMean(d_msa, d_hidden, d_pair)
|
| 1202 |
+
if not is_final_block:
|
| 1203 |
+
self.msa_pair_weighted_averaging = MSAPairWeightedAveraging(
|
| 1204 |
+
d_msa, d_pair, n_heads_msa, msa_head_width
|
| 1205 |
+
)
|
| 1206 |
+
self.msa_transition = PairTransition(d_msa, expansion_ratio=4)
|
| 1207 |
+
self.tri_mul_out = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=True)
|
| 1208 |
+
self.tri_mul_in = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=False)
|
| 1209 |
+
self.pair_transition = PairTransition(d_pair, expansion_ratio=4)
|
| 1210 |
+
|
| 1211 |
+
def set_chunk_size(self, chunk_size: int | None) -> None:
|
| 1212 |
+
self.outer_product_mean.set_chunk_size(chunk_size)
|
| 1213 |
+
self.tri_mul_out.set_chunk_size(chunk_size)
|
| 1214 |
+
self.tri_mul_in.set_chunk_size(chunk_size)
|
| 1215 |
+
if not self.is_final_block:
|
| 1216 |
+
self.msa_transition.set_chunk_size(chunk_size)
|
| 1217 |
+
self.pair_transition.set_chunk_size(chunk_size)
|
| 1218 |
+
|
| 1219 |
+
def forward(
|
| 1220 |
+
self,
|
| 1221 |
+
m: Tensor,
|
| 1222 |
+
pair: Tensor,
|
| 1223 |
+
msa_attention_mask: Tensor,
|
| 1224 |
+
pair_attention_mask: Tensor,
|
| 1225 |
+
) -> tuple[Tensor, Tensor]:
|
| 1226 |
+
pair = pair + self.outer_product_mean(m, msa_attention_mask)
|
| 1227 |
+
if not self.is_final_block:
|
| 1228 |
+
m = m + self.msa_pair_weighted_averaging(m, pair, pair_attention_mask)
|
| 1229 |
+
m = m + self.msa_transition(m)
|
| 1230 |
+
pair = pair + self.tri_mul_out(pair, mask=pair_attention_mask)
|
| 1231 |
+
pair = pair + self.tri_mul_in(pair, mask=pair_attention_mask)
|
| 1232 |
+
pair = pair + self.pair_transition(pair)
|
| 1233 |
+
return m, pair
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
class MSAEncoder(nn.Module):
|
| 1237 |
+
"""Stack of [`MSAEncoderBlock`] layers that conditions the pair on an MSA."""
|
| 1238 |
+
|
| 1239 |
+
def __init__(
|
| 1240 |
+
self,
|
| 1241 |
+
d_msa: int,
|
| 1242 |
+
d_pair: int,
|
| 1243 |
+
d_inputs: int,
|
| 1244 |
+
d_hidden: int = 32,
|
| 1245 |
+
n_layers: int = 4,
|
| 1246 |
+
n_heads_msa: int = 8,
|
| 1247 |
+
msa_head_width: int = 16,
|
| 1248 |
+
) -> None:
|
| 1249 |
+
super().__init__()
|
| 1250 |
+
self.embed = nn.Linear(35, d_msa, bias=False)
|
| 1251 |
+
self.project_inputs = nn.Linear(d_inputs, d_msa, bias=False)
|
| 1252 |
+
self.blocks = nn.ModuleList(
|
| 1253 |
+
[
|
| 1254 |
+
MSAEncoderBlock(
|
| 1255 |
+
d_msa=d_msa,
|
| 1256 |
+
d_pair=d_pair,
|
| 1257 |
+
d_hidden=d_hidden,
|
| 1258 |
+
n_heads_msa=n_heads_msa,
|
| 1259 |
+
msa_head_width=msa_head_width,
|
| 1260 |
+
is_final_block=(i == n_layers - 1),
|
| 1261 |
+
)
|
| 1262 |
+
for i in range(n_layers)
|
| 1263 |
+
]
|
| 1264 |
+
)
|
| 1265 |
+
|
| 1266 |
+
def set_chunk_size(self, chunk_size: int | None) -> None:
|
| 1267 |
+
for block in self.blocks:
|
| 1268 |
+
cast(MSAEncoderBlock, block).set_chunk_size(chunk_size)
|
| 1269 |
+
|
| 1270 |
+
def forward(
|
| 1271 |
+
self,
|
| 1272 |
+
x_pair: Tensor,
|
| 1273 |
+
x_inputs: Tensor,
|
| 1274 |
+
msa_oh: Tensor,
|
| 1275 |
+
has_deletion: Tensor,
|
| 1276 |
+
deletion_value: Tensor,
|
| 1277 |
+
msa_attention_mask: Tensor,
|
| 1278 |
+
) -> Tensor:
|
| 1279 |
+
# All inputs are pre-transposed to [B, L, M, ...] before calling.
|
| 1280 |
+
m_feat = torch.cat(
|
| 1281 |
+
[msa_oh, has_deletion.unsqueeze(-1), deletion_value.unsqueeze(-1)], dim=-1
|
| 1282 |
+
)
|
| 1283 |
+
m = self.embed(m_feat) + self.project_inputs(x_inputs).unsqueeze(2)
|
| 1284 |
+
tok_mask = msa_attention_mask[:, :, 0].bool()
|
| 1285 |
+
pair_attention_mask = tok_mask.unsqueeze(2) & tok_mask.unsqueeze(1)
|
| 1286 |
+
for block in self.blocks:
|
| 1287 |
+
m, x_pair = block(m, x_pair, msa_attention_mask, pair_attention_mask)
|
| 1288 |
+
return x_pair
|
modeling_esmfold2_common.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
protein_utils.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 Biohub. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Self-contained protein featurization for ESMFold2 inference.
|
| 16 |
+
|
| 17 |
+
Lets ``ESMFold2ExperimentalModel.infer_protein_as_pdb`` fold a protein sequence
|
| 18 |
+
ESMFold-style without the ``esm`` companion package. The featurization
|
| 19 |
+
mirrors ``ESMFold2InputBuilder.prepare_input`` for the protein-only path —
|
| 20 |
+
``test_prepare_protein_features.py`` enforces tensor-exact parity.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import math
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
from torch import Tensor
|
| 29 |
+
|
| 30 |
+
MOL_TYPE_PROTEIN = 0
|
| 31 |
+
PROTEIN_UNK_RES_TYPE = 22
|
| 32 |
+
MSA_GAP_TOKEN_ID = 1
|
| 33 |
+
|
| 34 |
+
PROTEIN_RESIDUE_TO_RES_TYPE: dict[str, int] = {
|
| 35 |
+
"ALA": 2,
|
| 36 |
+
"ARG": 3,
|
| 37 |
+
"ASN": 4,
|
| 38 |
+
"ASP": 5,
|
| 39 |
+
"CYS": 6,
|
| 40 |
+
"GLN": 7,
|
| 41 |
+
"GLU": 8,
|
| 42 |
+
"GLY": 9,
|
| 43 |
+
"HIS": 10,
|
| 44 |
+
"ILE": 11,
|
| 45 |
+
"LEU": 12,
|
| 46 |
+
"LYS": 13,
|
| 47 |
+
"MET": 14,
|
| 48 |
+
"PHE": 15,
|
| 49 |
+
"PRO": 16,
|
| 50 |
+
"SER": 17,
|
| 51 |
+
"THR": 18,
|
| 52 |
+
"TRP": 19,
|
| 53 |
+
"TYR": 20,
|
| 54 |
+
"VAL": 21,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
PROTEIN_1TO3: dict[str, str] = {
|
| 58 |
+
"A": "ALA",
|
| 59 |
+
"R": "ARG",
|
| 60 |
+
"N": "ASN",
|
| 61 |
+
"D": "ASP",
|
| 62 |
+
"C": "CYS",
|
| 63 |
+
"Q": "GLN",
|
| 64 |
+
"E": "GLU",
|
| 65 |
+
"G": "GLY",
|
| 66 |
+
"H": "HIS",
|
| 67 |
+
"I": "ILE",
|
| 68 |
+
"L": "LEU",
|
| 69 |
+
"K": "LYS",
|
| 70 |
+
"M": "MET",
|
| 71 |
+
"F": "PHE",
|
| 72 |
+
"P": "PRO",
|
| 73 |
+
"S": "SER",
|
| 74 |
+
"T": "THR",
|
| 75 |
+
"W": "TRP",
|
| 76 |
+
"Y": "TYR",
|
| 77 |
+
"V": "VAL",
|
| 78 |
+
"X": "UNK",
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
ESM_PROTEIN_VOCAB: dict[str, int] = {
|
| 82 |
+
"L": 4,
|
| 83 |
+
"A": 5,
|
| 84 |
+
"G": 6,
|
| 85 |
+
"V": 7,
|
| 86 |
+
"S": 8,
|
| 87 |
+
"E": 9,
|
| 88 |
+
"R": 10,
|
| 89 |
+
"T": 11,
|
| 90 |
+
"I": 12,
|
| 91 |
+
"D": 13,
|
| 92 |
+
"P": 14,
|
| 93 |
+
"K": 15,
|
| 94 |
+
"Q": 16,
|
| 95 |
+
"N": 17,
|
| 96 |
+
"F": 18,
|
| 97 |
+
"Y": 19,
|
| 98 |
+
"M": 20,
|
| 99 |
+
"H": 21,
|
| 100 |
+
"W": 22,
|
| 101 |
+
"C": 23,
|
| 102 |
+
"X": 3,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
# Heavy atoms per canonical residue, in training-time order.
|
| 106 |
+
PROTEIN_HEAVY_ATOMS: dict[str, list[str]] = {
|
| 107 |
+
"ALA": ["N", "CA", "C", "O", "CB"],
|
| 108 |
+
"ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
|
| 109 |
+
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"],
|
| 110 |
+
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"],
|
| 111 |
+
"CYS": ["N", "CA", "C", "O", "CB", "SG"],
|
| 112 |
+
"GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"],
|
| 113 |
+
"GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"],
|
| 114 |
+
"GLY": ["N", "CA", "C", "O"],
|
| 115 |
+
"HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"],
|
| 116 |
+
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"],
|
| 117 |
+
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"],
|
| 118 |
+
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"],
|
| 119 |
+
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
|
| 120 |
+
"PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
|
| 121 |
+
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD"],
|
| 122 |
+
"SER": ["N", "CA", "C", "O", "CB", "OG"],
|
| 123 |
+
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2"],
|
| 124 |
+
"TRP": [
|
| 125 |
+
"N",
|
| 126 |
+
"CA",
|
| 127 |
+
"C",
|
| 128 |
+
"O",
|
| 129 |
+
"CB",
|
| 130 |
+
"CG",
|
| 131 |
+
"CD1",
|
| 132 |
+
"CD2",
|
| 133 |
+
"NE1",
|
| 134 |
+
"CE2",
|
| 135 |
+
"CE3",
|
| 136 |
+
"CZ2",
|
| 137 |
+
"CZ3",
|
| 138 |
+
"CH2",
|
| 139 |
+
],
|
| 140 |
+
"TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"],
|
| 141 |
+
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2"],
|
| 142 |
+
"UNK": ["N", "CA", "C", "O"],
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
PROTEIN_REF_POS: dict[str, dict[str, tuple[float, float, float]]] = {
|
| 146 |
+
"ALA": {
|
| 147 |
+
"N": (-0.01003183238208294, -1.2073018550872803, -1.0555061101913452),
|
| 148 |
+
"CA": (-0.04190138354897499, 0.17447763681411743, -0.5729365348815918),
|
| 149 |
+
"C": (1.2127548456192017, 0.4737588167190552, 0.19521640241146088),
|
| 150 |
+
"O": (1.9390329122543335, 1.4484562873840332, -0.13759790360927582),
|
| 151 |
+
"CB": (-1.276943325996399, 0.4288230538368225, 0.29937705397605896),
|
| 152 |
+
},
|
| 153 |
+
"ARG": {
|
| 154 |
+
"N": (-2.0170421600341797, 0.6717798113822937, -1.1794233322143555),
|
| 155 |
+
"CA": (-2.0503084659576416, -0.5735036730766296, -0.4097220301628113),
|
| 156 |
+
"C": (-3.469440460205078, -1.0612813234329224, -0.2755832374095917),
|
| 157 |
+
"O": (-3.8218462467193604, -2.1369943618774414, -0.8294969797134399),
|
| 158 |
+
"CB": (-1.4193516969680786, -0.3735991418361664, 0.9852858781814575),
|
| 159 |
+
"CG": (0.11878877878189087, -0.3112654983997345, 0.963895857334137),
|
| 160 |
+
"CD": (0.6643245816230774, 1.0068185329437256, 0.3963329493999481),
|
| 161 |
+
"NE": (2.1090238094329834, 1.0977025032043457, 0.6120952367782593),
|
| 162 |
+
"CZ": (3.098905324935913, 0.3215920031070709, -0.09047172218561172),
|
| 163 |
+
"NH1": (4.461230278015137, 0.3844667971134186, 0.34141138195991516),
|
| 164 |
+
"NH2": (2.7856509685516357, -0.4166366159915924, -1.1148239374160767),
|
| 165 |
+
},
|
| 166 |
+
"ASN": {
|
| 167 |
+
"N": (-0.7595629096031189, 0.7503494620323181, 1.1369825601577759),
|
| 168 |
+
"CA": (-0.76087886095047, 0.23876343667507172, -0.23573364317417145),
|
| 169 |
+
"C": (-1.9211044311523438, -0.6982439160346985, -0.42196929454803467),
|
| 170 |
+
"O": (-2.677666187286377, -0.5753439664840698, -1.4223182201385498),
|
| 171 |
+
"CB": (0.5504899024963379, -0.5078350305557251, -0.5390339493751526),
|
| 172 |
+
"CG": (1.7250099182128906, 0.4264017939567566, -0.5778228640556335),
|
| 173 |
+
"OD1": (1.9470350742340088, 1.1086392402648926, -1.613560438156128),
|
| 174 |
+
"ND2": (2.57365345954895, 0.5730618834495544, 0.5608599781990051),
|
| 175 |
+
},
|
| 176 |
+
"ASP": {
|
| 177 |
+
"N": (-1.8452696800231934, -1.2169504165649414, 0.19437327980995178),
|
| 178 |
+
"CA": (-0.6379959583282471, -0.41974392533302307, 0.41681644320487976),
|
| 179 |
+
"C": (-0.9431572556495667, 1.0356197357177734, 0.18555717170238495),
|
| 180 |
+
"O": (-1.5183608531951904, 1.4045922756195068, -0.8739855885505676),
|
| 181 |
+
"CB": (0.48594576120376587, -0.8970447778701782, -0.5209363698959351),
|
| 182 |
+
"CG": (1.780342936515808, -0.19918935000896454, -0.2310730367898941),
|
| 183 |
+
"OD1": (2.5202910900115967, -0.6044584512710571, 0.7049641013145447),
|
| 184 |
+
"OD2": (2.1454880237579346, 0.9208861589431763, -0.9712985157966614),
|
| 185 |
+
},
|
| 186 |
+
"CYS": {
|
| 187 |
+
"N": (0.0469963513314724, 1.190075159072876, -1.1607273817062378),
|
| 188 |
+
"CA": (0.11344368755817413, -0.09400428831577301, -0.45952197909355164),
|
| 189 |
+
"C": (-1.2652032375335693, -0.6832379698753357, -0.3594406247138977),
|
| 190 |
+
"O": (-1.4631439447402954, -1.8851220607757568, -0.6826791763305664),
|
| 191 |
+
"CB": (0.6919880509376526, 0.09034398198127747, 0.952482283115387),
|
| 192 |
+
"SG": (2.4619927406311035, 0.5235707759857178, 0.9020372629165649),
|
| 193 |
+
},
|
| 194 |
+
"GLN": {
|
| 195 |
+
"N": (-2.370004653930664, -0.9637529850006104, -0.7942749261856079),
|
| 196 |
+
"CA": (-1.370002269744873, -0.6000258922576904, 0.2103111445903778),
|
| 197 |
+
"C": (-1.7545503377914429, 0.7091967463493347, 0.8433493971824646),
|
| 198 |
+
"O": (-1.8520662784576416, 0.7999289631843567, 2.0964975357055664),
|
| 199 |
+
"CB": (0.02040259726345539, -0.5004461407661438, -0.44764479994773865),
|
| 200 |
+
"CG": (1.1377512216567993, -0.28680720925331116, 0.582992434501648),
|
| 201 |
+
"CD": (2.4745187759399414, -0.24800164997577667, -0.09364881366491318),
|
| 202 |
+
"OE1": (3.1685523986816406, -1.2966246604919434, -0.1717153936624527),
|
| 203 |
+
"NE2": (2.947425603866577, 0.9601329565048218, -0.6888364553451538),
|
| 204 |
+
},
|
| 205 |
+
"GLU": {
|
| 206 |
+
"N": (-1.5850872993469238, -1.337684154510498, 0.9490851163864136),
|
| 207 |
+
"CA": (-1.0560977458953857, 0.027459044009447098, 1.0306966304779053),
|
| 208 |
+
"C": (-1.7741456031799316, 0.9664392471313477, 0.09259600937366486),
|
| 209 |
+
"O": (-1.9012441635131836, 2.181349992752075, 0.402479350566864),
|
| 210 |
+
"CB": (0.4706551432609558, 0.048803869634866714, 0.8114414811134338),
|
| 211 |
+
"CG": (0.9133604764938354, -0.4219329059123993, -0.5830985307693481),
|
| 212 |
+
"CD": (2.398822069168091, -0.3097084164619446, -0.7210537791252136),
|
| 213 |
+
"OE1": (3.1389315128326416, -1.274524450302124, -0.39029765129089355),
|
| 214 |
+
"OE2": (2.9647817611694336, 0.8781346082687378, -1.1732689142227173),
|
| 215 |
+
},
|
| 216 |
+
"GLY": {
|
| 217 |
+
"N": (-1.3942985534667969, -0.39875128865242004, -0.3370324671268463),
|
| 218 |
+
"CA": (-0.39974430203437805, 0.5488945245742798, 0.15242962539196014),
|
| 219 |
+
"C": (0.9440054893493652, -0.10314033925533295, 0.19859643280506134),
|
| 220 |
+
"O": (1.3352899551391602, -0.669218122959137, 1.2541258335113525),
|
| 221 |
+
},
|
| 222 |
+
"HIS": {
|
| 223 |
+
"N": (-1.4532867670059204, -1.0689626932144165, 0.881072461605072),
|
| 224 |
+
"CA": (-1.3396095037460327, 0.24797579646110535, 0.24960045516490936),
|
| 225 |
+
"C": (-2.675257921218872, 0.6571555733680725, -0.30441102385520935),
|
| 226 |
+
"O": (-3.1311378479003906, 1.8079776763916016, -0.06785715371370316),
|
| 227 |
+
"CB": (-0.3041955828666687, 0.21721023321151733, -0.8885309100151062),
|
| 228 |
+
"CG": (1.0887513160705566, 0.028941065073013306, -0.36419469118118286),
|
| 229 |
+
"ND1": (1.840459942817688, 1.0411773920059204, 0.29804590344429016),
|
| 230 |
+
"CD2": (1.780855417251587, -1.1011489629745483, -0.3814258575439453),
|
| 231 |
+
"CE1": (2.9566943645477295, 0.4924798905849457, 0.6477115750312805),
|
| 232 |
+
"NE2": (3.0280203819274902, -0.8751969337463379, 0.26084381341934204),
|
| 233 |
+
},
|
| 234 |
+
"ILE": {
|
| 235 |
+
"N": (-0.7167549729347229, -1.5426139831542969, -0.9983330368995667),
|
| 236 |
+
"CA": (-1.0636085271835327, -0.35169270634651184, -0.21393552422523499),
|
| 237 |
+
"C": (-1.3896740674972534, 0.8142145276069641, -1.1164065599441528),
|
| 238 |
+
"O": (-1.2377792596817017, 0.7302915453910828, -2.3656840324401855),
|
| 239 |
+
"CB": (0.061667006462812424, 0.01599610224366188, 0.8057394623756409),
|
| 240 |
+
"CG1": (1.502519965171814, -0.08899776637554169, 0.24154816567897797),
|
| 241 |
+
"CG2": (-0.053174979984760284, -0.8521055579185486, 2.0702083110809326),
|
| 242 |
+
"CD1": (1.7929610013961792, 0.899773120880127, -0.8863027691841125),
|
| 243 |
+
},
|
| 244 |
+
"LEU": {
|
| 245 |
+
"N": (1.9657520055770874, -1.9763224124908447, -0.18391533195972443),
|
| 246 |
+
"CA": (1.3077669143676758, -0.6677430868148804, -0.19492436945438385),
|
| 247 |
+
"C": (1.9905058145523071, 0.24182087182998657, 0.7879968285560608),
|
| 248 |
+
"O": (2.06896710395813, -0.07880014181137085, 2.0048046112060547),
|
| 249 |
+
"CB": (-0.20306941866874695, -0.8093230128288269, 0.11243502795696259),
|
| 250 |
+
"CG": (-0.9916267395019531, 0.5234957337379456, 0.06723011285066605),
|
| 251 |
+
"CD1": (-2.4228057861328125, 0.29949337244033813, 0.573042094707489),
|
| 252 |
+
"CD2": (-1.0282856225967407, 1.1250264644622803, -1.346014380455017),
|
| 253 |
+
},
|
| 254 |
+
"LYS": {
|
| 255 |
+
"N": (2.4221372604370117, -0.6473312377929688, 0.6370573043823242),
|
| 256 |
+
"CA": (2.0314927101135254, 0.2786507308483124, -0.4298512041568756),
|
| 257 |
+
"C": (2.7168593406677246, 1.595757246017456, -0.20924785733222961),
|
| 258 |
+
"O": (3.397681713104248, 2.116427421569824, -1.1332510709762573),
|
| 259 |
+
"CB": (0.5018402934074402, 0.4873858690261841, -0.49062973260879517),
|
| 260 |
+
"CG": (-0.25062066316604614, -0.7894009947776794, -0.9055535793304443),
|
| 261 |
+
"CD": (-1.769762635231018, -0.5552700161933899, -1.040329933166504),
|
| 262 |
+
"CE": (-2.576533555984497, -1.0221366882324219, 0.18493641912937164),
|
| 263 |
+
"NZ": (-2.269151210784912, -0.24293844401836395, 1.3849012851715088),
|
| 264 |
+
},
|
| 265 |
+
"MET": {
|
| 266 |
+
"N": (1.8903918266296387, -1.5252995491027832, -0.42638593912124634),
|
| 267 |
+
"CA": (1.2630571126937866, -0.24417810142040253, -0.7626462578773499),
|
| 268 |
+
"C": (2.30391001701355, 0.8367712497711182, -0.7254616618156433),
|
| 269 |
+
"O": (2.465414524078369, 1.5928632020950317, -1.7207728624343872),
|
| 270 |
+
"CB": (0.10567972809076309, 0.10861825942993164, 0.19741646945476532),
|
| 271 |
+
"CG": (-1.0658042430877686, -0.8736631274223328, 0.08811883628368378),
|
| 272 |
+
"SD": (-2.4557132720947266, -0.3332225978374481, 1.1461700201034546),
|
| 273 |
+
"CE": (-3.265165090560913, 0.7033554911613464, -0.11588376015424728),
|
| 274 |
+
},
|
| 275 |
+
"PHE": {
|
| 276 |
+
"N": (-2.8484435081481934, -1.525790810585022, 0.01789816841483116),
|
| 277 |
+
"CA": (-1.591969609260559, -0.8545162677764893, 0.35214468836784363),
|
| 278 |
+
"C": (-1.8900631666183472, 0.45833414793014526, 1.0232222080230713),
|
| 279 |
+
"O": (-1.3424992561340332, 0.74432373046875, 2.121629476547241),
|
| 280 |
+
"CB": (-0.760358452796936, -0.6342853307723999, -0.9257160425186157),
|
| 281 |
+
"CG": (0.604112982749939, -0.07200468331575394, -0.6148118376731873),
|
| 282 |
+
"CD1": (0.8468314409255981, 1.2480632066726685, -0.7146694660186768),
|
| 283 |
+
"CD2": (1.6827683448791504, -0.9758077263832092, -0.1423054188489914),
|
| 284 |
+
"CE1": (2.1801748275756836, 1.7875733375549316, -0.3744623064994812),
|
| 285 |
+
"CE2": (2.888307809829712, -0.48277512192726135, 0.16804970800876617),
|
| 286 |
+
"CZ": (3.149812936782837, 0.9656873941421509, 0.04440271109342575),
|
| 287 |
+
},
|
| 288 |
+
"PRO": {
|
| 289 |
+
"N": (-0.836250364780426, -0.9899801015853882, 0.5561304688453674),
|
| 290 |
+
"CA": (0.32722190022468567, -0.6164458394050598, -0.25072571635246277),
|
| 291 |
+
"C": (1.6121541261672974, -1.1711241006851196, 0.31082412600517273),
|
| 292 |
+
"O": (1.6127740144729614, -2.2771971225738525, 0.9156193733215332),
|
| 293 |
+
"CB": (0.3248198926448822, 0.9028244018554688, -0.33368146419525146),
|
| 294 |
+
"CG": (-1.1425083875656128, 1.2730128765106201, -0.2590600252151489),
|
| 295 |
+
"CD": (-1.8495968580245972, 0.026575811207294464, 0.2681289613246918),
|
| 296 |
+
},
|
| 297 |
+
"SER": {
|
| 298 |
+
"N": (0.674650251865387, 1.5018702745437622, -0.5367295145988464),
|
| 299 |
+
"CA": (0.00013792862591799349, 0.4966467022895813, 0.28510504961013794),
|
| 300 |
+
"C": (0.9941009879112244, -0.5374617576599121, 0.73505038022995),
|
| 301 |
+
"O": (1.0545241832733154, -0.8683545589447021, 1.9495396614074707),
|
| 302 |
+
"CB": (-1.1279288530349731, -0.1659376323223114, -0.5160963535308838),
|
| 303 |
+
"OG": (-1.8135979175567627, -1.085249662399292, 0.28947514295578003),
|
| 304 |
+
},
|
| 305 |
+
"THR": {
|
| 306 |
+
"N": (-1.325830340385437, -1.3728225231170654, 0.6882233023643494),
|
| 307 |
+
"CA": (-0.5433306097984314, -0.16364754736423492, 0.41697052121162415),
|
| 308 |
+
"C": (-1.294381856918335, 0.7077372074127197, -0.5549946427345276),
|
| 309 |
+
"O": (-1.6939635276794434, 0.23654410243034363, -1.6540418863296509),
|
| 310 |
+
"CB": (0.853203296661377, -0.5363803505897522, -0.14109353721141815),
|
| 311 |
+
"OG1": (1.5220820903778076, -1.379003643989563, 0.7635167837142944),
|
| 312 |
+
"CG2": (1.7225933074951172, 0.7054727077484131, -0.3651331067085266),
|
| 313 |
+
},
|
| 314 |
+
"TRP": {
|
| 315 |
+
"N": (3.686030864715576, 0.7599999904632568, 0.496155709028244),
|
| 316 |
+
"CA": (2.384092092514038, 0.09079249948263168, 0.5325262546539307),
|
| 317 |
+
"C": (2.1113572120666504, -0.6121063232421875, -0.7733646035194397),
|
| 318 |
+
"O": (1.796526312828064, -1.8323148488998413, -0.7775964140892029),
|
| 319 |
+
"CB": (1.281521201133728, 1.1139036417007446, 0.8559791445732117),
|
| 320 |
+
"CG": (-0.04292375594377518, 0.44645074009895325, 1.0942792892456055),
|
| 321 |
+
"CD1": (-0.42329534888267517, -0.15470874309539795, 2.2227554321289062),
|
| 322 |
+
"CD2": (-1.1023900508880615, 0.2158389836549759, 0.11529432237148285),
|
| 323 |
+
"NE1": (-1.7030320167541504, -0.7665823101997375, 2.0595016479492188),
|
| 324 |
+
"CE2": (-2.045644998550415, -0.4881173074245453, 0.710669219493866),
|
| 325 |
+
"CE3": (-1.2173502445220947, 0.6102271676063538, -1.300106406211853),
|
| 326 |
+
"CZ2": (-3.256009340286255, -0.9164394736289978, -0.00984987337142229),
|
| 327 |
+
"CZ3": (-2.315925121307373, 0.2306906282901764, -1.9776310920715332),
|
| 328 |
+
"CH2": (-3.3817875385284424, -0.5677337646484375, -1.3032053709030151),
|
| 329 |
+
},
|
| 330 |
+
"TYR": {
|
| 331 |
+
"N": (-1.7900604009628296, -0.8409399390220642, 1.3180142641067505),
|
| 332 |
+
"CA": (-1.913882851600647, 0.23552845418453217, 0.330669641494751),
|
| 333 |
+
"C": (-3.347280740737915, 0.3588399887084961, -0.09830684959888458),
|
| 334 |
+
"O": (-3.967811346054077, -0.6449354290962219, -0.5423302054405212),
|
| 335 |
+
"CB": (-1.0093992948532104, 0.0004731413209810853, -0.8981552124023438),
|
| 336 |
+
"CG": (0.4520410895347595, 0.021162061020731926, -0.5305932760238647),
|
| 337 |
+
"CD1": (1.0992432832717896, 1.1877919435501099, -0.3579142987728119),
|
| 338 |
+
"CD2": (1.1803174018859863, -1.253401279449463, -0.31122180819511414),
|
| 339 |
+
"CE1": (2.5253450870513916, 1.1990256309509277, 0.029804613441228867),
|
| 340 |
+
"CE2": (2.471151113510132, -1.240687608718872, 0.043534230440855026),
|
| 341 |
+
"CZ": (3.180687665939331, 0.04672492295503616, 0.2214856892824173),
|
| 342 |
+
"OH": (4.523719787597656, 0.0671030730009079, 0.5877485871315002),
|
| 343 |
+
},
|
| 344 |
+
"VAL": {
|
| 345 |
+
"N": (0.5987519025802612, -1.569443702697754, -0.7379124760627747),
|
| 346 |
+
"CA": (0.6014357209205627, -0.10503966361284256, -0.6336286664009094),
|
| 347 |
+
"C": (1.8391697406768799, 0.4067850410938263, 0.06351757049560547),
|
| 348 |
+
"O": (2.3952062129974365, -0.2666190266609192, 0.9731166958808899),
|
| 349 |
+
"CB": (-0.694736897945404, 0.4259096384048462, 0.03581475466489792),
|
| 350 |
+
"CG1": (-1.9276031255722046, 0.09515828639268875, -0.8172357082366943),
|
| 351 |
+
"CG2": (-0.8938426971435547, -0.08640842139720917, 1.472349762916565),
|
| 352 |
+
},
|
| 353 |
+
"UNK": {
|
| 354 |
+
"N": (0.0, 0.0, 0.0),
|
| 355 |
+
"CA": (0.0, 0.0, 0.0),
|
| 356 |
+
"C": (0.0, 0.0, 0.0),
|
| 357 |
+
"O": (0.0, 0.0, 0.0),
|
| 358 |
+
},
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
# Protonated nitrogens at physiological pH (matches CHARGED_ATOMS in the
|
| 362 |
+
# opensource constants for the protein subset).
|
| 363 |
+
PROTEIN_CHARGED_ATOMS: dict[tuple[str, str], int] = {
|
| 364 |
+
("LYS", "NZ"): 1,
|
| 365 |
+
("ARG", "NH2"): 1,
|
| 366 |
+
("HIS", "ND1"): 1,
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
# Only the elements that appear in canonical protein heavy atoms.
|
| 370 |
+
_PROTEIN_ELEMENT_TO_ATOMIC_NUM: dict[str, int] = {"C": 6, "N": 7, "O": 8, "S": 16}
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def _encode_atom_name(name: str) -> list[int]:
|
| 374 |
+
padded = name.ljust(4)[:4]
|
| 375 |
+
return [ord(c) - 32 if c != " " else 0 for c in padded]
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def prepare_protein_features(sequence: str) -> dict[str, Tensor]:
|
| 379 |
+
"""Featurize a single protein sequence for ESMFold2ExperimentalModel.forward.
|
| 380 |
+
|
| 381 |
+
Returns the same keys with the same dtypes/shapes as
|
| 382 |
+
``ESMFold2InputBuilder.prepare_input(StructurePredictionInput(...))``
|
| 383 |
+
restricted to a single-chain protein with no MSA, modifications,
|
| 384 |
+
distogram conditioning, or covalent bonds. All tensors have a
|
| 385 |
+
leading batch dim of 1; the caller is responsible for moving them
|
| 386 |
+
to the model device.
|
| 387 |
+
"""
|
| 388 |
+
if not sequence:
|
| 389 |
+
raise ValueError("sequence must be non-empty")
|
| 390 |
+
|
| 391 |
+
res_3letter = [PROTEIN_1TO3.get(c, "UNK") for c in sequence]
|
| 392 |
+
L = len(sequence)
|
| 393 |
+
|
| 394 |
+
token_atom_starts: list[int] = []
|
| 395 |
+
atom_records: list[tuple[int, str, str, int, tuple[float, float, float]]] = []
|
| 396 |
+
res_type_vals: list[int] = []
|
| 397 |
+
input_id_vals: list[int] = []
|
| 398 |
+
distogram_rep_atom_idx: list[int] = []
|
| 399 |
+
|
| 400 |
+
atom_cursor = 0
|
| 401 |
+
for t_idx, (letter, res_3) in enumerate(zip(sequence, res_3letter)):
|
| 402 |
+
atom_names = PROTEIN_HEAVY_ATOMS[res_3]
|
| 403 |
+
res_type = PROTEIN_RESIDUE_TO_RES_TYPE.get(res_3, PROTEIN_UNK_RES_TYPE)
|
| 404 |
+
input_id = ESM_PROTEIN_VOCAB.get(letter, ESM_PROTEIN_VOCAB["X"])
|
| 405 |
+
|
| 406 |
+
token_atom_starts.append(atom_cursor)
|
| 407 |
+
for name in atom_names:
|
| 408 |
+
charge = PROTEIN_CHARGED_ATOMS.get((res_3, name), 0)
|
| 409 |
+
element = name[0] # protein heavy atoms are all single-letter C/N/O/S
|
| 410 |
+
ref_pos = PROTEIN_REF_POS[res_3][name]
|
| 411 |
+
atom_records.append((t_idx, name, element, charge, ref_pos))
|
| 412 |
+
atom_cursor += 1
|
| 413 |
+
|
| 414 |
+
rep_name = "CB" if "CB" in atom_names else "CA"
|
| 415 |
+
distogram_rep_atom_idx.append(
|
| 416 |
+
token_atom_starts[t_idx] + atom_names.index(rep_name)
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
res_type_vals.append(res_type)
|
| 420 |
+
input_id_vals.append(input_id)
|
| 421 |
+
|
| 422 |
+
n_real_atoms = len(atom_records)
|
| 423 |
+
n_atoms = math.ceil(n_real_atoms / 32) * 32 if n_real_atoms > 0 else 32
|
| 424 |
+
|
| 425 |
+
ref_pos = torch.zeros(n_atoms, 3, dtype=torch.float32)
|
| 426 |
+
ref_element = torch.zeros(n_atoms, dtype=torch.int64)
|
| 427 |
+
ref_charge = torch.zeros(n_atoms, dtype=torch.int8)
|
| 428 |
+
ref_atom_name_chars = torch.zeros(n_atoms, 4, dtype=torch.int64)
|
| 429 |
+
ref_space_uid = torch.zeros(n_atoms, dtype=torch.int64)
|
| 430 |
+
atom_attention_mask = torch.zeros(n_atoms, dtype=torch.bool)
|
| 431 |
+
atom_to_token = torch.zeros(n_atoms, dtype=torch.int64)
|
| 432 |
+
|
| 433 |
+
for i, (t_idx, name, element, charge, pos) in enumerate(atom_records):
|
| 434 |
+
ref_pos[i] = torch.tensor(pos, dtype=torch.float32)
|
| 435 |
+
ref_element[i] = _PROTEIN_ELEMENT_TO_ATOMIC_NUM[element]
|
| 436 |
+
ref_charge[i] = charge
|
| 437 |
+
ref_atom_name_chars[i] = torch.tensor(
|
| 438 |
+
_encode_atom_name(name), dtype=torch.int64
|
| 439 |
+
)
|
| 440 |
+
ref_space_uid[i] = t_idx
|
| 441 |
+
atom_attention_mask[i] = True
|
| 442 |
+
atom_to_token[i] = t_idx
|
| 443 |
+
|
| 444 |
+
token_index = torch.arange(L, dtype=torch.int64)
|
| 445 |
+
residue_index = torch.arange(L, dtype=torch.int64)
|
| 446 |
+
asym_id = torch.zeros(L, dtype=torch.int64)
|
| 447 |
+
sym_id = torch.zeros(L, dtype=torch.int64)
|
| 448 |
+
entity_id = torch.ones(L, dtype=torch.int64)
|
| 449 |
+
mol_type = torch.full((L,), MOL_TYPE_PROTEIN, dtype=torch.int64)
|
| 450 |
+
res_type = torch.tensor(res_type_vals, dtype=torch.int64)
|
| 451 |
+
input_ids = torch.tensor(input_id_vals, dtype=torch.int64)
|
| 452 |
+
token_bonds = torch.zeros(L, L, 1, dtype=torch.float32)
|
| 453 |
+
token_attention_mask = torch.ones(L, dtype=torch.bool)
|
| 454 |
+
distogram_atom_idx = torch.tensor(distogram_rep_atom_idx, dtype=torch.int64)
|
| 455 |
+
|
| 456 |
+
# Single-sequence MSA: depth 1, row 0 is the sequence itself.
|
| 457 |
+
msa = res_type.unsqueeze(0)
|
| 458 |
+
msa_attention_mask = torch.ones(1, L, dtype=torch.bool)
|
| 459 |
+
has_deletion = torch.zeros(1, L, dtype=torch.bool)
|
| 460 |
+
deletion_value = torch.zeros(1, L, dtype=torch.float32)
|
| 461 |
+
deletion_mean = torch.zeros(L, dtype=torch.float32)
|
| 462 |
+
|
| 463 |
+
features = {
|
| 464 |
+
"token_index": token_index,
|
| 465 |
+
"residue_index": residue_index,
|
| 466 |
+
"asym_id": asym_id,
|
| 467 |
+
"sym_id": sym_id,
|
| 468 |
+
"entity_id": entity_id,
|
| 469 |
+
"mol_type": mol_type,
|
| 470 |
+
"res_type": res_type,
|
| 471 |
+
"input_ids": input_ids,
|
| 472 |
+
"token_bonds": token_bonds,
|
| 473 |
+
"token_attention_mask": token_attention_mask,
|
| 474 |
+
"ref_pos": ref_pos,
|
| 475 |
+
"ref_element": ref_element,
|
| 476 |
+
"ref_charge": ref_charge,
|
| 477 |
+
"ref_atom_name_chars": ref_atom_name_chars,
|
| 478 |
+
"ref_space_uid": ref_space_uid,
|
| 479 |
+
"atom_attention_mask": atom_attention_mask,
|
| 480 |
+
"atom_to_token": atom_to_token,
|
| 481 |
+
"distogram_atom_idx": distogram_atom_idx,
|
| 482 |
+
"msa": msa,
|
| 483 |
+
"msa_attention_mask": msa_attention_mask,
|
| 484 |
+
"has_deletion": has_deletion,
|
| 485 |
+
"deletion_value": deletion_value,
|
| 486 |
+
"deletion_mean": deletion_mean,
|
| 487 |
+
}
|
| 488 |
+
return {k: v.unsqueeze(0) for k, v in features.items()}
|