lhallee commited on
Commit
fb8a87c
·
verified ·
1 Parent(s): 5e243b2

Upload folder using huggingface_hub

Browse files
LICENSE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ **License (MIT)**
2
+
3
+ Copyright 2026 Chan Zuckerberg Biohub, Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
README.md CHANGED
@@ -1,199 +1,122 @@
1
  ---
2
  library_name: transformers
3
- tags: []
 
 
 
 
4
  ---
5
 
6
- # Model Card for Model ID
7
 
8
- <!-- Provide a quick summary of what the model is/does. -->
9
 
 
10
 
 
 
 
11
 
12
- ## Model Details
 
 
 
 
 
 
13
 
14
- ### Model Description
15
 
16
- <!-- Provide a longer summary of what this model is. -->
17
 
18
- This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
 
19
 
20
- - **Developed by:** [More Information Needed]
21
- - **Funded by [optional]:** [More Information Needed]
22
- - **Shared by [optional]:** [More Information Needed]
23
- - **Model type:** [More Information Needed]
24
- - **Language(s) (NLP):** [More Information Needed]
25
- - **License:** [More Information Needed]
26
- - **Finetuned from model [optional]:** [More Information Needed]
27
 
28
- ### Model Sources [optional]
 
 
29
 
30
- <!-- Provide the basic links for the model. -->
31
 
32
- - **Repository:** [More Information Needed]
33
- - **Paper [optional]:** [More Information Needed]
34
- - **Demo [optional]:** [More Information Needed]
35
 
36
- ## Uses
 
 
37
 
38
- <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
 
40
- ### Direct Use
41
 
42
- <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
 
43
 
44
- [More Information Needed]
 
 
 
 
 
 
45
 
46
- ### Downstream Use [optional]
 
 
 
 
 
 
47
 
48
- <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
 
49
 
50
- [More Information Needed]
51
 
52
- ### Out-of-Scope Use
 
53
 
54
- <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
 
 
 
 
 
55
 
56
- [More Information Needed]
 
57
 
58
- ## Bias, Risks, and Limitations
59
 
60
- <!-- This section is meant to convey both technical and sociotechnical limitations. -->
 
61
 
62
- [More Information Needed]
 
 
 
 
 
 
63
 
64
- ### Recommendations
 
65
 
66
- <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
-
68
- Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
-
70
- ## How to Get Started with the Model
71
-
72
- Use the code below to get started with the model.
73
-
74
- [More Information Needed]
75
-
76
- ## Training Details
77
-
78
- ### Training Data
79
-
80
- <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
-
82
- [More Information Needed]
83
-
84
- ### Training Procedure
85
-
86
- <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
-
88
- #### Preprocessing [optional]
89
-
90
- [More Information Needed]
91
-
92
-
93
- #### Training Hyperparameters
94
-
95
- - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
-
97
- #### Speeds, Sizes, Times [optional]
98
-
99
- <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
-
101
- [More Information Needed]
102
-
103
- ## Evaluation
104
-
105
- <!-- This section describes the evaluation protocols and provides the results. -->
106
-
107
- ### Testing Data, Factors & Metrics
108
-
109
- #### Testing Data
110
-
111
- <!-- This should link to a Dataset Card if possible. -->
112
-
113
- [More Information Needed]
114
-
115
- #### Factors
116
-
117
- <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
-
119
- [More Information Needed]
120
-
121
- #### Metrics
122
-
123
- <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
-
125
- [More Information Needed]
126
-
127
- ### Results
128
-
129
- [More Information Needed]
130
-
131
- #### Summary
132
-
133
-
134
-
135
- ## Model Examination [optional]
136
-
137
- <!-- Relevant interpretability work for the model goes here -->
138
-
139
- [More Information Needed]
140
-
141
- ## Environmental Impact
142
-
143
- <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
-
145
- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
-
147
- - **Hardware Type:** [More Information Needed]
148
- - **Hours used:** [More Information Needed]
149
- - **Cloud Provider:** [More Information Needed]
150
- - **Compute Region:** [More Information Needed]
151
- - **Carbon Emitted:** [More Information Needed]
152
-
153
- ## Technical Specifications [optional]
154
-
155
- ### Model Architecture and Objective
156
-
157
- [More Information Needed]
158
-
159
- ### Compute Infrastructure
160
-
161
- [More Information Needed]
162
-
163
- #### Hardware
164
-
165
- [More Information Needed]
166
-
167
- #### Software
168
-
169
- [More Information Needed]
170
-
171
- ## Citation [optional]
172
-
173
- <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
-
175
- **BibTeX:**
176
-
177
- [More Information Needed]
178
-
179
- **APA:**
180
-
181
- [More Information Needed]
182
-
183
- ## Glossary [optional]
184
-
185
- <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
-
187
- [More Information Needed]
188
-
189
- ## More Information [optional]
190
-
191
- [More Information Needed]
192
-
193
- ## Model Card Authors [optional]
194
-
195
- [More Information Needed]
196
-
197
- ## Model Card Contact
198
-
199
- [More Information Needed]
 
1
  ---
2
  library_name: transformers
3
+ tags:
4
+ - biology
5
+ - protein-structure
6
+ - esmfold2
7
+ - multimodal-protein-model
8
  ---
9
 
10
+ # FastPLMs ESMFold2
11
 
12
+ FastPLMs ESMFold2 is a self-contained Hugging Face `AutoModel` wrapper for Biohub's ESMFold2 and ESMFold2-Fast structure predictors. It vendors the released Biohub ESMFold2 model code, ESMC backbone code, input builder, MSA helpers, and structure export utilities needed for remote-code loading.
13
 
14
+ ## Load With AutoModel
15
 
16
+ ```python
17
+ import torch
18
+ from transformers import AutoModel
19
 
20
+ model = AutoModel.from_pretrained(
21
+ "Synthyra/ESMFold2-Fast",
22
+ trust_remote_code=True,
23
+ dtype=torch.bfloat16,
24
+ device_map="cuda",
25
+ ).eval()
26
+ ```
27
 
28
+ Use `Synthyra/ESMFold2` for the full model and `Synthyra/ESMFold2-Fast` for the faster release variant.
29
 
30
+ ## Fold One Protein
31
 
32
+ ```python
33
+ sequence = "MKTLLILAVVAAALA"
34
 
35
+ result = model.fold_protein(
36
+ sequence,
37
+ num_loops=3,
38
+ num_sampling_steps=50,
39
+ num_diffusion_samples=1,
40
+ seed=0,
41
+ )
42
 
43
+ print(float(result.plddt.mean()))
44
+ print(float(result.ptm))
45
+ ```
46
 
47
+ ## Save mmCIF or PDB
48
 
49
+ ```python
50
+ model.save_as_cif(result, "prediction.cif")
51
+ model.save_as_pdb(result, "prediction.pdb")
52
 
53
+ cif_text = model.result_to_cif(result)
54
+ pdb_text = model.result_to_pdb(result)
55
+ ```
56
 
57
+ `result_to_cif` preserves the full `MolecularComplex`. `result_to_pdb` converts through Biohub's protein-only `ProteinComplex` representation, so use mmCIF for complexes with ligands or nucleic acids.
58
 
59
+ ## Fold Complexes
60
 
61
+ ```python
62
+ types = model.input_types
63
 
64
+ complex_input = types.StructurePredictionInput(
65
+ sequences=[
66
+ types.ProteinInput(id="A", sequence="MKTLLILAVVAAALA"),
67
+ types.DNAInput(id="B", sequence="GATAGC"),
68
+ types.LigandInput(id="L", ccd=["SAH"]),
69
+ ]
70
+ )
71
 
72
+ result = model.fold(
73
+ complex_input,
74
+ num_loops=3,
75
+ num_sampling_steps=50,
76
+ num_diffusion_samples=1,
77
+ seed=0,
78
+ )
79
 
80
+ model.save_as_cif(result, "complex_prediction.cif")
81
+ ```
82
 
83
+ ## Use MSAs
84
 
85
+ ```python
86
+ types = model.input_types
87
 
88
+ msa = types.MSA.from_a3m("query.a3m", max_sequences=128)
89
+ input_with_msa = types.StructurePredictionInput(
90
+ sequences=[
91
+ types.ProteinInput(id="A", sequence=msa.query, msa=msa),
92
+ ]
93
+ )
94
 
95
+ result = model.fold(input_with_msa, num_sampling_steps=50, seed=0)
96
+ ```
97
 
98
+ ## Raw Tensor Inference
99
 
100
+ ```python
101
+ features, chain_infos = model.prepare_structure_input(complex_input, seed=0)
102
 
103
+ with torch.inference_mode():
104
+ output = model(
105
+ **features,
106
+ num_loops=3,
107
+ num_sampling_steps=50,
108
+ num_diffusion_samples=1,
109
+ )
110
 
111
+ decoded = model.input_builder.decode(output, features, chain_infos)
112
+ ```
113
 
114
+ Set `load_esmc=False` when loading if you want to provide precomputed `lm_hidden_states` manually or run folding-trunk tests without loading the 6B ESMC backbone:
115
+
116
+ ```python
117
+ model = AutoModel.from_pretrained(
118
+ "Synthyra/ESMFold2-Fast",
119
+ trust_remote_code=True,
120
+ load_esmc=False,
121
+ ).cuda().eval()
122
+ ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import sys
3
+
4
+ from .configuration_esmfold2 import ESMFold2Config
5
+ from .modeling_esmfold2 import ESMFold2Model
6
+
7
+
8
+ def ensure_vendored_esm() -> None:
9
+ sys.modules["esm"] = importlib.import_module(f"{__name__}.esm")
10
+
11
+
12
+ __all__ = ["ESMFold2Config", "ESMFold2Model", "ensure_vendored_esm"]
configuration_esmc.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Biohub. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ESMC model configuration."""
15
+
16
+ from transformers.configuration_utils import PretrainedConfig
17
+
18
+
19
+ class ESMCConfig(PretrainedConfig):
20
+ """
21
+ This is the configuration class to store the configuration of a [`ESMCModel`]. It is used to
22
+ instantiate an ESMC model according to the specified arguments, defining the model architecture.
23
+
24
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model
25
+ outputs. Read the documentation from [`PretrainedConfig`] for more information.
26
+
27
+ Args:
28
+ vocab_size (`int`, *optional*, defaults to 64):
29
+ Vocabulary size of the ESMC model. Defines the number of different amino acid tokens that
30
+ can be represented by the ``input_ids`` passed to [`ESMCModel`].
31
+ d_model (`int`, *optional*, defaults to 2560):
32
+ Dimensionality of the encoder layers and the pooler layer.
33
+ n_heads (`int`, *optional*, defaults to 40):
34
+ Number of attention heads for each attention layer in the Transformer encoder.
35
+ n_layers (`int`, *optional*, defaults to 80):
36
+ Number of hidden layers in the Transformer encoder.
37
+ pad_token_id (`int`, *optional*, defaults to 1):
38
+ Index of the padding token in the vocabulary (``"<pad>"``).
39
+ mask_token_id (`int`, *optional*, defaults to 32):
40
+ Index of the mask token in the vocabulary (``"<mask>"``), used for masked language modelling.
41
+ initializer_range (`float`, *optional*, defaults to 0.02):
42
+ The standard deviation of the truncated normal initialiser for weight matrix initialisation.
43
+ classifier_dropout (`float`, *optional*, defaults to 0.1):
44
+ Dropout ratio for the classification head.
45
+
46
+ Examples:
47
+
48
+ ```python
49
+ >>> from transformers import ESMCConfig, ESMCModel
50
+
51
+ >>> # Initializing an ESMC EvolutionaryScale/esmc-600m-2024-12 style configuration
52
+ >>> configuration = ESMCConfig()
53
+
54
+ >>> # Initializing a model (with random weights) from the EvolutionaryScale/esmc-600m-2024-12 style configuration
55
+ >>> model = ESMCModel(configuration)
56
+
57
+ >>> # Accessing the model configuration
58
+ >>> configuration = model.config
59
+ ```
60
+ """
61
+
62
+ model_type = "esmc"
63
+
64
+ def __init__(
65
+ self,
66
+ vocab_size: int = 64,
67
+ d_model: int = 2560,
68
+ n_heads: int = 40,
69
+ n_layers: int = 80,
70
+ pad_token_id: int = 1,
71
+ mask_token_id: int = 32,
72
+ initializer_range: float = 0.02,
73
+ classifier_dropout: float = 0.1,
74
+ **kwargs,
75
+ ):
76
+ super().__init__(
77
+ pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs
78
+ )
79
+
80
+ self.vocab_size = vocab_size
81
+ self.d_model = d_model
82
+ self.n_heads = n_heads
83
+ self.n_layers = n_layers
84
+ self.initializer_range = initializer_range
85
+ self.classifier_dropout = classifier_dropout
86
+ self.tie_word_embeddings = False
87
+
88
+
89
+ __all__ = ["ESMCConfig"]
configuration_esmc_sae.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Biohub. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ESMC sparse autoencoder (SAE) configuration."""
15
+
16
+ from dataclasses import dataclass
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+
20
+
21
+ @dataclass
22
+ class ESMCSAEParams:
23
+ """Parameters for one backbone layer's SAE inside :class:`ESMCSAEModel`.
24
+
25
+ The SAE itself is an internal ``nn.Module``; this dataclass just bundles
26
+ the handful of fields needed to instantiate one.
27
+ """
28
+
29
+ d_model: int = 2560
30
+ codebook_dim: int = 65536
31
+ k: int = 64
32
+ layer: int = 0
33
+
34
+
35
+ class ESMCSAEConfig(PretrainedConfig):
36
+ """
37
+ Configuration class for [`ESMCSAEModel`] — a container that holds one
38
+ SAE per backbone layer for a fixed ``(model, codebook_dim, k)`` group.
39
+
40
+ All SAEs in a container share ``d_model``, ``codebook_dim``, and ``k``;
41
+ they differ only in the backbone layer they were trained on.
42
+ ``available_layers`` lists the backbone-layer indices the repo ships;
43
+ each entry ``i`` is stored on disk as ``layer_{i}.safetensors`` (the
44
+ filename index *is* the backbone layer, so a single-layer repo for
45
+ layer 23 stores ``layer_23.safetensors``).
46
+
47
+ Args:
48
+ d_model (`int`, *optional*, defaults to 2560):
49
+ Dimensionality of the ESMC hidden states fed into the SAEs.
50
+ codebook_dim (`int`, *optional*, defaults to 65536):
51
+ Number of sparse features in each SAE's codebook.
52
+ k (`int`, *optional*, defaults to 64):
53
+ Top-k sparsity per SAE.
54
+ available_layers (`list[int]`, *optional*, defaults to ``[0]``):
55
+ Which backbone-layer indices the repo ships.
56
+ """
57
+
58
+ model_type = "esmc_sae"
59
+
60
+ def __init__(
61
+ self,
62
+ d_model: int = 2560,
63
+ codebook_dim: int = 65536,
64
+ k: int = 64,
65
+ available_layers: list[int] | None = None,
66
+ **kwargs,
67
+ ):
68
+ super().__init__(**kwargs)
69
+ self.d_model = d_model
70
+ self.codebook_dim = codebook_dim
71
+ self.k = k
72
+ self.available_layers = (
73
+ list(available_layers) if available_layers is not None else [0]
74
+ )
75
+
76
+
77
+ __all__ = ["ESMCSAEConfig", "ESMCSAEParams"]
configuration_esmfold2.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Biohub. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ESMFold2 model configuration."""
15
+
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import asdict, dataclass, field
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # Nested dataclass configs
24
+ # ---------------------------------------------------------------------------
25
+
26
+ _DEFAULT_ESMC_HF_REPO = "biohub/ESMC-6B"
27
+
28
+
29
+ @dataclass
30
+ class MSAEncoderConfig:
31
+ """Config for the optional MSA encoder module (Large MSA models only)."""
32
+
33
+ enabled: bool = False
34
+ d_msa: int = 128
35
+ d_hidden: int = 32
36
+ n_layers: int = 4
37
+ n_heads_msa: int = 8
38
+ msa_head_width: int = 32
39
+
40
+
41
+ @dataclass
42
+ class ParcaeConfig:
43
+ """Release-only config for the parcae diffusion-loop scheduler."""
44
+
45
+ enabled: bool = True
46
+ poisson_mean: float = 3.0
47
+ min_steps: int = 1
48
+ max_steps: int | None = 6
49
+ coda_n_layers: int = 2
50
+
51
+
52
+ @dataclass
53
+ class LMEncoderConfig:
54
+ """Release-only config for the LM-side pair encoder."""
55
+
56
+ enabled: bool = True
57
+ n_layers: int = 4
58
+ lm_dropout: float = 0.25
59
+ per_loop_lm_dropout: bool = True
60
+
61
+
62
+ @dataclass
63
+ class AtomAttentionConfig:
64
+ """Config for SWA atom encoder/decoder with 3D RoPE."""
65
+
66
+ d_atom: int = 128
67
+ d_token: int = 768
68
+ n_blocks: int = 3
69
+ n_heads: int = 4
70
+ swa_window_size: int = 128
71
+ expansion_ratio: int = 2
72
+ # 3D RoPE config
73
+ spatial_rope_base_frequency: float = 20.0
74
+ n_spatial_rope_pairs_per_axis: int = 2
75
+ n_uid_rope_pairs: int = 10
76
+ uid_rope_base_frequency: float = 10000.0
77
+
78
+
79
+ @dataclass
80
+ class FoldingTrunkConfig:
81
+ n_layers: int = 24
82
+ n_heads: int = 8
83
+ dropout: float = 0.0
84
+
85
+
86
+ @dataclass
87
+ class InputsEmbedderConfig:
88
+ d_inputs: int = 451
89
+ atom_encoder: AtomAttentionConfig = field(default_factory=AtomAttentionConfig)
90
+
91
+ def __post_init__(self):
92
+ if isinstance(self.atom_encoder, dict):
93
+ self.atom_encoder = AtomAttentionConfig(**self.atom_encoder)
94
+
95
+
96
+ @dataclass
97
+ class DiffusionModuleConfig:
98
+ """Config for the DiffusionModule."""
99
+
100
+ sigma_data: float = 16.0
101
+ c_atom: int = 128
102
+ c_token: int = 768
103
+ c_z: int = 256
104
+ c_s_inputs: int = 451
105
+ fourier_dim: int = 256
106
+ relpos_r_max: int = 32
107
+ relpos_s_max: int = 2
108
+ atom_num_blocks: int = 3
109
+ atom_num_heads: int = 4
110
+ token_num_blocks: int = 12
111
+ token_num_heads: int = 16
112
+ transition_multiplier: int = 2
113
+
114
+
115
+ @dataclass
116
+ class DiffusionStructureHeadConfig:
117
+ """Config for the diffusion-based structure prediction head."""
118
+
119
+ diffusion_module: DiffusionModuleConfig = field(
120
+ default_factory=DiffusionModuleConfig
121
+ )
122
+ distogram_bins: int = 128
123
+
124
+ # Training noise: sigma ~ sigma_data * exp(mu + sigma * N(0,1))
125
+ train_noise_log_mean: float = -1.2
126
+ train_noise_log_std: float = 1.5
127
+
128
+ # Sampling defaults (ODE)
129
+ gamma_0: float = 0.605
130
+ gamma_min: float = 1.107
131
+ noise_scale: float = 0.0
132
+ step_scale: float = 1.0
133
+
134
+ # Inference schedule defaults
135
+ inference_s_max: float = 160.0
136
+ inference_s_min: float = 4e-4
137
+ inference_p: float = 8.0
138
+ inference_num_steps: int = 68
139
+
140
+ def __post_init__(self):
141
+ if isinstance(self.diffusion_module, dict):
142
+ self.diffusion_module = DiffusionModuleConfig(**self.diffusion_module)
143
+
144
+
145
+ @dataclass
146
+ class ConfidenceHeadConfig:
147
+ enabled: bool = True
148
+ num_plddt_bins: int = 50
149
+ num_pde_bins: int = 64
150
+ num_pae_bins: int = 64
151
+ min_dist: float = 2.0
152
+ max_dist: float = 52.0
153
+ distogram_bins: int = 128
154
+ folding_trunk: FoldingTrunkConfig = field(
155
+ default_factory=lambda: FoldingTrunkConfig(n_layers=4)
156
+ )
157
+
158
+ def __post_init__(self):
159
+ if isinstance(self.folding_trunk, dict):
160
+ self.folding_trunk = FoldingTrunkConfig(**self.folding_trunk)
161
+
162
+
163
+ # ---------------------------------------------------------------------------
164
+ # Top-level config
165
+ # ---------------------------------------------------------------------------
166
+
167
+
168
+ class ESMFold2Config(PretrainedConfig):
169
+ """
170
+ Configuration for the ESMFold2 structure prediction model.
171
+
172
+ Uses SWA atom encoders with 3D RoPE, a diffusion transformer,
173
+ a folding trunk, and an ESMC 6B PLM backbone.
174
+
175
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control
176
+ the model outputs. Read the documentation from [`PretrainedConfig`] for more
177
+ information.
178
+
179
+ Args:
180
+ d_single (`int`, defaults to 384):
181
+ Dimensionality of single (per-residue) representations.
182
+ d_pair (`int`, defaults to 256):
183
+ Dimensionality of pair (residue-residue) representations.
184
+ n_relative_residx_bins (`int`, defaults to 32):
185
+ Number of bins for relative residue index encoding.
186
+ n_relative_chain_bins (`int`, defaults to 2):
187
+ Number of bins for relative chain encoding.
188
+ num_loops (`int`, defaults to 10):
189
+ Number of trunk loops for iterative refinement.
190
+ num_diffusion_samples (`int`, defaults to 8):
191
+ Number of parallel structure predictions to generate.
192
+ lm_dropout (`float`, defaults to 0.0):
193
+ Dropout probability on LM pair embeddings. When > 0, dropout is
194
+ applied with ``training=True`` (including at inference) to match
195
+ the experimental training recipe used by binder design.
196
+ force_lm_dropout_during_inference (`bool`, defaults to False):
197
+ When True, apply ``lm_dropout`` even when ``model.eval()`` and
198
+ ``lm_dropout`` > 0. Binder-design loads set this to True.
199
+ disable_msa_features (`bool`, defaults to False):
200
+ When True, zero out MSA-derived ``profile`` and ``deletion_mean``
201
+ before the inputs embedder (experimental medium/large checkpoints).
202
+ inputs (`InputsEmbedderConfig`):
203
+ Configuration for the inputs embedder module.
204
+ folding_trunk (`FoldingTrunkConfig`):
205
+ Configuration for the folding trunk.
206
+ structure_head (`DiffusionStructureHeadConfig`):
207
+ Configuration for the diffusion-based structure prediction head.
208
+ confidence_head (`ConfidenceHeadConfig`):
209
+ Configuration for the confidence prediction head.
210
+
211
+ Examples:
212
+
213
+ ```python
214
+ >>> from transformers import ESMFold2Config, ESMFold2ExperimentalModel
215
+
216
+ >>> # Initializing an ESMFold2 configuration
217
+ >>> configuration = ESMFold2Config(type="experimental")
218
+
219
+ >>> # Initializing a model (with random weights) from the configuration
220
+ >>> model = ESMFold2ExperimentalModel(configuration)
221
+
222
+ >>> # Accessing the model configuration
223
+ >>> configuration = model.config
224
+ ```
225
+ """
226
+
227
+ model_type = "esmfold2"
228
+ has_no_defaults_at_init = True
229
+
230
+ def __init__(self, **kwargs):
231
+ super().__init__(**kwargs)
232
+
233
+ self.type: str = kwargs.get("type", "release")
234
+ if self.type not in ("release", "experimental"):
235
+ raise ValueError(
236
+ f"ESMFold2Config.type must be 'release' or 'experimental', "
237
+ f"got {self.type!r}"
238
+ )
239
+
240
+ # Top-level scalar fields
241
+ self.d_single: int = kwargs.get("d_single", 384)
242
+ self.d_pair: int = kwargs.get("d_pair", 256)
243
+ self.n_relative_residx_bins: int = kwargs.get("n_relative_residx_bins", 32)
244
+ self.n_relative_chain_bins: int = kwargs.get("n_relative_chain_bins", 2)
245
+ self.num_loops: int = kwargs.get("num_loops", 10)
246
+ self.num_diffusion_samples: int = kwargs.get("num_diffusion_samples", 8)
247
+ # If True, ``profile`` / ``deletion_mean`` are zeroed before the inputs
248
+ # embedder.
249
+ self.disable_msa_features: bool = kwargs.get("disable_msa_features", False)
250
+ self.lm_dropout: float = kwargs.get("lm_dropout", 0.0)
251
+ self.force_lm_dropout_during_inference: bool = kwargs.get(
252
+ "force_lm_dropout_during_inference", False
253
+ )
254
+
255
+ self.lm_d_model: int = kwargs.get("lm_d_model", 2560)
256
+ self.lm_num_layers: int = kwargs.get("lm_num_layers", 80)
257
+ # Required, no default — every shipped HF export must name its ESMC backbone.
258
+ self.esmc_id: str = kwargs.get("esmc_id", _DEFAULT_ESMC_HF_REPO)
259
+
260
+ def _init_nested(cls, val):
261
+ if isinstance(val, cls):
262
+ return val
263
+ if isinstance(val, dict):
264
+ return cls(**val)
265
+ return cls()
266
+
267
+ self.inputs = _init_nested(InputsEmbedderConfig, kwargs.get("inputs"))
268
+ self.folding_trunk = _init_nested(
269
+ FoldingTrunkConfig, kwargs.get("folding_trunk")
270
+ )
271
+ self.structure_head = _init_nested(
272
+ DiffusionStructureHeadConfig, kwargs.get("structure_head")
273
+ )
274
+ self.confidence_head = _init_nested(
275
+ ConfidenceHeadConfig, kwargs.get("confidence_head")
276
+ )
277
+ self.msa_encoder = _init_nested(MSAEncoderConfig, kwargs.get("msa_encoder"))
278
+ # Release-only modules — ignored when ``type == "experimental"``.
279
+ self.parcae = _init_nested(ParcaeConfig, kwargs.get("parcae"))
280
+ self.lm_encoder = _init_nested(LMEncoderConfig, kwargs.get("lm_encoder"))
281
+ # If True, MSA encoder output replaces the pair stream; if False, it is added.
282
+ self.msa_encoder_overwrite: bool = bool(
283
+ kwargs.get("msa_encoder_overwrite", True)
284
+ )
285
+
286
+ def to_dict(self):
287
+ output = super().to_dict()
288
+ output["inputs"] = asdict(self.inputs)
289
+ output["folding_trunk"] = asdict(self.folding_trunk)
290
+ output["structure_head"] = asdict(self.structure_head)
291
+ output["confidence_head"] = asdict(self.confidence_head)
292
+ output["msa_encoder"] = asdict(self.msa_encoder)
293
+ output["parcae"] = asdict(self.parcae)
294
+ output["lm_encoder"] = asdict(self.lm_encoder)
295
+ return output
296
+
297
+
298
+ __all__ = ["ESMFold2Config", "MSAEncoderConfig", "ParcaeConfig", "LMEncoderConfig"]
esmfold2_affine3d.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import typing as T
4
+ from abc import ABC
5
+ from dataclasses import dataclass
6
+
7
+ import torch
8
+ from torch.nn import functional as F
9
+ from typing_extensions import Self
10
+
11
+ from .esmfold2_misc import fp32_autocast_context
12
+
13
+
14
+ class Rotation(ABC):
15
+ @classmethod
16
+ def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ...
17
+
18
+ @classmethod
19
+ def random(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ...
20
+
21
+ def __getitem__(self, idx: T.Any) -> Self: ...
22
+
23
+ @property
24
+ def tensor(self) -> torch.Tensor:
25
+ # We claim that this should be zero-cost abstraction that returns the raw tensor backing this
26
+ # object. The raw tensor should always have exactly 1 more dim than self.shape, which should be
27
+ # implemented using reshaping
28
+ ...
29
+
30
+ @property
31
+ def shape(self) -> torch.Size:
32
+ # The "shape" of the rotation, as if it was a torch.tensor object
33
+ # This means that 1x4 quaternions are treated as size (1,) for example
34
+ ...
35
+
36
+ def as_matrix(self) -> RotationMatrix: ...
37
+
38
+ def as_quat(self, normalize: bool = False) -> RotationQuat: ...
39
+
40
+ def compose(self, other: Self) -> Self:
41
+ # To be safe, we force users to explicitly convert between rotation types.
42
+ ...
43
+
44
+ def convert_compose(self, other: Self) -> Self:
45
+ # This function will automatically convert between types of rotations
46
+ ...
47
+
48
+ def apply(self, p: torch.Tensor) -> torch.Tensor:
49
+ # rotates points by this rotation object
50
+ ...
51
+
52
+ def invert(self) -> Self: ...
53
+
54
+ @property
55
+ def dtype(self) -> torch.dtype:
56
+ return self.tensor.dtype
57
+
58
+ @property
59
+ def device(self) -> torch.device:
60
+ return self.tensor.device
61
+
62
+ @property
63
+ def requires_grad(self) -> bool:
64
+ return self.tensor.requires_grad
65
+
66
+ @classmethod
67
+ def _from_tensor(cls, t: torch.Tensor) -> Self:
68
+ # This function exists to simplify the below functions, esp type signatures
69
+ # Its implementation is different from Affine3D.from_tensor and does not
70
+ # autodetect rotation types.
71
+ return cls(t) # type: ignore
72
+
73
+ def to(self, **kwargs) -> Self:
74
+ return self._from_tensor(self.tensor.to(**kwargs))
75
+
76
+ def detach(self, *args, **kwargs) -> Self:
77
+ return self._from_tensor(self.tensor.detach(**kwargs))
78
+
79
+ def tensor_apply(self, func) -> Self:
80
+ # Applys a function to the underlying tensor
81
+ return self._from_tensor(
82
+ torch.stack([func(x) for x in self.tensor.unbind(dim=-1)], dim=-1)
83
+ )
84
+
85
+
86
+ class RotationMatrix(Rotation):
87
+ def __init__(self, rots: torch.Tensor):
88
+ if rots.shape[-1] == 9:
89
+ rots = rots.unflatten(-1, (3, 3))
90
+ assert rots.shape[-1] == 3
91
+ assert rots.shape[-2] == 3
92
+ # Force full precision
93
+ rots = rots.to(torch.float32)
94
+ self._rots = rots
95
+
96
+ @classmethod
97
+ def identity(cls, shape, **tensor_kwargs):
98
+ rots = torch.eye(3, **tensor_kwargs)
99
+ rots = rots.view(*[1 for _ in range(len(shape))], 3, 3)
100
+ rots = rots.expand(*shape, -1, -1)
101
+ return cls(rots)
102
+
103
+ @classmethod
104
+ def random(cls, shape, **tensor_kwargs):
105
+ return RotationQuat.random(shape, **tensor_kwargs).as_matrix()
106
+
107
+ def __getitem__(self, idx: T.Any) -> RotationMatrix:
108
+ indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
109
+ return RotationMatrix(self._rots[indices + (slice(None), slice(None))])
110
+
111
+ @property
112
+ def shape(self) -> torch.Size:
113
+ return self._rots.shape[:-2]
114
+
115
+ def as_matrix(self) -> RotationMatrix:
116
+ return self
117
+
118
+ def as_quat(self, normalize: bool = False) -> RotationQuat:
119
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
120
+ self._rots.flatten(-2), dim=-1
121
+ )
122
+ q_abs = _sqrt_subgradient(
123
+ torch.stack(
124
+ [
125
+ 1.0 + m00 + m11 + m22,
126
+ 1.0 + m00 - m11 - m22,
127
+ 1.0 - m00 + m11 - m22,
128
+ 1.0 - m00 - m11 + m22,
129
+ ],
130
+ dim=-1,
131
+ )
132
+ )
133
+ # we produce the desired quaternion multiplied by each of r, i, j, k
134
+ quat_by_rijk = torch.stack(
135
+ [
136
+ x
137
+ for lst in [
138
+ [q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01],
139
+ [m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20],
140
+ [m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21],
141
+ [m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2],
142
+ ]
143
+ for x in lst
144
+ ],
145
+ dim=-1,
146
+ ).unflatten(-1, (4, 4))
147
+
148
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
149
+ # the candidate won't be picked.
150
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
151
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
152
+
153
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
154
+ # forall i; we pick the best-conditioned one (with the largest denominator)
155
+ # We manually implement one_hot so torch.compile works
156
+ one_hot = torch.zeros_like(q_abs, dtype=torch.bool)
157
+ one_hot.scatter_(-1, q_abs.argmax(dim=-1, keepdim=True), True)
158
+ quat = quat_candidates[one_hot, :].reshape(q_abs.shape)
159
+ return RotationQuat(quat)
160
+
161
+ def compose(self, other: RotationMatrix) -> RotationMatrix:
162
+ with fp32_autocast_context(self._rots.device.type):
163
+ return RotationMatrix(self._rots @ other._rots)
164
+
165
+ def convert_compose(self, other: Rotation):
166
+ return self.compose(other.as_matrix())
167
+
168
+ def apply(self, p: torch.Tensor) -> torch.Tensor:
169
+ with fp32_autocast_context(self.device.type):
170
+ if self._rots.shape[-3] == 1:
171
+ # This is a slight speedup over einsum for batched rotations
172
+ return p @ self._rots.transpose(-1, -2).squeeze(-3)
173
+ else:
174
+ # einsum way faster than bmm!
175
+ return torch.einsum("...ij,...j", self._rots, p)
176
+
177
+ def invert(self) -> RotationMatrix:
178
+ return RotationMatrix(self._rots.transpose(-1, -2))
179
+
180
+ @property
181
+ def tensor(self) -> torch.Tensor:
182
+ return self._rots.flatten(-2)
183
+
184
+ def to_3x3(self) -> torch.Tensor:
185
+ return self._rots
186
+
187
+ @staticmethod
188
+ def from_graham_schmidt(
189
+ x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12
190
+ ) -> RotationMatrix:
191
+ # A low eps here is necessary for good stability!
192
+ return RotationMatrix(_graham_schmidt(x_axis, xy_plane, eps))
193
+
194
+
195
+ class RotationQuat(Rotation):
196
+ def __init__(self, quats: torch.Tensor, normalized=False):
197
+ assert quats.shape[-1] == 4
198
+ self._normalized = normalized
199
+ # Force float32 as well
200
+ if normalized:
201
+ self._quats = F.normalize(quats.to(torch.float32), dim=-1)
202
+ self._quats = self._quats.where(self._quats[..., :1] >= 0, -self._quats)
203
+ else:
204
+ self._quats = quats.to(torch.float32)
205
+
206
+ @classmethod
207
+ def identity(cls, shape, **tensor_kwargs):
208
+ q = torch.ones((*shape, 4), **tensor_kwargs)
209
+ mult = torch.tensor([1, 0, 0, 0], device=q.device)
210
+ return RotationQuat(q * mult)
211
+
212
+ @classmethod
213
+ def random(cls, shape, **tensor_kwargs):
214
+ quat = torch.randn((*shape, 4), **tensor_kwargs)
215
+ return RotationQuat(quat, normalized=True)
216
+
217
+ def __getitem__(self, idx: T.Any) -> RotationQuat:
218
+ indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
219
+ return RotationQuat(self._quats[indices + (slice(None),)])
220
+
221
+ @property
222
+ def shape(self) -> torch.Size:
223
+ return self._quats.shape[:-1]
224
+
225
+ def compose(self, other: RotationQuat) -> RotationQuat:
226
+ with fp32_autocast_context(self._quats.device.type):
227
+ return RotationQuat(_quat_mult(self._quats, other._quats))
228
+
229
+ def convert_compose(self, other: Rotation):
230
+ return self.compose(other.as_quat())
231
+
232
+ def as_matrix(self) -> RotationMatrix:
233
+ q = self.normalized().tensor
234
+ r, i, j, k = torch.unbind(q, -1)
235
+ two_s = 2.0 / torch.linalg.norm(q, dim=-1)
236
+
237
+ o = torch.stack(
238
+ (
239
+ 1 - two_s * (j * j + k * k),
240
+ two_s * (i * j - k * r),
241
+ two_s * (i * k + j * r),
242
+ two_s * (i * j + k * r),
243
+ 1 - two_s * (i * i + k * k),
244
+ two_s * (j * k - i * r),
245
+ two_s * (i * k - j * r),
246
+ two_s * (j * k + i * r),
247
+ 1 - two_s * (i * i + j * j),
248
+ ),
249
+ -1,
250
+ )
251
+ return RotationMatrix(o.reshape(q.shape[:-1] + (3, 3)))
252
+
253
+ def as_quat(self, normalize: bool = False) -> RotationQuat:
254
+ return self
255
+
256
+ def apply(self, p: torch.Tensor) -> torch.Tensor:
257
+ return _quat_rotation(self.normalized()._quats, p)
258
+
259
+ def invert(self) -> RotationQuat:
260
+ return RotationQuat(_quat_invert(self._quats))
261
+
262
+ @property
263
+ def tensor(self) -> torch.Tensor:
264
+ return self._quats
265
+
266
+ def normalized(self) -> RotationQuat:
267
+ return self if self._normalized else RotationQuat(self._quats, normalized=True)
268
+
269
+
270
+ @dataclass(frozen=True)
271
+ class Affine3D:
272
+ trans: torch.Tensor
273
+ rot: Rotation
274
+
275
+ def __post_init__(self):
276
+ assert self.trans.shape[:-1] == self.rot.shape
277
+
278
+ @staticmethod
279
+ def identity(
280
+ shape_or_affine: T.Union[tuple[int, ...], "Affine3D"],
281
+ rotation_type: T.Type[Rotation] = RotationMatrix,
282
+ **tensor_kwargs,
283
+ ):
284
+ # Creates a new identity Affine3D object with a specified shape
285
+ # or the same shape as another Affine3D object.
286
+ if isinstance(shape_or_affine, Affine3D):
287
+ kwargs = {"dtype": shape_or_affine.dtype, "device": shape_or_affine.device}
288
+ kwargs.update(tensor_kwargs)
289
+ shape = shape_or_affine.shape
290
+ rotation_type = type(shape_or_affine.rot)
291
+ else:
292
+ kwargs = tensor_kwargs
293
+ shape = shape_or_affine
294
+ return Affine3D(
295
+ torch.zeros((*shape, 3), **kwargs), rotation_type.identity(shape, **kwargs)
296
+ )
297
+
298
+ @staticmethod
299
+ def random(
300
+ shape: tuple[int, ...],
301
+ std: float = 1,
302
+ rotation_type: T.Type[Rotation] = RotationMatrix,
303
+ **tensor_kwargs,
304
+ ) -> "Affine3D":
305
+ return Affine3D(
306
+ trans=torch.randn((*shape, 3), **tensor_kwargs).mul(std),
307
+ rot=rotation_type.random(shape, **tensor_kwargs),
308
+ )
309
+
310
+ def __getitem__(self, idx: T.Any) -> "Affine3D":
311
+ indices = (idx,) if isinstance(idx, int) or idx is None else tuple(idx)
312
+ return Affine3D(trans=self.trans[indices + (slice(None),)], rot=self.rot[idx])
313
+
314
+ @property
315
+ def shape(self) -> torch.Size:
316
+ return self.trans.shape[:-1]
317
+
318
+ @property
319
+ def dtype(self) -> torch.dtype:
320
+ return self.trans.dtype
321
+
322
+ @property
323
+ def device(self) -> torch.device:
324
+ return self.trans.device
325
+
326
+ @property
327
+ def requires_grad(self) -> bool:
328
+ return self.trans.requires_grad
329
+
330
+ def to(self, **kwargs) -> "Affine3D":
331
+ return Affine3D(self.trans.to(**kwargs), self.rot.to(**kwargs))
332
+
333
+ def detach(self, *args, **kwargs) -> "Affine3D":
334
+ return Affine3D(self.trans.detach(**kwargs), self.rot.detach(**kwargs))
335
+
336
+ def tensor_apply(self, func) -> "Affine3D":
337
+ # Applys a function to the underlying tensor
338
+ return self.from_tensor(
339
+ torch.stack([func(x) for x in self.tensor.unbind(dim=-1)], dim=-1)
340
+ )
341
+
342
+ def as_matrix(self):
343
+ return Affine3D(trans=self.trans, rot=self.rot.as_matrix())
344
+
345
+ def as_quat(self, normalize: bool = False):
346
+ return Affine3D(trans=self.trans, rot=self.rot.as_quat(normalize))
347
+
348
+ def compose(self, other: "Affine3D", autoconvert: bool = False):
349
+ rot = self.rot
350
+ new_rot = (rot.convert_compose if autoconvert else rot.compose)(other.rot)
351
+ new_trans = rot.apply(other.trans) + self.trans
352
+ return Affine3D(trans=new_trans, rot=new_rot)
353
+
354
+ def compose_rotation(self, other: Rotation, autoconvert: bool = False):
355
+ return Affine3D(
356
+ trans=self.trans,
357
+ rot=(self.rot.convert_compose if autoconvert else self.rot.compose)(other),
358
+ )
359
+
360
+ def scale(self, v: torch.Tensor | float):
361
+ return Affine3D(self.trans * v, self.rot)
362
+
363
+ def mask(self, mask: torch.Tensor, with_zero=False):
364
+ # Returns a transform where True positions in mask is identity
365
+ if with_zero:
366
+ tensor = self.tensor
367
+ return Affine3D.from_tensor(
368
+ torch.zeros_like(tensor).where(mask[..., None], tensor)
369
+ )
370
+ else:
371
+ identity = self.identity(
372
+ self.shape,
373
+ rotation_type=type(self.rot),
374
+ device=self.device,
375
+ dtype=self.dtype,
376
+ ).tensor
377
+ return Affine3D.from_tensor(identity.where(mask[..., None], self.tensor))
378
+
379
+ def apply(self, p: torch.Tensor) -> torch.Tensor:
380
+ return self.rot.apply(p) + self.trans
381
+
382
+ def invert(self):
383
+ inv_rot = self.rot.invert()
384
+ return Affine3D(trans=-inv_rot.apply(self.trans), rot=inv_rot)
385
+
386
+ @property
387
+ def tensor(self) -> torch.Tensor:
388
+ return torch.cat([self.rot.tensor, self.trans], dim=-1)
389
+
390
+ @staticmethod
391
+ def from_tensor(t: torch.Tensor) -> "Affine3D":
392
+ match t.shape[-1]:
393
+ case 4:
394
+ # Assume tensor 4x4 for backward compat with alphafold
395
+ trans = t[..., :3, 3]
396
+ rot = RotationMatrix(t[..., :3, :3])
397
+ case 6:
398
+ # Assume quaternion representation with real part = 1
399
+ trans = t[..., -3:]
400
+ rot = RotationQuat(F.pad(t[..., :3], (1, 0), value=1))
401
+ case 7:
402
+ trans = t[..., -3:]
403
+ rot = RotationQuat(t[..., :4])
404
+ case 12:
405
+ trans = t[..., -3:]
406
+ rot = RotationMatrix(t[..., :-3].unflatten(-1, (3, 3)))
407
+ case _:
408
+ raise RuntimeError(
409
+ f"Cannot detect rotation fromat from {t.shape[-1] -3}-d flat vector"
410
+ )
411
+ return Affine3D(trans, rot)
412
+
413
+ @staticmethod
414
+ def from_tensor_pair(t: torch.Tensor, r: torch.Tensor) -> "Affine3D":
415
+ return Affine3D(t, RotationMatrix(r))
416
+
417
+ @staticmethod
418
+ def from_graham_schmidt(
419
+ neg_x_axis: torch.Tensor,
420
+ origin: torch.Tensor,
421
+ xy_plane: torch.Tensor,
422
+ eps: float = 1e-10,
423
+ ):
424
+ # The arguments of this function is for parity with AlphaFold
425
+ x_axis = origin - neg_x_axis
426
+ xy_plane = xy_plane - origin
427
+ return Affine3D(
428
+ trans=origin, rot=RotationMatrix.from_graham_schmidt(x_axis, xy_plane, eps)
429
+ )
430
+
431
+ @staticmethod
432
+ def cat(affines: list["Affine3D"], dim: int = 0):
433
+ if dim < 0:
434
+ dim = len(affines[0].shape) + dim
435
+ return Affine3D.from_tensor(torch.cat([x.tensor for x in affines], dim=dim))
436
+
437
+
438
+ def _quat_mult(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
439
+ """
440
+ Multiply two quaternions.
441
+ Usual torch rules for broadcasting apply.
442
+
443
+ Args:
444
+ a: Quaternions as tensor of shape (..., 4), real part first.
445
+ b: Quaternions as tensor of shape (..., 4), real part first.
446
+
447
+ Returns:
448
+ The product of a and b, a tensor of quaternions shape (..., 4).
449
+ """
450
+ aw, ax, ay, az = torch.unbind(a, -1)
451
+ bw, bx, by, bz = torch.unbind(b, -1)
452
+ ow = aw * bw - ax * bx - ay * by - az * bz
453
+ ox = aw * bx + ax * bw + ay * bz - az * by
454
+ oy = aw * by - ax * bz + ay * bw + az * bx
455
+ oz = aw * bz + ax * by - ay * bx + az * bw
456
+ return torch.stack((ow, ox, oy, oz), -1)
457
+
458
+
459
+ def _quat_rotation(q: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
460
+ """
461
+ Rotates p by quaternion q. Usual torch rules for broadcasting apply.
462
+
463
+ Args:
464
+ q: Quaternions as tensor of shape (..., 4), real part first.
465
+ p: Points as tensor of shape (..., 3)
466
+
467
+ Returns:
468
+ The rotated version of p, of shape (..., 3)
469
+ """
470
+ aw, ax, ay, az = torch.unbind(q, -1)
471
+ bx, by, bz = torch.unbind(p, -1)
472
+ # fmt: off
473
+ ow = - ax * bx - ay * by - az * bz
474
+ ox = aw * bx + ay * bz - az * by
475
+ oy = aw * by - ax * bz + az * bx
476
+ oz = aw * bz + ax * by - ay * bx
477
+ # fmt: on
478
+ q_mul_pts = torch.stack((ow, ox, oy, oz), -1)
479
+ return _quat_mult(q_mul_pts, _quat_invert(q))[..., 1:]
480
+
481
+
482
+ def _quat_invert(q: torch.Tensor):
483
+ return q * torch.tensor([1, -1, -1, -1], device=q.device)
484
+
485
+
486
+ def _sqrt_subgradient(x: torch.Tensor) -> torch.Tensor:
487
+ # Returns torch.sqrt(torch.max(0, x)) but with a zero subgradient where x is 0.
488
+ ret = torch.zeros_like(x)
489
+ positive_mask = x > 0
490
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
491
+ return ret
492
+
493
+
494
+ def _graham_schmidt(x_axis: torch.Tensor, xy_plane: torch.Tensor, eps: float = 1e-12):
495
+ # A low eps here is necessary for good stability!
496
+ with fp32_autocast_context(x_axis.device.type):
497
+ e1 = xy_plane
498
+
499
+ denom = torch.sqrt((x_axis**2).sum(dim=-1, keepdim=True) + eps)
500
+ x_axis = x_axis / denom
501
+ dot = (x_axis * e1).sum(dim=-1, keepdim=True)
502
+ e1 = e1 - x_axis * dot
503
+ denom = torch.sqrt((e1**2).sum(dim=-1, keepdim=True) + eps)
504
+ e1 = e1 / denom
505
+ e2 = torch.cross(x_axis, e1, dim=-1)
506
+
507
+ rots = torch.stack([x_axis, e1, e2], dim=-1)
508
+
509
+ return rots
510
+
511
+
512
+ def build_affine3d_from_coordinates(
513
+ coords: torch.Tensor, # (N, CA, C).
514
+ ) -> tuple[Affine3D, torch.Tensor]:
515
+ _MAX_SUPPORTED_DISTANCE = 1e6
516
+ coord_mask = torch.all(
517
+ torch.all(torch.isfinite(coords) & (coords < _MAX_SUPPORTED_DISTANCE), dim=-1),
518
+ dim=-1,
519
+ )
520
+
521
+ def atom3_to_backbone_affine(bb_positions: torch.Tensor) -> Affine3D:
522
+ N, CA, C = bb_positions.unbind(dim=-2)
523
+ return Affine3D.from_graham_schmidt(C, CA, N)
524
+
525
+ coords = coords.clone().float()
526
+ coords[~coord_mask] = 0
527
+
528
+ # NOTE(thayes): If you have already normalized the coordinates, then
529
+ # the black hole affine translations will be zeros and the rotations will be
530
+ # the identity.
531
+ average_per_n_ca_c = coords.masked_fill(~coord_mask[..., None, None], 0).sum(1) / (
532
+ coord_mask.sum(-1)[..., None, None] + 1e-8
533
+ )
534
+ affine_from_average = atom3_to_backbone_affine(
535
+ average_per_n_ca_c.float()
536
+ ).as_matrix()
537
+
538
+ B, S, _, _ = coords.shape
539
+ assert isinstance(B, int)
540
+ assert isinstance(S, int)
541
+ affine_rot_mats = affine_from_average.rot.tensor[..., None, :].expand(B, S, 9)
542
+ affine_trans = affine_from_average.trans[..., None, :].expand(B, S, 3)
543
+
544
+ # We use the identity rotation whereever we have no coordinates. This is
545
+ # important because otherwise the rotation matrices will be all zeros, which
546
+ # will cause collapse in the distance/direction attention mechanism.
547
+ identity_rot = RotationMatrix.identity(
548
+ (B, S), dtype=torch.float32, device=coords.device, requires_grad=False
549
+ )
550
+ affine_rot_mats = affine_rot_mats.where(
551
+ coord_mask.any(-1)[..., None, None], identity_rot.tensor
552
+ )
553
+ black_hole_affine = Affine3D(affine_trans, RotationMatrix(affine_rot_mats))
554
+
555
+ affine = atom3_to_backbone_affine(coords.float())
556
+ affine = Affine3D.from_tensor(
557
+ affine.tensor.where(coord_mask[..., None], black_hole_affine.tensor)
558
+ )
559
+
560
+ return affine, coord_mask
561
+
esmfold2_aligner.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import Field, replace
4
+ from typing import Any, ClassVar, Protocol, TypeVar
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from .esmfold2_protein_structure import compute_affine_and_rmsd
10
+
11
+
12
+ class Alignable(Protocol):
13
+ # Trick to detect whether an object is a dataclass
14
+ __dataclass_fields__: ClassVar[dict[str, Field[Any]]]
15
+
16
+ @property
17
+ def atom37_positions(self) -> np.ndarray: # type: ignore
18
+ pass
19
+
20
+ @property
21
+ def atom37_mask(self) -> np.ndarray: # type: ignore
22
+ pass
23
+
24
+ def __len__(self) -> int: ...
25
+
26
+
27
+ T = TypeVar("T", bound=Alignable)
28
+
29
+
30
+ class Aligner:
31
+ def __init__(
32
+ self,
33
+ mobile: Alignable,
34
+ target: Alignable,
35
+ only_use_backbone: bool = False,
36
+ use_reflection: bool = False,
37
+ ):
38
+ """
39
+ Aligns a mobile protein chain against a target protein chain.
40
+
41
+ Args:
42
+ mobile (ProteinChain): Protein chain to be aligned.
43
+ target (ProteinChain): Protein chain target.
44
+ only_use_backbone (bool): Whether to only use backbone atoms.
45
+ use_reflection (bool): Whether to align to target reflection.
46
+ """
47
+ # Check proteins must have same number of residues
48
+ assert len(mobile) == len(target)
49
+
50
+ # Determine overlapping atoms
51
+ joint_atom37_mask = mobile.atom37_mask.astype(bool) & target.atom37_mask.astype(
52
+ bool
53
+ )
54
+
55
+ # Backbone atoms are first sites in atom37 representation
56
+ if only_use_backbone:
57
+ joint_atom37_mask[:, 3:] = False
58
+
59
+ # Extract matching atom positions and convert to batched tensors
60
+ mobile_atom_tensor = (
61
+ torch.from_numpy(mobile.atom37_positions).type(torch.double).unsqueeze(0)
62
+ )
63
+ target_atom_tensor = (
64
+ torch.from_numpy(target.atom37_positions).type(torch.double).unsqueeze(0)
65
+ )
66
+ joint_atom37_mask = (
67
+ torch.from_numpy(joint_atom37_mask).type(torch.bool).unsqueeze(0)
68
+ )
69
+
70
+ # If using reflection flip target
71
+ if use_reflection:
72
+ target_atom_tensor = -target_atom_tensor
73
+
74
+ # Compute alignment and rmsd
75
+ affine3D, rmsd = compute_affine_and_rmsd(
76
+ mobile_atom_tensor, target_atom_tensor, atom_exists_mask=joint_atom37_mask
77
+ )
78
+ self._affine3D = affine3D
79
+ self._rmsd = rmsd.item()
80
+
81
+ @property
82
+ def rmsd(self):
83
+ return self._rmsd
84
+
85
+ def apply(self, mobile: T) -> T:
86
+ """Apply alignment to a protein chain"""
87
+ # Extract atom positions and convert to batched tensors
88
+ mobile_atom_tensor = (
89
+ torch.from_numpy(mobile.atom37_positions[mobile.atom37_mask])
90
+ .type(torch.float32)
91
+ .unsqueeze(0)
92
+ )
93
+
94
+ # Transform atom arrays
95
+ aligned_atom_tensor = self._affine3D.apply(mobile_atom_tensor).squeeze(0)
96
+
97
+ # Rebuild atom37 positions
98
+ aligned_atom37_positions = np.full_like(mobile.atom37_positions, np.nan)
99
+ aligned_atom37_positions[mobile.atom37_mask] = aligned_atom_tensor
100
+
101
+ return replace(mobile, atom37_positions=aligned_atom37_positions)
102
+
esmfold2_atom_indexer.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from .esmfold2_protein_structure import index_by_atom_name
4
+
5
+
6
+ class AtomIndexer:
7
+ def __init__(self, structure, property: str, dim: int):
8
+ self.structure = structure
9
+ self.property = property
10
+ self.dim = dim
11
+
12
+ def __getitem__(self, atom_names: str | list[str]) -> np.ndarray:
13
+ return index_by_atom_name(
14
+ getattr(self.structure, self.property), atom_names, self.dim
15
+ )
16
+
esmfold2_conformers.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CCD conformer loading utilities.
2
+
3
+ Loads idealized conformer coordinates from a CCD pickle file containing RDKit molecules.
4
+ Conformer priority follows AF3 Section 2.8: Computed > Ideal > first available.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+ import pickle
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ from .esmfold2_constants import RES_TYPE_TO_CCD
17
+
18
+ if os.environ.get("ESMCFOLD_CCD_PATH"):
19
+ CCD_PICKLE_PATH = Path(os.environ["ESMCFOLD_CCD_PATH"])
20
+ else:
21
+ CCD_PICKLE_PATH = None
22
+
23
+
24
+ # Lazily loaded CCD dictionary
25
+ _CCD_MOLECULES: dict | None = None
26
+
27
+ # Caches
28
+ _CCD_CONFORMERS: dict[str, dict[str, np.ndarray]] = {}
29
+ _CCD_ATOM_CACHE: dict[str, list[tuple[str, str, int]]] = {}
30
+ _CCD_BONDS_CACHE: dict[str, list[tuple[str, str]]] = {}
31
+ _CCD_LEAVING_ATOMS_CACHE: dict[str, set[str]] = {}
32
+ _IDEALIZED_POS_CACHE: dict[tuple[int, str], np.ndarray | None] = {}
33
+ _LIGAND_IDEALIZED_POS_CACHE: dict[tuple[str, str], np.ndarray | None] = {}
34
+
35
+
36
+ def load_ccd(cache_dir: Path | str | None = None) -> dict:
37
+ """Load CCD molecules from pickle file, downloading if needed.
38
+
39
+ Args:
40
+ cache_dir: Directory to cache the downloaded CCD pickle.
41
+ If None, uses CCD_PICKLE_PATH env var or downloads to ~/.cache/esmcfold/.
42
+ """
43
+ global _CCD_MOLECULES
44
+ if _CCD_MOLECULES is not None:
45
+ return _CCD_MOLECULES
46
+
47
+ # Determine pickle path
48
+ if CCD_PICKLE_PATH is not None and CCD_PICKLE_PATH.exists():
49
+ pkl_path = CCD_PICKLE_PATH
50
+ elif cache_dir is not None:
51
+ cache_dir = Path(cache_dir)
52
+ cache_dir.mkdir(parents=True, exist_ok=True)
53
+ pkl_path = cache_dir / "ccd.pkl"
54
+ else:
55
+ try:
56
+ pkl_path = Path(
57
+ hf_hub_download(repo_id="biohub/ESMFold2", filename="ccd.pkl")
58
+ )
59
+ except Exception as e:
60
+ raise FileNotFoundError(
61
+ f"Failed to download CCD pickle file from Hugging Face repository: {e}"
62
+ )
63
+
64
+ if not pkl_path.exists():
65
+ raise FileNotFoundError(
66
+ f"CCD pickle file not found: {pkl_path}. Please set the ESMCFOLD_CCD_PATH environment variable to the path of a valid CCD pickle file or download the file from the Hugging Face repository."
67
+ )
68
+
69
+ print(f"Loading CCD dictionary from {pkl_path}")
70
+ with open(pkl_path, "rb") as f:
71
+ _CCD_MOLECULES = pickle.load(f)
72
+
73
+ if _CCD_MOLECULES is None:
74
+ _CCD_MOLECULES = {}
75
+
76
+ return _CCD_MOLECULES
77
+
78
+
79
+ def _get_ccd_molecules() -> dict:
80
+ """Get CCD molecules, loading lazily on first call."""
81
+ global _CCD_MOLECULES
82
+ if _CCD_MOLECULES is None:
83
+ return load_ccd()
84
+ return _CCD_MOLECULES
85
+
86
+
87
+ def _get_ccd_mol_with_significant_h(comp_id: str):
88
+ """Get CCD molecule with only chemically significant hydrogens.
89
+
90
+ Returns (mol, conformer) tuple or (None, None) if not available.
91
+ """
92
+ ccd = _get_ccd_molecules()
93
+ if comp_id not in ccd:
94
+ return None, None
95
+
96
+ mol = ccd[comp_id]
97
+ if mol.GetNumConformers() == 0:
98
+ return None, None
99
+
100
+ # Find the "Computed" conformer (RDKit ETKDGv3), fall back to "Ideal"
101
+ conf_idx = 0
102
+ for i, c in enumerate(mol.GetConformers()):
103
+ props = c.GetPropsAsDict()
104
+ if props.get("name") == "Computed":
105
+ conf_idx = i
106
+ break
107
+ else:
108
+ for i, c in enumerate(mol.GetConformers()):
109
+ props = c.GetPropsAsDict()
110
+ if props.get("name") == "Ideal":
111
+ conf_idx = i
112
+ break
113
+
114
+ from rdkit import Chem
115
+
116
+ mol_no_h = Chem.RemoveHs(mol, sanitize=False)
117
+
118
+ if mol_no_h.GetNumConformers() == 0:
119
+ return None, None
120
+
121
+ return mol_no_h, mol_no_h.GetConformer(
122
+ min(conf_idx, mol_no_h.GetNumConformers() - 1)
123
+ )
124
+
125
+
126
+ def get_ccd_conformer(comp_id: str) -> dict[str, np.ndarray] | None:
127
+ """Get idealized conformer as dict of atom_name -> position [3].
128
+
129
+ Conformer priority: Computed > Ideal > first available.
130
+ """
131
+ if comp_id in _CCD_CONFORMERS:
132
+ cached = _CCD_CONFORMERS[comp_id]
133
+ return cached if cached else None
134
+
135
+ mol, conf = _get_ccd_mol_with_significant_h(comp_id)
136
+ if mol is None or conf is None:
137
+ _CCD_CONFORMERS[comp_id] = {}
138
+ return None
139
+
140
+ conformer: dict[str, np.ndarray] = {}
141
+ for atom in mol.GetAtoms():
142
+ props = atom.GetPropsAsDict()
143
+ atom_name = props.get("name")
144
+ if not isinstance(atom_name, str) or not atom_name:
145
+ continue
146
+ idx = atom.GetIdx()
147
+ pos = conf.GetAtomPosition(idx)
148
+ conformer[atom_name] = np.array([pos.x, pos.y, pos.z], dtype=np.float32)
149
+
150
+ _CCD_CONFORMERS[comp_id] = conformer
151
+ return conformer if conformer else None
152
+
153
+
154
+ def get_idealized_atom_pos(res_type: int, atom_name: str) -> np.ndarray | None:
155
+ """Get idealized position for a standard residue atom.
156
+
157
+ Uses res_type index to look up CCD component, then returns position.
158
+ Returns None if not found.
159
+ """
160
+ cache_key = (res_type, atom_name)
161
+ if cache_key in _IDEALIZED_POS_CACHE:
162
+ return _IDEALIZED_POS_CACHE[cache_key]
163
+
164
+ comp_id = RES_TYPE_TO_CCD.get(res_type)
165
+ if comp_id:
166
+ ccd_conformer = get_ccd_conformer(comp_id)
167
+ if ccd_conformer and atom_name in ccd_conformer:
168
+ pos = ccd_conformer[atom_name]
169
+ _IDEALIZED_POS_CACHE[cache_key] = pos
170
+ return pos
171
+
172
+ _IDEALIZED_POS_CACHE[cache_key] = None
173
+ return None
174
+
175
+
176
+ def get_ligand_idealized_atom_pos(res_name: str, atom_name: str) -> np.ndarray | None:
177
+ """Get idealized position for a ligand/modified residue atom.
178
+
179
+ Returns None if not found.
180
+ """
181
+ cache_key = (res_name, atom_name)
182
+ if cache_key in _LIGAND_IDEALIZED_POS_CACHE:
183
+ return _LIGAND_IDEALIZED_POS_CACHE[cache_key]
184
+
185
+ ccd_conformer = get_ccd_conformer(res_name)
186
+ if ccd_conformer and atom_name in ccd_conformer:
187
+ pos = ccd_conformer[atom_name]
188
+ _LIGAND_IDEALIZED_POS_CACHE[cache_key] = pos
189
+ return pos
190
+
191
+ _LIGAND_IDEALIZED_POS_CACHE[cache_key] = None
192
+ return None
193
+
194
+
195
+ def get_ligand_ccd_atoms_with_charges(
196
+ comp_id: str,
197
+ ) -> list[tuple[str, str, int]] | None:
198
+ """Get list of (atom_name, element, charge) for a CCD component.
199
+
200
+ Uses RDKit RemoveHs(sanitize=False) to keep chemically significant hydrogens.
201
+ Returns None if CCD data not available.
202
+ """
203
+ if comp_id in _CCD_ATOM_CACHE:
204
+ cached = _CCD_ATOM_CACHE[comp_id]
205
+ return cached if cached else None
206
+
207
+ mol, _ = _get_ccd_mol_with_significant_h(comp_id)
208
+ if mol is None:
209
+ _CCD_ATOM_CACHE[comp_id] = []
210
+ return None
211
+
212
+ atoms: list[tuple[str, str, int]] = []
213
+ for atom in mol.GetAtoms():
214
+ props = atom.GetPropsAsDict()
215
+ atom_name = props.get("name")
216
+ if not isinstance(atom_name, str) or not atom_name:
217
+ continue
218
+ element = atom.GetSymbol()
219
+ charge = atom.GetFormalCharge()
220
+ atoms.append((atom_name, element, charge))
221
+
222
+ _CCD_ATOM_CACHE[comp_id] = atoms
223
+ return atoms if atoms else None
224
+
225
+
226
+ def get_ligand_ccd_bonds(comp_id: str) -> list[tuple[str, str]] | None:
227
+ """Get list of (atom1_name, atom2_name) bonds for a CCD component.
228
+
229
+ Returns None if CCD data not available.
230
+ """
231
+ if comp_id in _CCD_BONDS_CACHE:
232
+ cached = _CCD_BONDS_CACHE[comp_id]
233
+ return cached if cached else None
234
+
235
+ mol, _ = _get_ccd_mol_with_significant_h(comp_id)
236
+ if mol is None:
237
+ _CCD_BONDS_CACHE[comp_id] = []
238
+ return None
239
+
240
+ # Get included atom names
241
+ included_atoms = set()
242
+ for atom in mol.GetAtoms():
243
+ props = atom.GetPropsAsDict()
244
+ atom_name = props.get("name")
245
+ if isinstance(atom_name, str) and atom_name:
246
+ included_atoms.add(atom_name)
247
+
248
+ bonds: list[tuple[str, str]] = []
249
+ for bond in mol.GetBonds():
250
+ a1 = bond.GetBeginAtom()
251
+ a2 = bond.GetEndAtom()
252
+ n1 = a1.GetPropsAsDict().get("name")
253
+ n2 = a2.GetPropsAsDict().get("name")
254
+ if (
255
+ isinstance(n1, str)
256
+ and isinstance(n2, str)
257
+ and n1
258
+ and n2
259
+ and n1 in included_atoms
260
+ and n2 in included_atoms
261
+ ):
262
+ bonds.append((n1, n2))
263
+
264
+ _CCD_BONDS_CACHE[comp_id] = bonds
265
+ return bonds if bonds else None
266
+
267
+
268
+ def get_ccd_leaving_atoms(comp_id: str) -> set[str]:
269
+ """Get set of atom names marked as leaving atoms in CCD.
270
+
271
+ Leaving atoms are removed during polymerization (e.g., OP3 in nucleotides).
272
+ """
273
+ if comp_id in _CCD_LEAVING_ATOMS_CACHE:
274
+ return _CCD_LEAVING_ATOMS_CACHE[comp_id]
275
+
276
+ ccd = _get_ccd_molecules()
277
+ if comp_id not in ccd:
278
+ _CCD_LEAVING_ATOMS_CACHE[comp_id] = set()
279
+ return set()
280
+
281
+ mol = ccd[comp_id]
282
+ leaving_atoms = set()
283
+ for atom in mol.GetAtoms():
284
+ if atom.HasProp("leaving_atom"):
285
+ if atom.GetProp("leaving_atom") == "1":
286
+ name = atom.GetProp("name") if atom.HasProp("name") else ""
287
+ if name:
288
+ leaving_atoms.add(name)
289
+
290
+ _CCD_LEAVING_ATOMS_CACHE[comp_id] = leaving_atoms
291
+ return leaving_atoms
292
+
esmfold2_constants.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Constants for the ESMFold2 input pipeline.
2
+
3
+ Includes molecule types, residue types, vocabularies, atom lists, and element data.
4
+ """
5
+
6
+ # =============================================================================
7
+ # Molecule types
8
+ # =============================================================================
9
+
10
+ MOL_TYPE_PROTEIN = 0
11
+ MOL_TYPE_DNA = 1
12
+ MOL_TYPE_RNA = 2
13
+ MOL_TYPE_NONPOLYMER = 3
14
+
15
+ # =============================================================================
16
+ # Residue type indices
17
+ # =============================================================================
18
+
19
+ # Standard amino acids (indices 2-21), MSE mapped to MET
20
+ PROTEIN_RESIDUE_TO_RES_TYPE = {
21
+ "ALA": 2,
22
+ "ARG": 3,
23
+ "ASN": 4,
24
+ "ASP": 5,
25
+ "CYS": 6,
26
+ "GLN": 7,
27
+ "GLU": 8,
28
+ "GLY": 9,
29
+ "HIS": 10,
30
+ "ILE": 11,
31
+ "LEU": 12,
32
+ "LYS": 13,
33
+ "MET": 14,
34
+ "PHE": 15,
35
+ "PRO": 16,
36
+ "SER": 17,
37
+ "THR": 18,
38
+ "TRP": 19,
39
+ "TYR": 20,
40
+ "VAL": 21,
41
+ "MSE": 14, # Selenomethionine -> MET
42
+ }
43
+ PROTEIN_UNK_RES_TYPE = 22
44
+
45
+ # RNA nucleotides (indices 23-26, unknown=27)
46
+ RNA_RESIDUE_TO_RES_TYPE = {"A": 23, "G": 24, "C": 25, "U": 26}
47
+ RNA_UNK_RES_TYPE = 27
48
+
49
+ # DNA nucleotides (indices 28-31, unknown=32)
50
+ DNA_RESIDUE_TO_RES_TYPE = {"DA": 28, "DG": 29, "DC": 30, "DT": 31}
51
+ DNA_UNK_RES_TYPE = 32
52
+
53
+ GAP_RES_TYPE = 32
54
+
55
+ # =============================================================================
56
+ # Vocabularies
57
+ # =============================================================================
58
+
59
+ # 3-letter to 1-letter codes for proteins
60
+ PROTEIN_3TO1 = {
61
+ "ALA": "A",
62
+ "ARG": "R",
63
+ "ASN": "N",
64
+ "ASP": "D",
65
+ "CYS": "C",
66
+ "GLN": "Q",
67
+ "GLU": "E",
68
+ "GLY": "G",
69
+ "HIS": "H",
70
+ "ILE": "I",
71
+ "LEU": "L",
72
+ "LYS": "K",
73
+ "MET": "M",
74
+ "PHE": "F",
75
+ "PRO": "P",
76
+ "SER": "S",
77
+ "THR": "T",
78
+ "TRP": "W",
79
+ "TYR": "Y",
80
+ "VAL": "V",
81
+ "MSE": "M",
82
+ }
83
+
84
+ # 1-letter to 3-letter codes
85
+ PROTEIN_1TO3 = {v: k for k, v in PROTEIN_3TO1.items() if k != "MSE"}
86
+ PROTEIN_1TO3["X"] = "UNK"
87
+
88
+ # DNA 1-letter to CCD code
89
+ DNA_1TO3 = {"A": "DA", "T": "DT", "C": "DC", "G": "DG"}
90
+
91
+ # RNA 1-letter to CCD code
92
+ RNA_1TO3 = {"A": "A", "U": "U", "C": "C", "G": "G"}
93
+
94
+ # ESM-2 input_ids vocabulary for proteins
95
+ ESM_PROTEIN_VOCAB = {
96
+ "L": 4,
97
+ "A": 5,
98
+ "G": 6,
99
+ "V": 7,
100
+ "S": 8,
101
+ "E": 9,
102
+ "R": 10,
103
+ "T": 11,
104
+ "I": 12,
105
+ "D": 13,
106
+ "P": 14,
107
+ "K": 15,
108
+ "Q": 16,
109
+ "N": 17,
110
+ "F": 18,
111
+ "Y": 19,
112
+ "M": 20,
113
+ "H": 21,
114
+ "W": 22,
115
+ "C": 23,
116
+ "X": 3, # Unknown
117
+ }
118
+
119
+ # For DNA/RNA/ligands
120
+ DNA_RNA_LIGAND_INPUT_ID = 24
121
+
122
+ # MSA tokens
123
+ MSA_PAD_TOKEN_ID = 0
124
+ MSA_GAP_TOKEN_ID = 1 # Gap/insertion token for MSA
125
+
126
+ # res_type int -> CCD component ID (for conformer lookup)
127
+ RES_TYPE_TO_CCD = {
128
+ # Proteins (2-22)
129
+ 2: "ALA",
130
+ 3: "ARG",
131
+ 4: "ASN",
132
+ 5: "ASP",
133
+ 6: "CYS",
134
+ 7: "GLN",
135
+ 8: "GLU",
136
+ 9: "GLY",
137
+ 10: "HIS",
138
+ 11: "ILE",
139
+ 12: "LEU",
140
+ 13: "LYS",
141
+ 14: "MET",
142
+ 15: "PHE",
143
+ 16: "PRO",
144
+ 17: "SER",
145
+ 18: "THR",
146
+ 19: "TRP",
147
+ 20: "TYR",
148
+ 21: "VAL",
149
+ 22: "UNK",
150
+ # RNA (23-27)
151
+ 23: "A",
152
+ 24: "G",
153
+ 25: "C",
154
+ 26: "U",
155
+ 27: "N",
156
+ # DNA (28-32)
157
+ 28: "DA",
158
+ 29: "DG",
159
+ 30: "DC",
160
+ 31: "DT",
161
+ 32: "DN",
162
+ }
163
+
164
+ # =============================================================================
165
+ # Charged atoms at physiological pH
166
+ # =============================================================================
167
+
168
+ CHARGED_ATOMS: dict[tuple[str, str], int] = {
169
+ ("LYS", "NZ"): 1,
170
+ ("ARG", "NH2"): 1,
171
+ ("HIS", "ND1"): 1,
172
+ ("PO4", "O2"): -1,
173
+ ("PO4", "O3"): -1,
174
+ ("PO4", "O4"): -1,
175
+ ("SO4", "O3"): -1,
176
+ ("SO4", "O4"): -1,
177
+ ("MG", "MG"): 2,
178
+ ("ZN", "ZN"): 2,
179
+ ("CA", "CA"): 2,
180
+ ("FE2", "FE"): 2,
181
+ ("MN", "MN"): 2,
182
+ ("CO", "CO"): 2,
183
+ ("NCO", "CO"): 3,
184
+ ("CU", "CU"): 2,
185
+ ("NI", "NI"): 2,
186
+ ("K", "K"): 1,
187
+ ("NA", "NA"): 1,
188
+ ("CD", "CD"): 2,
189
+ ("CL", "CL"): -1,
190
+ ("ACT", "OXT"): -1,
191
+ ("NAD", "O2N"): -1,
192
+ ("NAD", "N1N"): 1,
193
+ ("NAP", "O2N"): -1,
194
+ ("NAP", "N1N"): 1,
195
+ ("IMD", "N3"): 1,
196
+ ("SAM", "SD"): 1,
197
+ ("FE", "FE"): 3,
198
+ ("A1BH3", "N3"): 1,
199
+ }
200
+
201
+ # =============================================================================
202
+ # Element atomic numbers (Z=1 to 92)
203
+ # =============================================================================
204
+
205
+ ELEMENT_TO_ATOMIC_NUM = {
206
+ "H": 1,
207
+ "LI": 3,
208
+ "BE": 4,
209
+ "B": 5,
210
+ "C": 6,
211
+ "N": 7,
212
+ "O": 8,
213
+ "F": 9,
214
+ "NE": 10,
215
+ "NA": 11,
216
+ "MG": 12,
217
+ "AL": 13,
218
+ "SI": 14,
219
+ "P": 15,
220
+ "S": 16,
221
+ "CL": 17,
222
+ "AR": 18,
223
+ "K": 19,
224
+ "CA": 20,
225
+ "SC": 21,
226
+ "TI": 22,
227
+ "V": 23,
228
+ "CR": 24,
229
+ "MN": 25,
230
+ "FE": 26,
231
+ "CO": 27,
232
+ "NI": 28,
233
+ "CU": 29,
234
+ "ZN": 30,
235
+ "GA": 31,
236
+ "GE": 32,
237
+ "AS": 33,
238
+ "SE": 34,
239
+ "BR": 35,
240
+ "KR": 36,
241
+ "RB": 37,
242
+ "SR": 38,
243
+ "Y": 39,
244
+ "ZR": 40,
245
+ "NB": 41,
246
+ "MO": 42,
247
+ "TC": 43,
248
+ "RU": 44,
249
+ "RH": 45,
250
+ "PD": 46,
251
+ "AG": 47,
252
+ "CD": 48,
253
+ "IN": 49,
254
+ "SN": 50,
255
+ "SB": 51,
256
+ "TE": 52,
257
+ "I": 53,
258
+ "XE": 54,
259
+ "CS": 55,
260
+ "BA": 56,
261
+ "LA": 57,
262
+ "CE": 58,
263
+ "PR": 59,
264
+ "ND": 60,
265
+ "PM": 61,
266
+ "SM": 62,
267
+ "EU": 63,
268
+ "GD": 64,
269
+ "TB": 65,
270
+ "DY": 66,
271
+ "HO": 67,
272
+ "ER": 68,
273
+ "TM": 69,
274
+ "YB": 70,
275
+ "LU": 71,
276
+ "HF": 72,
277
+ "TA": 73,
278
+ "W": 74,
279
+ "RE": 75,
280
+ "OS": 76,
281
+ "IR": 77,
282
+ "PT": 78,
283
+ "AU": 79,
284
+ "HG": 80,
285
+ "TL": 81,
286
+ "PB": 82,
287
+ "BI": 83,
288
+ "PO": 84,
289
+ "AT": 85,
290
+ "RN": 86,
291
+ "FR": 87,
292
+ "RA": 88,
293
+ "AC": 89,
294
+ "TH": 90,
295
+ "PA": 91,
296
+ "U": 92,
297
+ }
298
+
299
+ # Inverse mapping: atomic number → element symbol
300
+ ELEMENT_NUMBER_TO_SYMBOL = {v: k for k, v in ELEMENT_TO_ATOMIC_NUM.items()}
301
+
302
+ # =============================================================================
303
+ # Standard heavy atoms per residue type
304
+ # =============================================================================
305
+
306
+ PROTEIN_HEAVY_ATOMS = {
307
+ "ALA": ["N", "CA", "C", "O", "CB"],
308
+ "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
309
+ "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"],
310
+ "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"],
311
+ "CYS": ["N", "CA", "C", "O", "CB", "SG"],
312
+ "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"],
313
+ "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"],
314
+ "GLY": ["N", "CA", "C", "O"],
315
+ "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"],
316
+ "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"],
317
+ "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"],
318
+ "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"],
319
+ "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
320
+ "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
321
+ "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD"],
322
+ "SER": ["N", "CA", "C", "O", "CB", "OG"],
323
+ "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2"],
324
+ "TRP": [
325
+ "N",
326
+ "CA",
327
+ "C",
328
+ "O",
329
+ "CB",
330
+ "CG",
331
+ "CD1",
332
+ "CD2",
333
+ "NE1",
334
+ "CE2",
335
+ "CE3",
336
+ "CZ2",
337
+ "CZ3",
338
+ "CH2",
339
+ ],
340
+ "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"],
341
+ "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2"],
342
+ "MSE": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
343
+ "UNK": ["N", "CA", "C", "O"],
344
+ }
345
+
346
+ DNA_HEAVY_ATOMS = {
347
+ "DA": [
348
+ "P",
349
+ "OP1",
350
+ "OP2",
351
+ "O5'",
352
+ "C5'",
353
+ "C4'",
354
+ "O4'",
355
+ "C3'",
356
+ "O3'",
357
+ "C2'",
358
+ "C1'",
359
+ "N9",
360
+ "C8",
361
+ "N7",
362
+ "C5",
363
+ "C6",
364
+ "N6",
365
+ "N1",
366
+ "C2",
367
+ "N3",
368
+ "C4",
369
+ ],
370
+ "DG": [
371
+ "P",
372
+ "OP1",
373
+ "OP2",
374
+ "O5'",
375
+ "C5'",
376
+ "C4'",
377
+ "O4'",
378
+ "C3'",
379
+ "O3'",
380
+ "C2'",
381
+ "C1'",
382
+ "N9",
383
+ "C8",
384
+ "N7",
385
+ "C5",
386
+ "C6",
387
+ "O6",
388
+ "N1",
389
+ "C2",
390
+ "N2",
391
+ "N3",
392
+ "C4",
393
+ ],
394
+ "DC": [
395
+ "P",
396
+ "OP1",
397
+ "OP2",
398
+ "O5'",
399
+ "C5'",
400
+ "C4'",
401
+ "O4'",
402
+ "C3'",
403
+ "O3'",
404
+ "C2'",
405
+ "C1'",
406
+ "N1",
407
+ "C2",
408
+ "O2",
409
+ "N3",
410
+ "C4",
411
+ "N4",
412
+ "C5",
413
+ "C6",
414
+ ],
415
+ "DT": [
416
+ "P",
417
+ "OP1",
418
+ "OP2",
419
+ "O5'",
420
+ "C5'",
421
+ "C4'",
422
+ "O4'",
423
+ "C3'",
424
+ "O3'",
425
+ "C2'",
426
+ "C1'",
427
+ "N1",
428
+ "C2",
429
+ "O2",
430
+ "N3",
431
+ "C4",
432
+ "O4",
433
+ "C5",
434
+ "C7",
435
+ "C6",
436
+ ],
437
+ }
438
+
439
+ RNA_HEAVY_ATOMS = {
440
+ "A": [
441
+ "P",
442
+ "OP1",
443
+ "OP2",
444
+ "O5'",
445
+ "C5'",
446
+ "C4'",
447
+ "O4'",
448
+ "C3'",
449
+ "O3'",
450
+ "C2'",
451
+ "O2'",
452
+ "C1'",
453
+ "N9",
454
+ "C8",
455
+ "N7",
456
+ "C5",
457
+ "C6",
458
+ "N6",
459
+ "N1",
460
+ "C2",
461
+ "N3",
462
+ "C4",
463
+ ],
464
+ "G": [
465
+ "P",
466
+ "OP1",
467
+ "OP2",
468
+ "O5'",
469
+ "C5'",
470
+ "C4'",
471
+ "O4'",
472
+ "C3'",
473
+ "O3'",
474
+ "C2'",
475
+ "O2'",
476
+ "C1'",
477
+ "N9",
478
+ "C8",
479
+ "N7",
480
+ "C5",
481
+ "C6",
482
+ "O6",
483
+ "N1",
484
+ "C2",
485
+ "N2",
486
+ "N3",
487
+ "C4",
488
+ ],
489
+ "C": [
490
+ "P",
491
+ "OP1",
492
+ "OP2",
493
+ "O5'",
494
+ "C5'",
495
+ "C4'",
496
+ "O4'",
497
+ "C3'",
498
+ "O3'",
499
+ "C2'",
500
+ "O2'",
501
+ "C1'",
502
+ "N1",
503
+ "C2",
504
+ "O2",
505
+ "N3",
506
+ "C4",
507
+ "N4",
508
+ "C5",
509
+ "C6",
510
+ ],
511
+ "U": [
512
+ "P",
513
+ "OP1",
514
+ "OP2",
515
+ "O5'",
516
+ "C5'",
517
+ "C4'",
518
+ "O4'",
519
+ "C3'",
520
+ "O3'",
521
+ "C2'",
522
+ "O2'",
523
+ "C1'",
524
+ "N1",
525
+ "C2",
526
+ "O2",
527
+ "N3",
528
+ "C4",
529
+ "O4",
530
+ "C5",
531
+ "C6",
532
+ ],
533
+ }
534
+
535
+ # Unknown nucleotide backbone atoms
536
+ DNA_BACKBONE_ATOMS = [
537
+ "P",
538
+ "OP1",
539
+ "OP2",
540
+ "O5'",
541
+ "C5'",
542
+ "C4'",
543
+ "O4'",
544
+ "C3'",
545
+ "O3'",
546
+ "C2'",
547
+ "C1'",
548
+ ]
549
+ RNA_BACKBONE_ATOMS = [
550
+ "P",
551
+ "OP1",
552
+ "OP2",
553
+ "O5'",
554
+ "C5'",
555
+ "C4'",
556
+ "O4'",
557
+ "C3'",
558
+ "O3'",
559
+ "C2'",
560
+ "O2'",
561
+ "C1'",
562
+ ]
563
+
esmfold2_constants_esm3.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import cache
3
+ from pathlib import Path
4
+
5
+ from huggingface_hub import snapshot_download
6
+
7
+ SEQUENCE_BOS_TOKEN = 0
8
+ SEQUENCE_PAD_TOKEN = 1
9
+ SEQUENCE_EOS_TOKEN = 2
10
+ SEQUENCE_CHAINBREAK_TOKEN = 31
11
+ SEQUENCE_MASK_TOKEN = 32
12
+
13
+ VQVAE_CODEBOOK_SIZE = 4096
14
+ VQVAE_SPECIAL_TOKENS = {
15
+ "MASK": VQVAE_CODEBOOK_SIZE,
16
+ "EOS": VQVAE_CODEBOOK_SIZE + 1,
17
+ "BOS": VQVAE_CODEBOOK_SIZE + 2,
18
+ "PAD": VQVAE_CODEBOOK_SIZE + 3,
19
+ "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4,
20
+ }
21
+ VQVAE_DIRECTION_LOSS_BINS = 16
22
+ VQVAE_PAE_BINS = 64
23
+ VQVAE_MAX_PAE_BIN = 31.0
24
+ VQVAE_PLDDT_BINS = 50
25
+
26
+ STRUCTURE_MASK_TOKEN = VQVAE_SPECIAL_TOKENS["MASK"]
27
+ STRUCTURE_BOS_TOKEN = VQVAE_SPECIAL_TOKENS["BOS"]
28
+ STRUCTURE_EOS_TOKEN = VQVAE_SPECIAL_TOKENS["EOS"]
29
+ STRUCTURE_PAD_TOKEN = VQVAE_SPECIAL_TOKENS["PAD"]
30
+ STRUCTURE_CHAINBREAK_TOKEN = VQVAE_SPECIAL_TOKENS["CHAINBREAK"]
31
+ STRUCTURE_UNDEFINED_TOKEN = 955
32
+
33
+ SASA_PAD_TOKEN = 0
34
+
35
+ SS8_PAD_TOKEN = 0
36
+
37
+ INTERPRO_PAD_TOKEN = 0
38
+
39
+ RESIDUE_PAD_TOKEN = 0
40
+
41
+ CHAIN_BREAK_STR = "|"
42
+
43
+ SEQUENCE_BOS_STR = "<cls>"
44
+ SEQUENCE_EOS_STR = "<eos>"
45
+
46
+ MASK_STR_SHORT = "_"
47
+ SEQUENCE_MASK_STR = "<mask>"
48
+ SASA_MASK_STR = "<unk>"
49
+ SS8_MASK_STR = "<unk>"
50
+
51
+ # fmt: off
52
+ SEQUENCE_VOCAB = [
53
+ "<cls>", "<pad>", "<eos>", "<unk>",
54
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
55
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
56
+ "O", ".", "-", "|",
57
+ "<mask>",
58
+ ]
59
+ # fmt: on
60
+
61
+ SEQUENCE_STANDARD_AA_MIN_TOKEN = 4 # L
62
+ SEQUENCE_STANDARD_AA_MAX_TOKEN = 24 # X (exclusive)
63
+
64
+ SSE_8CLASS_VOCAB = "GHITEBSC"
65
+ SSE_3CLASS_VOCAB = "HEC"
66
+ SSE_8CLASS_TO_3CLASS_MAP = {
67
+ "G": "H",
68
+ "H": "H",
69
+ "I": "H",
70
+ "T": "C",
71
+ "E": "E",
72
+ "B": "E",
73
+ "S": "C",
74
+ "C": "C",
75
+ }
76
+
77
+ SASA_DISCRETIZATION_BOUNDARIES = [
78
+ 0.8,
79
+ 4.0,
80
+ 9.6,
81
+ 16.4,
82
+ 24.5,
83
+ 32.9,
84
+ 42.0,
85
+ 51.5,
86
+ 61.2,
87
+ 70.9,
88
+ 81.6,
89
+ 93.3,
90
+ 107.2,
91
+ 125.4,
92
+ 151.4,
93
+ ]
94
+
95
+ MAX_RESIDUE_ANNOTATIONS = 16
96
+
97
+
98
+ TFIDF_VECTOR_SIZE = 58641
99
+
100
+ FUNCTION_TOKENS_DEPTH = 8
101
+
102
+
103
+ @staticmethod
104
+ @cache
105
+ def data_root(model: str):
106
+ if "INFRA_PROVIDER" in os.environ:
107
+ return Path("")
108
+ # Try to download from huggingface if it doesn't exist
109
+ if model.startswith("esm3"):
110
+ path = Path(snapshot_download(repo_id="biohub/esm3-sm-open-v1"))
111
+ elif model.startswith("esmc-300"):
112
+ path = Path(snapshot_download(repo_id="biohub/esmc-300m-2024-12"))
113
+ elif model.startswith("esmc-600"):
114
+ path = Path(snapshot_download(repo_id="biohub/esmc-600m-2024-12"))
115
+ elif model.startswith("esmc-6b"):
116
+ path = Path(snapshot_download(repo_id="biohub/esmc-6b-2024-12"))
117
+ else:
118
+ raise ValueError(f"{model=} is an invalid model name.")
119
+ return path
120
+
121
+
122
+ IN_REPO_DATA_FOLDER = Path(__file__).parents[2] / "data"
123
+
124
+ INTERPRO_ENTRY = IN_REPO_DATA_FOLDER / "entry_list_safety_29026.list"
125
+ INTERPRO_HIERARCHY = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
126
+ INTERPRO2GO = IN_REPO_DATA_FOLDER / "ParentChildTreeFile.txt"
127
+ INTERPRO_2ID = "data/tag_dict_4_safety_filtered.json"
128
+
129
+ LSH_TABLE_PATHS = {"8bit": "data/hyperplanes_8bit_58641.npz"}
130
+
131
+ KEYWORDS_VOCABULARY = (
132
+ IN_REPO_DATA_FOLDER / "keyword_vocabulary_safety_filtered_58641.txt"
133
+ )
134
+ KEYWORDS_IDF = IN_REPO_DATA_FOLDER / "keyword_idf_safety_filtered_58641.npy"
135
+
136
+ RESID_CSV = "data/uniref90_and_mgnify90_residue_annotations_gt_1k_proteins.csv"
137
+ INTERPRO2KEYWORDS = IN_REPO_DATA_FOLDER / "interpro_29026_to_keywords_58641.csv"
138
+
esmfold2_input_builder.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Sequence, TypeAlias, Union
3
+
4
+ import numpy as np
5
+
6
+ from .esmfold2_msa import MSA
7
+
8
+ # fmt: off
9
+ MSAInput: TypeAlias = Union[
10
+ MSA,
11
+ None,
12
+ ]
13
+ # fmt: on
14
+
15
+
16
+ @dataclass
17
+ class Modification:
18
+ position: int # zero-indexed
19
+ ccd: str
20
+ smiles: str | None = None # TODO(mlee): add smiles support
21
+
22
+
23
+ @dataclass
24
+ class ProteinInput:
25
+ id: str | list[str]
26
+ sequence: str
27
+ modifications: list[Modification] | None = None
28
+ msa: MSAInput = None
29
+
30
+
31
+ @dataclass
32
+ class RNAInput:
33
+ id: str | list[str]
34
+ sequence: str
35
+ modifications: list[Modification] | None = None
36
+
37
+
38
+ @dataclass
39
+ class DNAInput:
40
+ id: str | list[str]
41
+ sequence: str
42
+ modifications: list[Modification] | None = None
43
+
44
+
45
+ @dataclass
46
+ class LigandInput:
47
+ id: str | list[str]
48
+ smiles: str | None = None
49
+ ccd: list[str] | None = None
50
+
51
+
52
+ @dataclass
53
+ class DistogramConditioning:
54
+ chain_id: str
55
+ distogram: np.ndarray
56
+
57
+
58
+ @dataclass
59
+ class PocketConditioning:
60
+ binder_chain_id: str
61
+ contacts: list[tuple[str, int]]
62
+
63
+
64
+ @dataclass
65
+ class CovalentBond:
66
+ chain_id1: str
67
+ res_idx1: int
68
+ atom_idx1: int
69
+ chain_id2: str
70
+ res_idx2: int
71
+ atom_idx2: int
72
+
73
+
74
+ @dataclass
75
+ class StructurePredictionInput:
76
+ sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput]
77
+ pocket: PocketConditioning | None = None
78
+ distogram_conditioning: list[DistogramConditioning] | None = None
79
+ covalent_bonds: list[CovalentBond] | None = None
80
+
81
+
82
+ def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput):
83
+ def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]:
84
+ chain_data: dict[str, Any] = {
85
+ "sequence": seq_input.sequence,
86
+ "id": seq_input.id,
87
+ "type": chain_type,
88
+ }
89
+ if hasattr(seq_input, "modifications") and seq_input.modifications:
90
+ mods = [
91
+ {"position": mod.position, "ccd": mod.ccd}
92
+ for mod in seq_input.modifications
93
+ ]
94
+ chain_data["modifications"] = mods
95
+ if not hasattr(seq_input, "msa"):
96
+ pass
97
+ elif seq_input.msa is None:
98
+ chain_data["msa"] = None
99
+ elif isinstance(seq_input.msa, MSA):
100
+ chain_data["msa"] = {"sequences": seq_input.msa.sequences}
101
+ else:
102
+ error_msg = f"MSA must be None or MSA. Got {seq_input.msa} instead."
103
+ raise AttributeError(error_msg)
104
+ return chain_data
105
+
106
+ sequences = []
107
+ for seq_input in all_atom_input.sequences:
108
+ if isinstance(seq_input, ProteinInput):
109
+ sequences.append(create_chain_data(seq_input, "protein"))
110
+ elif isinstance(seq_input, RNAInput):
111
+ sequences.append(create_chain_data(seq_input, "rna"))
112
+ elif isinstance(seq_input, DNAInput):
113
+ sequences.append(create_chain_data(seq_input, "dna"))
114
+ elif isinstance(seq_input, LigandInput):
115
+ sequences.append(
116
+ {
117
+ "smiles": seq_input.smiles,
118
+ "id": seq_input.id,
119
+ "ccd": seq_input.ccd,
120
+ "type": "ligand",
121
+ }
122
+ )
123
+ else:
124
+ raise ValueError(f"Unsupported sequence input type: {type(seq_input)}")
125
+
126
+ result: dict[str, Any] = {"sequences": sequences}
127
+
128
+ if all_atom_input.covalent_bonds is not None:
129
+ result["covalent_bonds"] = [
130
+ {
131
+ "chain_id1": bond.chain_id1,
132
+ "res_idx1": bond.res_idx1,
133
+ "atom_idx1": bond.atom_idx1,
134
+ "chain_id2": bond.chain_id2,
135
+ "res_idx2": bond.res_idx2,
136
+ "atom_idx2": bond.atom_idx2,
137
+ }
138
+ for bond in all_atom_input.covalent_bonds
139
+ ]
140
+
141
+ if all_atom_input.pocket is not None:
142
+ result["pocket"] = {
143
+ "binder_chain_id": all_atom_input.pocket.binder_chain_id,
144
+ "contacts": all_atom_input.pocket.contacts,
145
+ }
146
+
147
+ if all_atom_input.distogram_conditioning is not None:
148
+ result["distogram_conditioning"] = [
149
+ {"chain_id": disto.chain_id, "distogram": disto.distogram.tolist()}
150
+ for disto in all_atom_input.distogram_conditioning
151
+ ]
152
+
153
+ return result
154
+
155
+
156
+ def deserialize_structure_prediction_input(
157
+ data: dict[str, Any],
158
+ ) -> StructurePredictionInput:
159
+ """Inverse of :func:`serialize_structure_prediction_input`.
160
+
161
+ Reconstructs a :class:`StructurePredictionInput` from the JSON-safe dict
162
+ produced by ``serialize_structure_prediction_input``. Values round-trip;
163
+ ``DistogramConditioning.distogram`` dtype follows from JSON (``int64``
164
+ for integer entries, ``float64`` for floats) — cast back to the original
165
+ dtype if downstream code requires a specific one.
166
+ """
167
+
168
+ def _mods(chain: dict[str, Any]) -> list[Modification] | None:
169
+ raw = chain.get("modifications")
170
+ if not raw:
171
+ return None
172
+ return [Modification(position=m["position"], ccd=m["ccd"]) for m in raw]
173
+
174
+ def _msa(chain: dict[str, Any]) -> MSAInput:
175
+ if "msa" not in chain or chain["msa"] is None:
176
+ return None
177
+ msa_blk = chain["msa"]
178
+ if isinstance(msa_blk, str):
179
+ raise ValueError(f"Unexpected MSA string value: {msa_blk!r}")
180
+ return MSA.from_sequences(msa_blk["sequences"])
181
+
182
+ sequences: list[ProteinInput | RNAInput | DNAInput | LigandInput] = []
183
+ for chain in data["sequences"]:
184
+ t = chain["type"]
185
+ if t == "protein":
186
+ sequences.append(
187
+ ProteinInput(
188
+ id=chain["id"],
189
+ sequence=chain["sequence"],
190
+ modifications=_mods(chain),
191
+ msa=_msa(chain),
192
+ )
193
+ )
194
+ elif t == "rna":
195
+ sequences.append(
196
+ RNAInput(
197
+ id=chain["id"],
198
+ sequence=chain["sequence"],
199
+ modifications=_mods(chain),
200
+ )
201
+ )
202
+ elif t == "dna":
203
+ sequences.append(
204
+ DNAInput(
205
+ id=chain["id"],
206
+ sequence=chain["sequence"],
207
+ modifications=_mods(chain),
208
+ )
209
+ )
210
+ elif t == "ligand":
211
+ sequences.append(
212
+ LigandInput(
213
+ id=chain["id"], smiles=chain.get("smiles"), ccd=chain.get("ccd")
214
+ )
215
+ )
216
+ else:
217
+ raise ValueError(f"Unsupported sequence type: {t!r}")
218
+
219
+ pocket: PocketConditioning | None = None
220
+ if (pocket_blk := data.get("pocket")) is not None:
221
+ pocket = PocketConditioning(
222
+ binder_chain_id=pocket_blk["binder_chain_id"],
223
+ contacts=[tuple(c) for c in pocket_blk["contacts"]],
224
+ )
225
+
226
+ distogram_conditioning: list[DistogramConditioning] | None = None
227
+ if (disto_blk := data.get("distogram_conditioning")) is not None:
228
+ distogram_conditioning = [
229
+ DistogramConditioning(
230
+ chain_id=d["chain_id"], distogram=np.asarray(d["distogram"])
231
+ )
232
+ for d in disto_blk
233
+ ]
234
+
235
+ covalent_bonds: list[CovalentBond] | None = None
236
+ if (bonds_blk := data.get("covalent_bonds")) is not None:
237
+ covalent_bonds = [
238
+ CovalentBond(
239
+ chain_id1=b["chain_id1"],
240
+ res_idx1=b["res_idx1"],
241
+ atom_idx1=b["atom_idx1"],
242
+ chain_id2=b["chain_id2"],
243
+ res_idx2=b["res_idx2"],
244
+ atom_idx2=b["atom_idx2"],
245
+ )
246
+ for b in bonds_blk
247
+ ]
248
+
249
+ return StructurePredictionInput(
250
+ sequences=sequences,
251
+ pocket=pocket,
252
+ distogram_conditioning=distogram_conditioning,
253
+ covalent_bonds=covalent_bonds,
254
+ )
255
+
esmfold2_metrics.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from torch import Tensor
6
+ from torch.amp import autocast # type: ignore
7
+
8
+ from . import esmfold2_residue_constants
9
+ from .esmfold2_misc import binpack, unbinpack
10
+ from .esmfold2_protein_structure import (
11
+ compute_alignment_tensors,
12
+ compute_gdt_ts_no_alignment,
13
+ compute_rmsd_no_alignment,
14
+ )
15
+
16
+
17
+ def contact_precision(
18
+ predictions: Tensor,
19
+ targets: Tensor,
20
+ src_lengths: Tensor | None = None,
21
+ minsep: int = 6,
22
+ maxsep: int | None = None,
23
+ override_length: int | None = None, # for casp
24
+ ):
25
+ """Computes contact precisions.
26
+
27
+ For protein contact prediction, precision is measured for the top (L/K) highest confidence predictions,
28
+ with L being the length of the protein sequence and K generally being equal to 1 or 5.
29
+
30
+ K = 5 measures the predictions of the very highest confidence contacts, while K = 1 is a more general measure
31
+ over all relatively high confidence predictions.
32
+
33
+ Since there are roughly ~L true contacts in a protein, this is a reasonable cutoff.
34
+
35
+
36
+ Args:
37
+ predictions (Tensor): Tensor of probabilities of size (B, L, L)
38
+ targets (Tensor): Tensor of true contacts of size (B, L, L)
39
+ src_lengths (Tensor, optional): Lengths of each sample in the batch, if using variable lengths.
40
+ If not provided, inferred from the size of the predictions.
41
+ minsep (int): Minimum separation distance to consider. We often want to measure contacts at a
42
+ certain range. Typical ranges are short [6, 12), medium [12, 24), and long [24, inf).
43
+ maxsep (int, optional): Used in conjunction with minsep to specify a contact range. If not provided uses
44
+ assumes no maximum range
45
+ override_length (int, optional): Used for casp evaluation where sometimes the "true" length is not
46
+ the same as the length of the input. Kept for posterity, we probably don't need this argument.
47
+ """
48
+ if predictions.dim() == 2:
49
+ predictions = predictions.unsqueeze(0)
50
+ if targets.dim() == 2:
51
+ targets = targets.unsqueeze(0)
52
+
53
+ # Check sizes
54
+ if predictions.size() != targets.size():
55
+ raise ValueError(
56
+ f"Size mismatch. Received predictions of size {predictions.size()}, "
57
+ f"targets of size {targets.size()}"
58
+ )
59
+ device = predictions.device
60
+
61
+ batch_size, seqlen, _ = predictions.size()
62
+
63
+ # Step 1) Construct a mask of size [B, L, L] to mask invalid contacts
64
+ seqlen_range = torch.arange(seqlen, device=device)
65
+ sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1)
66
+ sep = sep.unsqueeze(0)
67
+ # Mask contacts that are closer than minsep
68
+ valid_mask = sep >= minsep
69
+ # Mask contacts where target is negative (padding or unknown)
70
+ valid_mask = valid_mask & (targets >= 0) # negative targets are invalid
71
+
72
+ # Mask contacts that are farther than maxsep, if provided
73
+ if maxsep is not None:
74
+ valid_mask &= sep < maxsep
75
+
76
+ if src_lengths is not None:
77
+ # If the lengths of the individual sequences are provided, mask positions
78
+ # that are farther than the end of the sequence.
79
+ valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1)
80
+ valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2)
81
+ else:
82
+ src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long)
83
+
84
+ # Fill in the logit tensor with -inf for all invalid positions
85
+ predictions = predictions.masked_fill(~valid_mask, float("-inf"))
86
+
87
+ # Step 2) Select the top half of the prediction (should be symmetric)
88
+ x_ind, y_ind = np.triu_indices(seqlen, minsep)
89
+ predictions_upper = predictions[:, x_ind, y_ind]
90
+ targets_upper = targets[:, x_ind, y_ind]
91
+
92
+ # Step 3) Select the topk values in each batch where k = L (length of sequence)
93
+ topk = seqlen if override_length is None else max(seqlen, override_length)
94
+ # Indices are the indices into the predictions corresponding to the most confident predictions
95
+ indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk]
96
+ # topk_targets are the target values corresponding to the above indices
97
+ topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices]
98
+ if topk_targets.size(1) < topk:
99
+ # If there aren't enough targets, pad to the output.
100
+ topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)])
101
+
102
+ # Step 4) Sum the accuracy at of the top-i predictions for i in 1, L
103
+ # topk_targets => 1/0 true vs. false contact, sorted by confidence of prediction
104
+ # cmumulative sum => Number of correct answers for the top-i predictions.
105
+ cumulative_dist = topk_targets.type_as(predictions).cumsum(-1)
106
+
107
+ # Step 5) Find the gather indices. This should be P@(L / K) for varous values of K
108
+ # The values will differ for each batch.
109
+ gather_lengths = src_lengths.unsqueeze(1)
110
+ if override_length is not None:
111
+ gather_lengths = override_length * torch.ones_like(
112
+ gather_lengths, device=device
113
+ )
114
+
115
+ # This gets you (0.1 * L, 0.2 * L, 0.3 * L, etc.)
116
+ gather_indices = (
117
+ (torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths).type(
118
+ torch.long
119
+ )
120
+ - 1
121
+ ).clamp_min(0)
122
+
123
+ # Step 6) Gather the results and divide by the number of guesses to get the precision.
124
+ binned_cumulative_dist = cumulative_dist.gather(1, gather_indices)
125
+ binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as(
126
+ binned_cumulative_dist
127
+ )
128
+
129
+ # Select specific P@L/k. pl5 is index 1 b/c that corresponds to L * 0.2 in
130
+ # gather_indices above
131
+ pl5 = binned_precisions[:, 1]
132
+ # pl2 = binned_precisions[:, 4]
133
+ pl = binned_precisions[:, 9]
134
+ # AUC is the integral wrt K of P@L/K for K in range(1, L)
135
+ auc = binned_precisions.mean(-1)
136
+
137
+ return {"AUC": auc, "P@L": pl, "P@L5": pl5}
138
+
139
+
140
+ def compute_lddt(
141
+ all_atom_pred_pos: torch.Tensor,
142
+ all_atom_positions: torch.Tensor,
143
+ all_atom_mask: torch.Tensor,
144
+ pairwise_all_atom_mask: torch.Tensor | None = None,
145
+ cutoff: float | torch.Tensor = 15.0,
146
+ eps: float = 1e-10,
147
+ per_residue: bool = True,
148
+ sequence_id: torch.Tensor | None = None,
149
+ ) -> torch.Tensor:
150
+ """
151
+ Computes LDDT for a protein. Tensor sizes below include some optional dimensions. Specifically:
152
+ Nstates:
153
+ all_atom_pred_pos can contain multiple states in the first dimension which corresponds to outputs from different layers of a model (e.g. each IPA block). The return size will be [Nstates x Batch size] if this is included.
154
+ Natoms:
155
+ LDDT can be computed for all atoms or some atoms. The second to last dimension should contain the *FLATTENED* representation of L x Natoms. If you want to calculate for atom37, e.g., this will be of size (L * 37). If you are only calculating CA LDDT, it will be of size L.
156
+
157
+ Args:
158
+ all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions
159
+ all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions
160
+ all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists.
161
+ pairwise_all_atom_mask (Tensor[float], [B x (L * Natoms x L * Natoms)], optional): Tensor of masks, indicating whether a pair of atoms should be considered in the LDDT calculation.
162
+ cutoff (float): Max distance to score lddt over. This can either be a float, or a tensor of shape [B, L, L] to allow for per-residue cutoffs, e.g. if you want to use a different cutoff for nucleic acids.
163
+ per_residue (bool): Whether to return per-residue or full-protein lddt.
164
+ sequence_id (Tensor, optional): Sequence id tensor for binpacking. NOTE: only supported for lddt_ca calculations, not when Natoms is passed!
165
+
166
+ Returns:
167
+ LDDT Tensor:
168
+ if per_residue:
169
+ Tensor[float], [(Nstates x) B x (L * Natoms)]
170
+ else:
171
+ Tensor[float], [(Nstates x) B]
172
+ """
173
+ all_atom_mask = all_atom_mask[..., None] # add a dimension for broadcasting
174
+ dmat_true = torch.sqrt(
175
+ eps
176
+ + torch.sum(
177
+ (all_atom_positions[..., None, :] - all_atom_positions[..., None, :, :])
178
+ ** 2,
179
+ dim=-1,
180
+ )
181
+ )
182
+
183
+ dmat_pred = torch.sqrt(
184
+ eps
185
+ + torch.sum(
186
+ (all_atom_pred_pos[..., None, :] - all_atom_pred_pos[..., None, :, :]) ** 2,
187
+ dim=-1,
188
+ )
189
+ )
190
+ mask = all_atom_mask * rearrange(all_atom_mask, "... a b -> ... b a")
191
+ if pairwise_all_atom_mask is not None:
192
+ mask = mask * pairwise_all_atom_mask
193
+
194
+ if sequence_id is not None:
195
+ # TODO: This will work for lddt_ca, but not for regular lddt
196
+ # Problem is that regular lddt has natoms * nres scores, so would need to repeat this mask by natoms
197
+ # Leaving for now because it won't fail silently so should be ook.
198
+ seqid_mask = sequence_id[..., None] == sequence_id[..., None, :]
199
+ mask = mask * seqid_mask.type_as(mask)
200
+
201
+ return compute_lddt_from_dmat(
202
+ dmat_pred, dmat_true, mask, cutoff=cutoff, eps=eps, per_residue=per_residue
203
+ )
204
+
205
+
206
+ def compute_lddt_from_dmat(
207
+ dmat_pred: torch.Tensor,
208
+ dmat_true: torch.Tensor,
209
+ pairwise_mask: torch.Tensor,
210
+ cutoff: float | torch.Tensor = 15.0,
211
+ eps: float = 1e-10,
212
+ per_residue: bool = True,
213
+ ):
214
+ """
215
+ Compute LDDT from pre-computed distance matrices.
216
+ This is useful when you want to compute LDDT with multiple different masks or cutoffs, e.g. for different molecule types (protein, nucleic acid, etc.).
217
+
218
+ Args:
219
+ dmat_pred (Tensor[float], [B x L x L]): Predicted distance matrix
220
+ dmat_true (Tensor[float], [B x L x L]): True distance matrix
221
+ pairwise_mask (Tensor[float], [B x L x L]): Pairwise mask indicating which pairs of atoms to consider
222
+ cutoff (float): Max distance to score lddt over. This can either be a float, or a tensor of shape [B, L, L] to allow for per-residue cutoffs, e.g. if you want to use a different cutoff for nucleic acids.
223
+ per_residue (bool): Whether to return per-residue or full-protein lddt.
224
+
225
+ Returns:
226
+ LDDT Tensor:
227
+ if per_residue:
228
+ Tensor[float], [B x L]
229
+ else:
230
+ Tensor[float], [B]
231
+ """
232
+ n = dmat_true.size(-1)
233
+ dists_to_score = (
234
+ (dmat_true < cutoff)
235
+ * pairwise_mask
236
+ * (1.0 - torch.eye(n, device=dmat_true.device))
237
+ )
238
+
239
+ dist_l1 = torch.abs(dmat_true - dmat_pred)
240
+ score = (
241
+ (dist_l1 < 0.5).type(dist_l1.dtype)
242
+ + (dist_l1 < 1.0).type(dist_l1.dtype)
243
+ + (dist_l1 < 2.0).type(dist_l1.dtype)
244
+ + (dist_l1 < 4.0).type(dist_l1.dtype)
245
+ )
246
+ score = score * 0.25
247
+
248
+ dims = (-1,) if per_residue else (-2, -1)
249
+ norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
250
+ score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
251
+ return score
252
+
253
+
254
+ def compute_lddt_ca(
255
+ all_atom_pred_pos: torch.Tensor,
256
+ all_atom_positions: torch.Tensor,
257
+ all_atom_mask: torch.Tensor,
258
+ cutoff: float = 15.0,
259
+ eps: float = 1e-10,
260
+ per_residue: bool = True,
261
+ sequence_id: torch.Tensor | None = None,
262
+ ) -> torch.Tensor:
263
+ ca_pos = residue_constants.atom_order["CA"]
264
+ if all_atom_pred_pos.dim() != 3:
265
+ all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
266
+ all_atom_positions = all_atom_positions[..., ca_pos, :]
267
+ all_atom_mask = all_atom_mask[..., ca_pos]
268
+
269
+ return compute_lddt(
270
+ all_atom_pred_pos,
271
+ all_atom_positions,
272
+ all_atom_mask,
273
+ cutoff=cutoff,
274
+ eps=eps,
275
+ per_residue=per_residue,
276
+ sequence_id=sequence_id,
277
+ )
278
+
279
+
280
+ # NOTE(roshan): no_grad required for stack_variable_length_tensors apparently... let's revisit if we want to backprop
281
+ @torch.no_grad()
282
+ @autocast("cuda", enabled=False)
283
+ def compute_rmsd(
284
+ mobile: torch.Tensor,
285
+ target: torch.Tensor,
286
+ atom_exists_mask: torch.Tensor | None = None,
287
+ sequence_id: torch.Tensor | None = None,
288
+ reduction: str = "batch",
289
+ ):
290
+ """
291
+ Compute RMSD between two batches of structures with support for masking invalid atoms using PyTorch.
292
+
293
+ Args:
294
+ - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
295
+ - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
296
+ - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
297
+ - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
298
+ - reduction (str): One of "batch", "per_sample", "per_residue".
299
+
300
+ Returns:
301
+ If reduction == "batch":
302
+ (torch.Tensor): 0-dim, Average Root Mean Square Deviation between the structures for each batch
303
+ If reduction == "per_sample":
304
+ (torch.Tensor): (B,)-dim, Root Mean Square Deviation between the structures for each batch
305
+ If reduction == "per_residue":
306
+ (torch.Tensor): (B, N)-dim, Root Mean Square Deviation between the structures for residue in the batch
307
+ """
308
+
309
+ (centered_mobile, _, centered_target, _, rotation_matrix, num_valid_atoms) = (
310
+ compute_alignment_tensors(
311
+ mobile=mobile,
312
+ target=target,
313
+ atom_exists_mask=atom_exists_mask,
314
+ sequence_id=sequence_id,
315
+ )
316
+ )
317
+
318
+ # Apply transformation to centered structure
319
+ rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
320
+
321
+ # Compute rmsd for centered structures
322
+ rmsd = compute_rmsd_no_alignment(
323
+ rotated_mobile, centered_target, num_valid_atoms, reduction=reduction
324
+ )
325
+ if reduction == "per_residue" and sequence_id is not None:
326
+ rmsd = binpack(rmsd, sequence_id, pad_value=0)
327
+ return rmsd
328
+
329
+
330
+ def compute_gdt_ts(
331
+ mobile: torch.Tensor,
332
+ target: torch.Tensor,
333
+ atom_exists_mask: torch.Tensor | None = None,
334
+ sequence_id: torch.Tensor | None = None,
335
+ reduction: str = "per_sample",
336
+ ):
337
+ """
338
+ Compute GDT_TS between two batches of structures with support for masking invalid atoms using PyTorch.
339
+
340
+ Args:
341
+ - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
342
+ - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
343
+ - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
344
+ - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
345
+ - reduction (str): One of "batch", "per_sample", "per_residue".
346
+
347
+ Returns:
348
+ If reduction == "batch":
349
+ (torch.Tensor): 0-dim, GDT_TS between the structures for each batch
350
+ If reduction == "per_sample":
351
+ (torch.Tensor): (B,)-dim, GDT_TS between the structures for each sample in the batch
352
+ """
353
+ if atom_exists_mask is None:
354
+ atom_exists_mask = torch.isfinite(target).all(dim=-1)
355
+ (centered_mobile, _, centered_target, _, rotation_matrix, _) = (
356
+ compute_alignment_tensors(
357
+ mobile=mobile,
358
+ target=target,
359
+ atom_exists_mask=atom_exists_mask,
360
+ sequence_id=sequence_id,
361
+ )
362
+ )
363
+
364
+ # Apply transformation to centered structure
365
+ rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
366
+
367
+ # the coordinate tensors returned by `compute_alignment_tensors` are unbinpacked and contain zeros for invalid positions
368
+ # so `compute_gdt_ts_no_alignment` requires `atom_exists_mask` to be passed and be unbinpacked
369
+ if sequence_id is not None:
370
+ atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=False)
371
+ return compute_gdt_ts_no_alignment(
372
+ rotated_mobile, centered_target, atom_exists_mask, reduction
373
+ )
374
+
esmfold2_misc.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from collections import defaultdict
5
+ from contextlib import nullcontext
6
+ from dataclasses import is_dataclass
7
+ from io import BytesIO
8
+ from typing import (
9
+ Any,
10
+ ContextManager,
11
+ Generator,
12
+ Iterable,
13
+ Protocol,
14
+ Sequence,
15
+ TypeVar,
16
+ runtime_checkable,
17
+ )
18
+ from warnings import warn
19
+
20
+ import huggingface_hub
21
+ import numpy as np
22
+ import torch
23
+ import zstd
24
+
25
+ from .esmfold2_constants_esm3 import CHAIN_BREAK_STR
26
+ from .esmfold2_utils_types import FunctionAnnotation
27
+
28
+ MAX_SUPPORTED_DISTANCE = 1e6
29
+
30
+
31
+ TSequence = TypeVar("TSequence", bound=Sequence)
32
+
33
+
34
+ @runtime_checkable
35
+ class Concatable(Protocol):
36
+ @classmethod
37
+ def concat(cls, objs: list[Concatable]) -> Concatable: ...
38
+
39
+
40
+ def slice_python_object_as_numpy(
41
+ obj: TSequence, idx: int | list[int] | slice | np.ndarray
42
+ ) -> TSequence:
43
+ """
44
+ Slice a python object (like a list, string, or tuple) as if it was a numpy object.
45
+
46
+ Example:
47
+ >>> obj = "ABCDE"
48
+ >>> slice_python_object_as_numpy(obj, [1, 3, 4])
49
+ "BDE"
50
+
51
+ >>> obj = [1, 2, 3, 4, 5]
52
+ >>> slice_python_object_as_numpy(obj, np.arange(5) < 3)
53
+ [1, 2, 3]
54
+ """
55
+ if np.isscalar(idx):
56
+ idx = [int(idx)] # type: ignore
57
+
58
+ if isinstance(idx, np.ndarray) and idx.dtype == bool:
59
+ sliced_obj = [obj[i] for i in np.where(idx)[0]]
60
+ elif isinstance(idx, slice):
61
+ sliced_obj = obj[idx]
62
+ else:
63
+ sliced_obj = [obj[i] for i in idx] # type: ignore
64
+
65
+ match obj, sliced_obj:
66
+ case str(), list():
67
+ sliced_obj = "".join(sliced_obj)
68
+ case _:
69
+ sliced_obj = obj.__class__(sliced_obj) # type: ignore
70
+
71
+ return sliced_obj # type: ignore
72
+
73
+
74
+ def slice_any_object(
75
+ obj: TSequence, idx: int | list[int] | slice | np.ndarray
76
+ ) -> TSequence:
77
+ """
78
+ Slice a arbitrary object (like a list, string, or tuple) as if it was a numpy object. Similar to `slice_python_object_as_numpy`, but detects if it's a numpy array or Tensor and uses the existing slice method if so.
79
+
80
+ If the object is a dataclass, it will simply apply the index to the object, under the assumption that the object has correcty implemented numpy indexing.
81
+
82
+ Example:
83
+ >>> obj = "ABCDE"
84
+ >>> slice_any_object(obj, [1, 3, 4])
85
+ "BDE"
86
+
87
+ >>> obj = np.array([1, 2, 3, 4, 5])
88
+ >>> slice_any_object(obj, np.arange(5) < 3)
89
+ np.array([1, 2, 3])
90
+
91
+ >>> obj = ProteinChain.from_rcsb("1a3a", "A")
92
+ >>> slice_any_object(obj, np.arange(len(obj)) < 10)
93
+ # ProteinChain w/ length 10
94
+
95
+ """
96
+ if isinstance(obj, (np.ndarray, torch.Tensor)):
97
+ return obj[idx] # type: ignore
98
+ elif is_dataclass(obj):
99
+ # if passing a dataclass, assume it implements a custom slice
100
+ return obj[idx] # type: ignore
101
+ else:
102
+ return slice_python_object_as_numpy(obj, idx)
103
+
104
+
105
+ def rbf(values, v_min, v_max, n_bins=16):
106
+ """
107
+ Returns RBF encodings in a new dimension at the end.
108
+ """
109
+ rbf_centers = torch.linspace(
110
+ v_min, v_max, n_bins, device=values.device, dtype=values.dtype
111
+ )
112
+ rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
113
+ rbf_std = (v_max - v_min) / n_bins
114
+ z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
115
+ return torch.exp(-(z**2))
116
+
117
+
118
+ def batched_gather(data, inds, dim=0, no_batch_dims=0):
119
+ ranges = []
120
+ for i, s in enumerate(data.shape[:no_batch_dims]):
121
+ r = torch.arange(s)
122
+ r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
123
+ ranges.append(r)
124
+
125
+ remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
126
+ remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
127
+ ranges.extend(remaining_dims)
128
+ return data[ranges]
129
+
130
+
131
+ def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
132
+ return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1)
133
+
134
+
135
+ def knn_graph(
136
+ coords: torch.Tensor,
137
+ coord_mask: torch.Tensor,
138
+ padding_mask: torch.Tensor,
139
+ sequence_id: torch.Tensor,
140
+ *,
141
+ no_knn: int,
142
+ ):
143
+ L = coords.shape[-2]
144
+ num_by_dist = min(no_knn, L)
145
+ device = coords.device
146
+
147
+ coords = coords.nan_to_num()
148
+ coord_mask = ~(coord_mask[..., None, :] & coord_mask[..., :, None])
149
+ padding_pairwise_mask = padding_mask[..., None, :] | padding_mask[..., :, None]
150
+ if sequence_id is not None:
151
+ padding_pairwise_mask |= torch.unsqueeze(sequence_id, 1) != torch.unsqueeze(
152
+ sequence_id, 2
153
+ )
154
+ dists = (coords.unsqueeze(-2) - coords.unsqueeze(-3)).norm(dim=-1)
155
+ arange = torch.arange(L, device=device)
156
+ seq_dists = (arange.unsqueeze(-1) - arange.unsqueeze(-2)).abs()
157
+ # We only support up to a certain distance, above that, we use sequence distance
158
+ # instead. This is so that when a large portion of the structure is masked out,
159
+ # the edges are built according to sequence distance.
160
+ max_dist = MAX_SUPPORTED_DISTANCE
161
+ if not (dists[~coord_mask] < max_dist).all():
162
+ raise ValueError(
163
+ f"Coordinate pairwise distances exceed max supported distance ({max_dist}). "
164
+ )
165
+ struct_then_seq_dist = (
166
+ seq_dists.to(dists.dtype)
167
+ .mul(1e2)
168
+ .add(max_dist)
169
+ .where(coord_mask, dists)
170
+ .masked_fill(padding_pairwise_mask, torch.inf)
171
+ )
172
+ dists, edges = struct_then_seq_dist.sort(dim=-1, descending=False)
173
+ # This is a L x L tensor, where we index by rows first,
174
+ # and columns are the edges we should pick.
175
+ chosen_edges = edges[..., :num_by_dist]
176
+ chosen_mask = dists[..., :num_by_dist].isfinite()
177
+ return chosen_edges, chosen_mask
178
+
179
+
180
+ def stack_variable_length_tensors(
181
+ sequences: Sequence[torch.Tensor],
182
+ constant_value: int | float = 0,
183
+ dtype: torch.dtype | None = None,
184
+ ) -> torch.Tensor:
185
+ """Automatically stack tensors together, padding variable lengths with the
186
+ value in constant_value. Handles an arbitrary number of dimensions.
187
+
188
+ Examples:
189
+ >>> tensor1, tensor2 = torch.ones([2]), torch.ones([5])
190
+ >>> stack_variable_length_tensors(tensor1, tensor2)
191
+ tensor of shape [2, 5]. First row is [1, 1, 0, 0, 0]. Second row is all ones.
192
+
193
+ >>> tensor1, tensor2 = torch.ones([2, 4]), torch.ones([5, 3])
194
+ >>> stack_variable_length_tensors(tensor1, tensor2)
195
+ tensor of shape [2, 5, 4]
196
+ """
197
+ batch_size = len(sequences)
198
+ shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist()
199
+
200
+ if dtype is None:
201
+ dtype = sequences[0].dtype
202
+ device = sequences[0].device
203
+
204
+ array = torch.full(shape, constant_value, dtype=dtype, device=device)
205
+ for arr, seq in zip(array, sequences):
206
+ arrslice = tuple(slice(dim) for dim in seq.shape)
207
+ arr[arrslice] = seq
208
+
209
+ return array
210
+
211
+
212
+ def binpack(
213
+ tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float
214
+ ):
215
+ """
216
+ Args:
217
+ tensor (Tensor): [B, L, ...]
218
+
219
+ Returns:
220
+ Tensor: [B_binpacked, L_binpacked, ...]
221
+ """
222
+ if sequence_id is None:
223
+ return tensor
224
+
225
+ num_sequences = sequence_id.max(dim=-1).values + 1
226
+
227
+ dims = sequence_id.shape + tensor.shape[2:]
228
+ output_tensor = torch.full(
229
+ dims, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device
230
+ )
231
+
232
+ idx = 0
233
+ for batch_idx, (batch_seqid, batch_num_sequences) in enumerate(
234
+ zip(sequence_id, num_sequences)
235
+ ):
236
+ for seqid in range(batch_num_sequences):
237
+ mask = batch_seqid == seqid
238
+ output_tensor[batch_idx, mask] = tensor[idx, : mask.sum()]
239
+ idx += 1
240
+ return output_tensor
241
+
242
+
243
+ def unbinpack(
244
+ tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float
245
+ ):
246
+ """
247
+ Args:
248
+ tensor (Tensor): [B, L, ...]
249
+
250
+ Returns:
251
+ Tensor: [B_unbinpacked, L_unbinpack, ...]
252
+ """
253
+ if sequence_id is None:
254
+ return tensor
255
+
256
+ unpacked_tensors = []
257
+ num_sequences = sequence_id.max(dim=-1).values + 1
258
+ for batch_idx, (batch_seqid, batch_num_sequences) in enumerate(
259
+ zip(sequence_id, num_sequences)
260
+ ):
261
+ for seqid in range(batch_num_sequences):
262
+ mask = batch_seqid == seqid
263
+ unpacked = tensor[batch_idx, mask]
264
+ unpacked_tensors.append(unpacked)
265
+ return stack_variable_length_tensors(unpacked_tensors, pad_value)
266
+
267
+
268
+ def fp32_autocast_context(device_type: str) -> ContextManager[Any]: # type: ignore
269
+ """
270
+ Returns an autocast context manager that disables downcasting by AMP.
271
+
272
+ Args:
273
+ device_type: The device type ('cpu' or 'cuda')
274
+
275
+ Returns:
276
+ An autocast context manager with the specified behavior.
277
+ """
278
+ if device_type == "cpu":
279
+ return torch.amp.autocast(device_type, enabled=False) # type: ignore
280
+ elif device_type == "mps":
281
+ # For MPS, just return a no-op context manager (nullcontext) since MPS does not support autocast.
282
+ return nullcontext()
283
+ elif device_type == "cuda":
284
+ return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore
285
+ else:
286
+ raise ValueError(f"Unsupported device type: {device_type}")
287
+
288
+
289
+ def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[range]:
290
+ """Merge overlapping ranges into sorted, non-overlapping segments.
291
+
292
+ Args:
293
+ ranges: collection of ranges to merge.
294
+ merge_gap_max: optionally merge neighboring ranges that are separated by a gap
295
+ no larger than this size.
296
+ Returns:
297
+ non-overlapping ranges merged from the inputs, sorted by position.
298
+ """
299
+ ranges = sorted(ranges, key=lambda r: r.start)
300
+ merge_gap_max = merge_gap_max if merge_gap_max is not None else 0
301
+ assert merge_gap_max >= 0, f"Invalid merge_gap_max: {merge_gap_max}"
302
+
303
+ merged = []
304
+ for r in ranges:
305
+ if not merged:
306
+ merged.append(r)
307
+ else:
308
+ last = merged[-1]
309
+ if last.stop + merge_gap_max >= r.start:
310
+ merged[-1] = range(last.start, max(last.stop, r.stop))
311
+ else:
312
+ merged.append(r)
313
+ return merged
314
+
315
+
316
+ def merge_annotations(
317
+ annotations: list[FunctionAnnotation], merge_gap_max: int | None = None
318
+ ) -> list[FunctionAnnotation]:
319
+ """Merges annotations into non-overlapping segments.
320
+
321
+ Args:
322
+ annotations: annotations to merge.
323
+ merge_gap_max: optionally merge neighboring ranges that are separated by a gap
324
+ no larger than this size.
325
+ Returns:
326
+ non-overlapping annotations with gaps merged.
327
+ """
328
+ grouped: dict[str, list[range]] = defaultdict(list)
329
+ for a in annotations:
330
+ # +1 since FunctionAnnotation.end is inlcusive.
331
+ grouped[a.label].append(range(a.start, a.end + 1))
332
+
333
+ merged = []
334
+ for label, ranges in grouped.items():
335
+ merged_ranges = merge_ranges(ranges, merge_gap_max=merge_gap_max)
336
+ for range_ in merged_ranges:
337
+ annotation = FunctionAnnotation(
338
+ label=label,
339
+ start=range_.start,
340
+ end=range_.stop - 1, # convert range.stop exclusive -> inclusive.
341
+ )
342
+ merged.append(annotation)
343
+ return merged
344
+
345
+
346
+ def replace_inf(data):
347
+ if data is None:
348
+ return None
349
+ array = np.asarray(data, dtype=np.float32)
350
+ array = np.where(np.isinf(array), 1000, array)
351
+ return array.tolist()
352
+
353
+
354
+ def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None:
355
+ if x is None:
356
+ return None
357
+ if isinstance(x, torch.Tensor):
358
+ return x
359
+ if isinstance(x, list) and all(isinstance(t, torch.Tensor) for t in x):
360
+ return torch.stack(x)
361
+ if convert_none_to_nan:
362
+ x = np.asarray(x, dtype=np.float32)
363
+ x = np.where(x is None, np.nan, x)
364
+ return torch.tensor(x)
365
+
366
+
367
+ def maybe_list(x, convert_nan_to_none: bool = False) -> list | None:
368
+ if x is None:
369
+ return None
370
+ if not convert_nan_to_none:
371
+ return x.tolist()
372
+
373
+ # Handle both torch.tensor and np.ndarray input.
374
+ if isinstance(x, torch.Tensor):
375
+ nan_mask = torch.isnan(x).cpu().numpy()
376
+ np_arr = x.cpu().numpy().astype(object)
377
+ elif isinstance(x, np.ndarray):
378
+ nan_mask = np.isnan(x)
379
+ np_arr = x.astype(object)
380
+ else:
381
+ raise TypeError("maybe_list can only work with torch.tensor or np.ndarray.")
382
+
383
+ np_arr[nan_mask] = None
384
+ return np_arr.tolist()
385
+
386
+
387
+ def huggingfacehub_login():
388
+ """Authenticates with the Hugging Face Hub using the HF_TOKEN environment
389
+ variable, else by prompting the user"""
390
+ token = os.environ.get("HF_TOKEN")
391
+ huggingface_hub.login(token=token)
392
+
393
+
394
+ def get_chainbreak_boundaries_from_sequence(sequence: Sequence[str]) -> np.ndarray:
395
+ chain_boundaries = [0]
396
+ for i, aa in enumerate(sequence):
397
+ if aa == CHAIN_BREAK_STR:
398
+ if i == (len(sequence) - 1):
399
+ raise ValueError(
400
+ "Encountered chain break token at end of sequence, this is unexpected."
401
+ )
402
+ if i == (len(sequence) - 2):
403
+ warn(
404
+ "Encountered chain break token at penultimate position, this is unexpected."
405
+ )
406
+ chain_boundaries.append(i)
407
+ chain_boundaries.append(i + 1)
408
+ chain_boundaries.append(len(sequence))
409
+ assert len(chain_boundaries) % 2 == 0
410
+ chain_boundaries = np.array(chain_boundaries).reshape(-1, 2)
411
+ return chain_boundaries
412
+
413
+
414
+ def deserialize_tensors(b: bytes) -> Any:
415
+ buf = BytesIO(zstd.ZSTD_uncompress(b))
416
+ d = torch.load(buf, map_location="cpu", weights_only=False)
417
+ return d
418
+
419
+
420
+ def join_lists(
421
+ lists: Sequence[Sequence[Any]], separator: Sequence[Any] | None = None
422
+ ) -> list[Any]:
423
+ """Joins multiple lists with separator element. Like str.join but for lists.
424
+
425
+ Example: [[1, 2], [3], [4]], separator=[0] -> [1, 2, 0, 3, 0, 4]
426
+
427
+ Args:
428
+ lists: Lists of elements to chain
429
+ separator: separators to intsert between chained output.
430
+ Returns:
431
+ Joined lists.
432
+ """
433
+ if not lists:
434
+ return []
435
+ joined = []
436
+ joined.extend(lists[0])
437
+ for l in lists[1:]:
438
+ if separator:
439
+ joined.extend(separator)
440
+ joined.extend(l)
441
+ return joined
442
+
443
+
444
+ def iterate_with_intermediate(
445
+ lists: Iterable, intermediate
446
+ ) -> Generator[Any, None, None]:
447
+ """
448
+ Iterate over the iterable, yielding the intermediate value between
449
+ every element of the intermediate. Useful for joining objects with
450
+ separator tokens.
451
+ """
452
+ it = iter(lists)
453
+ yield next(it)
454
+ for l in it:
455
+ yield intermediate
456
+ yield l
457
+
458
+
459
+ def concat_objects(objs: Sequence[Any], separator: Any | None = None):
460
+ """
461
+ Concat objects with each other using a separator token.
462
+
463
+ Supports:
464
+ - Concatable (objects that implement `concat` classmethod)
465
+ - strings
466
+ - lists
467
+ - numpy arrays
468
+ - torch Tensors
469
+
470
+ Example:
471
+ >>> foo = "abc"
472
+ >>> bar = "def"
473
+ >>> concat_objects([foo, bar], "|")
474
+ "abc|def"
475
+ """
476
+ match objs[0]:
477
+ case Concatable():
478
+ return objs[0].__class__.concat(objs) # type: ignore
479
+ case str():
480
+ assert isinstance(
481
+ separator, str
482
+ ), "Trying to join strings but separator is not a string"
483
+ return separator.join(objs)
484
+ case list():
485
+ if separator is not None:
486
+ return join_lists(objs, [separator])
487
+ else:
488
+ return join_lists(objs)
489
+ case np.ndarray():
490
+ if separator is not None:
491
+ return np.concatenate(
492
+ list(iterate_with_intermediate(objs, np.array([separator])))
493
+ )
494
+ else:
495
+ return np.concatenate(objs)
496
+ case torch.Tensor():
497
+ if separator is not None:
498
+ return torch.cat(
499
+ list(iterate_with_intermediate(objs, torch.tensor([separator])))
500
+ )
501
+ else:
502
+ return torch.cat(objs) # type: ignore
503
+ case _:
504
+ raise TypeError(type(objs[0]))
505
+
esmfold2_mmcif_parsing.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import io
5
+ import os
6
+ from dataclasses import dataclass
7
+ from datetime import datetime
8
+ from typing import Union
9
+
10
+ import biotite.structure as bs
11
+ import biotite.structure.io.pdbx as pdbx
12
+
13
+ from . import esmfold2_residue_constants
14
+
15
+ # Define PathOrBuffer for the opensource version
16
+ PathOrBuffer = Union[str, os.PathLike, io.StringIO]
17
+
18
+
19
+ class NoProteinError(Exception):
20
+ pass
21
+
22
+
23
+ @dataclass
24
+ class Residue:
25
+ residue_number: int | None = None
26
+ insertion_code: str = ""
27
+ hetflag: bool = False
28
+
29
+
30
+ @dataclass
31
+ class MmcifHeader:
32
+ release_date: datetime | None = None
33
+ resolution: float | None = None
34
+ structure_method: str = "UNKNOWN"
35
+
36
+
37
+ class MmcifWrapper:
38
+ def __init__(self, id: str | None = None):
39
+ self.id: str = id or ""
40
+ self.raw: pdbx.CIFFile | None = None
41
+ self.structure: bs.AtomArray
42
+ self.header: MmcifHeader = MmcifHeader()
43
+ self.entities: dict[int, list[str]] = {}
44
+ self.chain_to_seqres: dict[str, str] = {}
45
+ self.seqres_to_structure: dict[str, dict[int, Residue]] = {}
46
+
47
+ @classmethod
48
+ def read(cls, path: PathOrBuffer, id: str | None = None) -> MmcifWrapper:
49
+ obj = cls(id=id)
50
+ obj._load(path)
51
+ return obj
52
+
53
+ def _load(self, path: PathOrBuffer, fileid: str | None = None):
54
+ """Load mmCIF data from file."""
55
+ self.raw = pdbx.CIFFile.read(path)
56
+
57
+ self._parse_structure()
58
+ self._parse_header()
59
+ self._parse_entities()
60
+ self._parse_sequences()
61
+
62
+ def _parse_structure(self):
63
+ """Parse the atomic structure from mmCIF."""
64
+ try:
65
+ structure = pdbx.get_structure(self.raw, model=1)
66
+ if structure is None or not isinstance(structure, bs.AtomArray):
67
+ raise NoProteinError("No structure found in mmCIF file")
68
+ if len(structure) == 0:
69
+ raise NoProteinError("Empty structure in mmCIF file")
70
+ self.structure = structure
71
+ except Exception as e:
72
+ raise ValueError(f"Failed to parse structure: {e}")
73
+
74
+ def _parse_header(self):
75
+ """Parse header information from mmCIF."""
76
+ if not self.raw:
77
+ return
78
+
79
+ try:
80
+ # Get the first (and usually only) block
81
+ block = self.raw.block
82
+
83
+ # Parse release date
84
+ if "pdbx_database_status" in block:
85
+ status_cat = block["pdbx_database_status"]
86
+ if "recvd_initial_deposition_date" in status_cat:
87
+ date_str = status_cat["recvd_initial_deposition_date"].as_item()
88
+ if date_str and date_str != "?":
89
+ try:
90
+ self.header.release_date = datetime.strptime(
91
+ date_str, "%Y-%m-%d"
92
+ )
93
+ except ValueError:
94
+ pass
95
+
96
+ # Parse resolution
97
+ if "refine" in block:
98
+ refine_cat = block["refine"]
99
+ if "ls_d_res_high" in refine_cat:
100
+ res_str = refine_cat["ls_d_res_high"].as_item()
101
+ if res_str and res_str != "?":
102
+ try:
103
+ self.header.resolution = float(res_str)
104
+ except ValueError:
105
+ pass
106
+
107
+ # Parse structure method
108
+ if "exptl" in block:
109
+ exptl_cat = block["exptl"]
110
+ if "method" in exptl_cat:
111
+ method = exptl_cat["method"].as_item()
112
+ if method and method != "?":
113
+ self.header.structure_method = method.upper()
114
+
115
+ except Exception:
116
+ # If parsing fails, keep default values
117
+ pass
118
+
119
+ def _parse_entities(self):
120
+ """Parse entity information and map to chains."""
121
+ if not self.raw:
122
+ return
123
+
124
+ try:
125
+ block = self.raw.block
126
+
127
+ # Parse entity information
128
+ if "entity" in block:
129
+ entity_cat = block["entity"]
130
+ entity_ids = entity_cat["id"].as_array(str)
131
+ entity_types = entity_cat["type"].as_array(str)
132
+
133
+ # Initialize entities dict with all entities (not just polymers)
134
+ for i, (entity_id, entity_type) in enumerate(
135
+ zip(entity_ids, entity_types)
136
+ ):
137
+ self.entities[int(entity_id)] = []
138
+
139
+ # Map polymer chains to entities using entity_poly
140
+ if "entity_poly" in block:
141
+ poly_cat = block["entity_poly"]
142
+ entity_ids = poly_cat["entity_id"].as_array(str)
143
+ chain_lists = poly_cat["pdbx_strand_id"].as_array(str)
144
+
145
+ for entity_id, chain_list in zip(entity_ids, chain_lists):
146
+ entity_id = int(entity_id)
147
+ # Chain list is comma-separated
148
+ chains = [c.strip() for c in chain_list.split(",") if c.strip()]
149
+ if entity_id in self.entities:
150
+ self.entities[entity_id] = chains
151
+
152
+ # Map non-polymer chains using struct_asym for entities not covered by entity_poly
153
+ if "struct_asym" in block:
154
+ asym_cat = block["struct_asym"]
155
+ asym_ids = asym_cat["id"].as_array(str)
156
+ entity_ids = asym_cat["entity_id"].as_array(str)
157
+
158
+ for asym_id, entity_id in zip(asym_ids, entity_ids):
159
+ entity_id = int(entity_id)
160
+ # Only add if entity exists but has no chains yet (non-polymer entities)
161
+ if entity_id in self.entities and not self.entities[entity_id]:
162
+ self.entities[entity_id].append(asym_id)
163
+
164
+ except Exception:
165
+ # If parsing fails, try to infer from structure
166
+ if (
167
+ self.structure
168
+ and hasattr(self.structure, "chain_id")
169
+ and self.structure.chain_id is not None
170
+ and hasattr(self.structure.chain_id, "__iter__")
171
+ ):
172
+ chain_ids = list(set(self.structure.chain_id))
173
+ self.entities = {1: chain_ids}
174
+
175
+ def _parse_sequences(self):
176
+ """Parse sequence information from mmCIF."""
177
+ if not self.raw:
178
+ return
179
+
180
+ block = self.raw.block
181
+
182
+ # Parse polymer sequences
183
+ if "entity_poly" in block:
184
+ poly_cat = block["entity_poly"]
185
+ entity_ids = poly_cat["entity_id"].as_array(str)
186
+ sequences = poly_cat["pdbx_seq_one_letter_code_can"].as_array(str)
187
+ chain_lists = poly_cat["pdbx_strand_id"].as_array(str)
188
+
189
+ for entity_id, sequence, chain_list in zip(
190
+ entity_ids, sequences, chain_lists
191
+ ):
192
+ # Clean up sequence (remove whitespace and newlines)
193
+ clean_seq = "".join(sequence.split())
194
+ chains = [c.strip() for c in chain_list.split(",") if c.strip()]
195
+
196
+ for chain_id in chains:
197
+ self.chain_to_seqres[chain_id] = clean_seq
198
+
199
+ # Parse sequence to structure mapping
200
+ if "pdbx_poly_seq_scheme" in block:
201
+ seq_cat = block["pdbx_poly_seq_scheme"]
202
+ asym_ids = seq_cat["asym_id"].as_array(str) # Internal chain IDs
203
+ seq_positions = seq_cat["seq_id"].as_array(str)
204
+ auth_seq_nums = seq_cat["auth_seq_num"].as_array(str)
205
+ ins_codes = (
206
+ seq_cat["pdb_ins_code"].as_array(str)
207
+ if "pdb_ins_code" in seq_cat
208
+ else [""] * len(asym_ids)
209
+ )
210
+ hetflags = (
211
+ seq_cat["hetflag"].as_array(str)
212
+ if "hetflag" in seq_cat
213
+ else ["N"] * len(asym_ids)
214
+ )
215
+
216
+ # Get author chain IDs if available
217
+ auth_chain_ids = (
218
+ seq_cat["pdb_strand_id"].as_array(str)
219
+ if "pdb_strand_id" in seq_cat
220
+ else asym_ids # Fallback to internal IDs
221
+ )
222
+
223
+ # Build mapping from internal chain ID to author chain ID
224
+ asym_to_auth_mapping = {}
225
+ for asym_id, auth_id in zip(asym_ids, auth_chain_ids):
226
+ asym_to_auth_mapping[asym_id] = auth_id
227
+
228
+ # Group by internal chain ID first, then map to author chain ID
229
+ chain_data = {}
230
+ for asym_id, seq_pos, auth_seq, ins_code, hetflag in zip(
231
+ asym_ids, seq_positions, auth_seq_nums, ins_codes, hetflags
232
+ ):
233
+ if asym_id not in chain_data:
234
+ chain_data[asym_id] = {}
235
+
236
+ try:
237
+ seq_index = int(seq_pos) - 1 # Convert to 0-based indexing
238
+ res_num = int(auth_seq) if auth_seq != "?" else None
239
+ except ValueError:
240
+ continue
241
+
242
+ if res_num is not None:
243
+ # Convert mmCIF "." and "?" to empty string
244
+ clean_ins_code = "" if ins_code in [".", "?"] else ins_code
245
+ else:
246
+ clean_ins_code = ""
247
+ res_num = None
248
+
249
+ is_het = hetflag.upper() == "Y" # type: ignore
250
+ chain_data[asym_id][seq_index] = Residue(
251
+ residue_number=res_num,
252
+ insertion_code=clean_ins_code, # type: ignore
253
+ hetflag=is_het,
254
+ )
255
+
256
+ # Handle cases where multiple residues have the same auth_seq_num
257
+ # by adjusting residue numbers to be unique within each chain
258
+ for asym_id, residue_data in chain_data.items():
259
+ # Check if there are duplicate residue numbers in this chain
260
+ positions_with_same_num = {}
261
+ for seq_idx, res_at_pos in residue_data.items():
262
+ if res_at_pos.residue_number is not None:
263
+ res_num = res_at_pos.residue_number
264
+ if res_num not in positions_with_same_num:
265
+ positions_with_same_num[res_num] = []
266
+ positions_with_same_num[res_num].append(seq_idx)
267
+
268
+ # Fix duplicate residue numbers by making them sequential
269
+ for res_num, seq_indices in positions_with_same_num.items():
270
+ if len(seq_indices) > 1:
271
+ # Multiple residues have the same residue number
272
+ # Make them sequential starting from the original number
273
+ seq_indices.sort() # Ensure consistent ordering
274
+ for i, seq_idx in enumerate(seq_indices):
275
+ original_pos = residue_data[seq_idx]
276
+ new_pos = Residue(
277
+ residue_number=res_num + i,
278
+ insertion_code=original_pos.insertion_code,
279
+ hetflag=original_pos.hetflag,
280
+ )
281
+ residue_data[seq_idx] = new_pos
282
+
283
+ # Create ordered mappings using author chain IDs
284
+ for asym_id in chain_data:
285
+ auth_chain_id = asym_to_auth_mapping.get(asym_id, asym_id)
286
+ if auth_chain_id in self.chain_to_seqres:
287
+ seq_len = len(self.chain_to_seqres[auth_chain_id])
288
+ ordered_mapping = {}
289
+
290
+ for i in range(seq_len):
291
+ if i in chain_data[asym_id]:
292
+ ordered_mapping[i] = chain_data[asym_id][i]
293
+ else:
294
+ # Missing residue - no structure coordinates
295
+ ordered_mapping[i] = Residue(
296
+ residue_number=None, insertion_code="", hetflag=False
297
+ )
298
+
299
+ self.seqres_to_structure[auth_chain_id] = ordered_mapping
300
+ else:
301
+ # Handle case where auth_chain_id is not in chain_to_seqres
302
+ # This can happen if the chain is not a polymer or if there's a parsing issue
303
+ # Create a basic mapping based on the chain_data
304
+ if chain_data[asym_id]:
305
+ # Sort by sequence index to create ordered mapping
306
+ sorted_indices = sorted(chain_data[asym_id].keys())
307
+ ordered_mapping = {}
308
+ for i, seq_idx in enumerate(sorted_indices):
309
+ ordered_mapping[i] = chain_data[asym_id][seq_idx]
310
+ self.seqres_to_structure[auth_chain_id] = ordered_mapping
311
+
312
+ # Ensure all chains have complete mappings
313
+ for chain_id in self.chain_to_seqres:
314
+ if chain_id not in self.seqres_to_structure:
315
+ seq_len = len(self.chain_to_seqres[chain_id])
316
+ self.seqres_to_structure[chain_id] = {
317
+ i: Residue(residue_number=None, insertion_code="", hetflag=False)
318
+ for i in range(seq_len)
319
+ }
320
+ else:
321
+ # Fill in any missing indices
322
+ seq_len = len(self.chain_to_seqres[chain_id])
323
+ mapping = self.seqres_to_structure[chain_id]
324
+ for i in range(seq_len):
325
+ if i not in mapping:
326
+ mapping[i] = Residue(
327
+ residue_number=None, insertion_code="", hetflag=False
328
+ )
329
+
330
+ # Fallback: create basic mappings from structure for missing chains
331
+ if (
332
+ self.structure
333
+ and hasattr(self.structure, "chain_id")
334
+ and self.structure.chain_id is not None
335
+ and hasattr(self.structure.chain_id, "__iter__")
336
+ ):
337
+ for chain_id in set(self.structure.chain_id):
338
+ if chain_id not in self.seqres_to_structure:
339
+ chain_structure = self.structure[
340
+ self.structure.chain_id == chain_id
341
+ ]
342
+ if (
343
+ hasattr(chain_structure, "res_id")
344
+ and chain_structure.res_id is not None
345
+ and hasattr(chain_structure.res_id, "__iter__")
346
+ ):
347
+ residue_ids = list(set(chain_structure.res_id))
348
+ residue_ids.sort()
349
+
350
+ self.seqres_to_structure[chain_id] = {
351
+ i: Residue(
352
+ residue_number=res_id, insertion_code="", hetflag=False
353
+ )
354
+ for i, res_id in enumerate(residue_ids)
355
+ }
356
+
357
+ def _parse_nonpoly_from_mmcif(self) -> dict[tuple, bs.AtomArray]:
358
+ """Parse non-polymer coordinates from mmCIF block data."""
359
+ nonpoly_coords = {}
360
+
361
+ # Get non-polymer entities from the mmCIF block
362
+ assert self.raw is not None
363
+ block = self.raw.block
364
+ nonpoly_entities = set()
365
+
366
+ # Find non-polymer entities
367
+ if "entity" in block:
368
+ entity_cat = block["entity"]
369
+ entity_ids = entity_cat["id"].as_array(str)
370
+ entity_types = entity_cat["type"].as_array(str)
371
+
372
+ for entity_id, entity_type in zip(entity_ids, entity_types):
373
+ if entity_type.upper() in ["NON-POLYMER", "WATER", "BRANCHED"]:
374
+ nonpoly_entities.add(entity_id)
375
+
376
+ # Map entities to chains for non-polymers
377
+ entity_to_chains = {}
378
+ if "pdbx_entity_nonpoly" in block:
379
+ nonpoly_cat = block["pdbx_entity_nonpoly"]
380
+ entity_ids = nonpoly_cat["entity_id"].as_array(str)
381
+ comp_ids = nonpoly_cat["comp_id"].as_array(str)
382
+
383
+ for entity_id, comp_id in zip(entity_ids, comp_ids):
384
+ if entity_id in nonpoly_entities:
385
+ entity_to_chains[entity_id] = comp_id
386
+
387
+ # Get atom site information for non-polymers
388
+ if "atom_site" in block:
389
+ atom_cat = block["atom_site"]
390
+ atom_chain_ids = atom_cat["label_asym_id"].as_array(str)
391
+ atom_entity_ids = atom_cat["label_entity_id"].as_array(str)
392
+ atom_comp_ids = atom_cat["label_comp_id"].as_array(str)
393
+
394
+ # Group non-polymer atoms by entity and chain
395
+ nonpoly_atom_groups = {}
396
+ for i, (chain_id, entity_id, comp_id) in enumerate(
397
+ zip(atom_chain_ids, atom_entity_ids, atom_comp_ids)
398
+ ):
399
+ if entity_id in nonpoly_entities:
400
+ key = (comp_id, chain_id)
401
+ if key not in nonpoly_atom_groups:
402
+ nonpoly_atom_groups[key] = []
403
+ nonpoly_atom_groups[key].append(i)
404
+
405
+ # Extract coordinates for each non-polymer group
406
+ for (comp_id, chain_id), atom_indices in nonpoly_atom_groups.items():
407
+ # Match atoms by comparing chain_id and residue name
408
+ structure_mask = (self.structure.chain_id == chain_id) & (
409
+ self.structure.res_name == comp_id
410
+ )
411
+
412
+ if structure_mask.any():
413
+ nonpoly_array = self.structure[structure_mask]
414
+ if (
415
+ isinstance(nonpoly_array, (bs.AtomArray, bs.AtomArrayStack))
416
+ and len(nonpoly_array) > 0
417
+ ):
418
+ nonpoly_coords[(comp_id, chain_id)] = nonpoly_array
419
+
420
+ return nonpoly_coords
421
+
422
+ def _parse_nonpoly_fallback(self) -> dict[tuple, bs.AtomArray]:
423
+ """Fallback method to extract heteroatoms directly from structure."""
424
+ nonpoly_coords = {}
425
+
426
+ if not (self.structure and hasattr(self.structure, "chain_id")):
427
+ return nonpoly_coords
428
+
429
+ # Create set of standard residues from residue_constants
430
+ standard_residues = set(residue_constants.resnames[:-1]) # Exclude 'UNK'
431
+ standard_residues.update({"A", "C", "G", "T", "U"}) # Add nucleic acids
432
+
433
+ if hasattr(self.structure, "chain_id") and self.structure.chain_id is not None:
434
+ for chain_id in set(self.structure.chain_id):
435
+ chain_structure = self.structure[self.structure.chain_id == chain_id]
436
+
437
+ # Find non-standard residues
438
+ if (
439
+ hasattr(chain_structure, "res_name")
440
+ and chain_structure.res_name is not None
441
+ and hasattr(chain_structure.res_name, "__iter__")
442
+ ):
443
+ for res_name in set(chain_structure.res_name):
444
+ if res_name not in standard_residues:
445
+ res_mask = (chain_structure.chain_id == chain_id) & (
446
+ chain_structure.res_name == res_name
447
+ )
448
+ if res_mask.any() and isinstance(
449
+ chain_structure, (bs.AtomArray, bs.AtomArrayStack)
450
+ ):
451
+ nonpoly_array = chain_structure[res_mask]
452
+ nonpoly_coords[(res_name, chain_id)] = nonpoly_array
453
+
454
+ return nonpoly_coords
455
+
456
+ @functools.cached_property
457
+ def non_polymer_coords(self) -> dict[tuple, bs.AtomArray]:
458
+ """
459
+ Extract non-polymer coordinates (ligands, cofactors, etc.) from mmCIF structure.
460
+
461
+ Returns a dictionary mapping (nonpolymer_info, chain_id) tuples to AtomArrays.
462
+ """
463
+ if not self.structure or not self.raw:
464
+ return {}
465
+
466
+ try:
467
+ return self._parse_nonpoly_from_mmcif()
468
+ except Exception:
469
+ return self._parse_nonpoly_fallback()
470
+
esmfold2_molecular_complex.py ADDED
@@ -0,0 +1,1226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ import os
5
+ import re
6
+ from dataclasses import asdict, dataclass
7
+ from pathlib import Path
8
+ from subprocess import check_output
9
+ from tempfile import TemporaryDirectory
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ import biotite.structure as bs
13
+ import biotite.structure.io.pdbx as pdbx
14
+ import brotli
15
+ import msgpack
16
+ import numpy as np
17
+ import torch
18
+ from biotite.structure.io.pdbx import (
19
+ CIFCategory,
20
+ CIFColumn,
21
+ CIFData,
22
+ CIFFile,
23
+ set_structure,
24
+ )
25
+
26
+ from . import esmfold2_residue_constants
27
+ from .esmfold2_metrics import compute_lddt, compute_rmsd
28
+ from .esmfold2_protein_complex import ProteinComplex, ProteinComplexMetadata
29
+
30
+
31
+ @dataclass
32
+ class MolecularComplexResult:
33
+ """Result of molecular complex folding"""
34
+
35
+ complex: MolecularComplex
36
+ plddt: torch.Tensor | None = None
37
+ ptm: float | None = None
38
+ iptm: float | None = None
39
+ pae: torch.Tensor | None = None
40
+ distogram: torch.Tensor | None = None
41
+ pair_chains_iptm: torch.Tensor | None = None
42
+ output_embedding_sequence: torch.Tensor | None = None
43
+ output_embedding_pair_pooled: torch.Tensor | None = None
44
+ residue_index: torch.Tensor | None = None
45
+ entity_id: torch.Tensor | None = None
46
+ sae_features: np.ndarray | None = None # [L, n_features]
47
+
48
+
49
+ @dataclass
50
+ class MolecularComplexMetadata:
51
+ """Metadata for MolecularComplex objects."""
52
+
53
+ entity_lookup: dict[int, str]
54
+ chain_lookup: dict[int, str]
55
+ assembly_composition: dict[str, list[str]] | None = None
56
+
57
+
58
+ @dataclass
59
+ class Molecule:
60
+ """Represents a single molecule/token within a MolecularComplex."""
61
+
62
+ token: str
63
+ token_idx: int
64
+ atom_positions: np.ndarray # [N_atoms, 3]
65
+ atom_elements: np.ndarray # [N_atoms] element strings
66
+ atom_names: np.ndarray | None = None # [N_atoms] atom names (optional)
67
+ atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional)
68
+ residue_type: int = 0
69
+ molecule_type: int = 0 # PROTEIN=0, RNA=1, DNA=2, LIGAND=3
70
+ confidence: float = 0.0
71
+
72
+
73
+ @dataclass(frozen=True)
74
+ class MolecularComplex:
75
+ """
76
+ Dataclass representing a molecular complex with support for proteins, nucleic acids, and ligands.
77
+
78
+ Uses a flat atom representation with token-based sequence indexing, supporting all atom types
79
+ beyond the traditional atom37 protein representation.
80
+ """
81
+
82
+ id: str
83
+ sequence: list[str] # Token sequence like ['MET', 'LYS', 'A', 'G', 'ATP']
84
+
85
+ # Flat atom arrays - simplified representation
86
+ atom_positions: np.ndarray # [N_atoms, 3] 3D coordinates
87
+ atom_elements: np.ndarray # [N_atoms] element strings
88
+
89
+ # Token-to-atom mapping for efficient access
90
+ token_to_atoms: np.ndarray # [N_tokens, 2] start/end indices into atoms array
91
+
92
+ # Chain information
93
+ chain_id: np.ndarray # [N_tokens] chain identifier for each token
94
+
95
+ # Confidence data
96
+ plddt: np.ndarray # Per-token confidence scores [N_tokens]
97
+
98
+ # Metadata
99
+ metadata: MolecularComplexMetadata
100
+
101
+ # Optional atom names and hetero flags (preserved from original structures)
102
+ atom_names: np.ndarray | None = None # [N_atoms] atom names (optional)
103
+ atom_hetero: np.ndarray | None = None # [N_atoms] hetero flags (optional)
104
+
105
+ def __post_init__(self):
106
+ """Validate array dimensions."""
107
+ n_tokens = len(self.sequence)
108
+ n_atoms = len(self.atom_positions)
109
+ assert (
110
+ self.token_to_atoms.shape[0] == n_tokens
111
+ ), f"token_to_atoms shape {self.token_to_atoms.shape} != {n_tokens} tokens"
112
+ assert (
113
+ self.chain_id.shape[0] == n_tokens
114
+ ), f"chain_id shape {self.chain_id.shape} != {n_tokens} tokens"
115
+ assert (
116
+ self.plddt.shape[0] == n_tokens
117
+ ), f"plddt shape {self.plddt.shape} != {n_tokens} tokens"
118
+ if self.atom_names is not None:
119
+ assert (
120
+ self.atom_names.shape[0] == n_atoms
121
+ ), f"atom_names shape {self.atom_names.shape} != {n_atoms} atoms"
122
+ if self.atom_hetero is not None:
123
+ assert (
124
+ self.atom_hetero.shape[0] == n_atoms
125
+ ), f"atom_hetero shape {self.atom_hetero.shape} != {n_atoms} atoms"
126
+
127
+ def __len__(self) -> int:
128
+ """Return number of tokens."""
129
+ return len(self.sequence)
130
+
131
+ def __getitem__(self, idx: int) -> Molecule:
132
+ """Access individual molecules/tokens by index."""
133
+ if idx >= len(self.sequence) or idx < 0:
134
+ raise IndexError(
135
+ f"Token index {idx} out of range for {len(self.sequence)} tokens"
136
+ )
137
+
138
+ token = self.sequence[idx]
139
+ start_atom, end_atom = self.token_to_atoms[idx]
140
+
141
+ # Extract atom data for this token
142
+ token_atom_positions = self.atom_positions[start_atom:end_atom]
143
+ token_atom_elements = self.atom_elements[start_atom:end_atom]
144
+ token_atom_names = None
145
+ if self.atom_names is not None:
146
+ token_atom_names = self.atom_names[start_atom:end_atom]
147
+ token_atom_hetero = None
148
+ if self.atom_hetero is not None:
149
+ token_atom_hetero = self.atom_hetero[start_atom:end_atom]
150
+
151
+ # Default values for residue/molecule type (would be extended based on actual implementation)
152
+ residue_type = 0 # Default to standard residue
153
+ molecule_type = 0 # Default to protein
154
+
155
+ return Molecule(
156
+ token=token,
157
+ token_idx=idx,
158
+ atom_positions=token_atom_positions,
159
+ atom_elements=token_atom_elements,
160
+ atom_names=token_atom_names,
161
+ atom_hetero=token_atom_hetero,
162
+ residue_type=residue_type,
163
+ molecule_type=molecule_type,
164
+ confidence=self.plddt[idx],
165
+ )
166
+
167
+ @property
168
+ def atom_coordinates(self) -> np.ndarray:
169
+ """Get flat array of all atom coordinates [N_atoms, 3]."""
170
+ return self.atom_positions
171
+
172
+ # Conversion methods
173
+ @classmethod
174
+ def from_protein_complex(cls, pc: ProteinComplex) -> "MolecularComplex":
175
+ """Convert a ProteinComplex to MolecularComplex.
176
+
177
+ Args:
178
+ pc: ProteinComplex object with atom37 representation
179
+
180
+ Returns:
181
+ MolecularComplex with flat atom arrays and token-based indexing
182
+ """
183
+ from . import esmfold2_residue_constants
184
+
185
+ # Extract sequence without chain breaks
186
+ sequence_no_breaks = pc.sequence.replace("|", "")
187
+ sequence_tokens = [
188
+ residue_constants.restype_1to3.get(aa, "UNK") for aa in sequence_no_breaks
189
+ ]
190
+
191
+ # Convert atom37 to flat arrays
192
+ flat_positions = []
193
+ flat_elements = []
194
+ flat_names = []
195
+ flat_hetero = []
196
+ token_to_atoms = []
197
+
198
+ atom_idx = 0
199
+
200
+ for i, aa in enumerate(pc.sequence):
201
+ if aa == "|":
202
+ # Skip chain break tokens
203
+ continue
204
+
205
+ # Get atom37 positions and mask for this residue.
206
+ # ProteinComplex arrays are indexed by sequence position (including |),
207
+ # so use `i` not a separate residue counter.
208
+ res_positions = pc.atom37_positions[i] # [37, 3]
209
+ res_mask = pc.atom37_mask[i] # [37]
210
+
211
+ # Track start position for this token
212
+ token_start = atom_idx
213
+
214
+ # Process each atom type in atom37 representation
215
+ for atom_type_idx, atom_name in enumerate(residue_constants.atom_types):
216
+ if res_mask[atom_type_idx]: # Atom is present
217
+ # Add position
218
+ flat_positions.append(res_positions[atom_type_idx])
219
+
220
+ # Determine element from atom name
221
+ element = (
222
+ atom_name[0] if atom_name else "C"
223
+ ) # First character is element
224
+ flat_elements.append(element)
225
+
226
+ # Add atom name
227
+ flat_names.append(atom_name)
228
+
229
+ # Add hetero flag (all proteins are non-hetero)
230
+ flat_hetero.append(False)
231
+
232
+ atom_idx += 1
233
+
234
+ # Record token-to-atom mapping [start_idx, end_idx)
235
+ token_to_atoms.append([token_start, atom_idx])
236
+
237
+ # Convert to numpy arrays
238
+ atom_positions = np.array(flat_positions, dtype=np.float32)
239
+ atom_elements = np.array(flat_elements, dtype=object)
240
+ atom_names = np.array(flat_names, dtype=object)
241
+ atom_hetero = np.array(flat_hetero, dtype=bool)
242
+ token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
243
+
244
+ # Extract confidence scores and chain_ids (skip chain breaks)
245
+ confidence_scores = []
246
+ chain_ids = []
247
+ for seq_idx, aa in enumerate(pc.sequence):
248
+ if aa != "|":
249
+ confidence_scores.append(pc.confidence[seq_idx])
250
+ chain_ids.append(pc.chain_id[seq_idx])
251
+
252
+ confidence_array = np.array(confidence_scores, dtype=np.float32)
253
+ chain_id_array = np.array(chain_ids, dtype=np.int64)
254
+
255
+ # Create metadata - convert entity IDs to strings for MolecularComplexMetadata
256
+ entity_lookup_str = {k: str(v) for k, v in pc.metadata.entity_lookup.items()}
257
+ metadata = MolecularComplexMetadata(
258
+ entity_lookup=entity_lookup_str,
259
+ chain_lookup=pc.metadata.chain_lookup,
260
+ assembly_composition=pc.metadata.assembly_composition,
261
+ )
262
+
263
+ return cls(
264
+ id=pc.id,
265
+ sequence=sequence_tokens,
266
+ atom_positions=atom_positions,
267
+ atom_elements=atom_elements,
268
+ token_to_atoms=token_to_atoms_array,
269
+ chain_id=chain_id_array,
270
+ plddt=confidence_array,
271
+ metadata=metadata,
272
+ atom_names=atom_names,
273
+ atom_hetero=atom_hetero,
274
+ )
275
+
276
+ def to_protein_complex(self) -> ProteinComplex:
277
+ """Convert MolecularComplex back to ProteinComplex format.
278
+
279
+ Extracts only protein tokens and converts from flat atom representation
280
+ back to atom37 format used by ProteinComplex.
281
+
282
+ Returns:
283
+ ProteinComplex with protein residues only, excluding ligands/nucleic acids
284
+ """
285
+ from . import esmfold2_residue_constants
286
+
287
+ # No need for element mapping - already using element characters
288
+
289
+ # Filter for protein tokens only (skip ligands, nucleic acids)
290
+ protein_tokens = []
291
+ protein_indices = []
292
+
293
+ for i, token in enumerate(self.sequence):
294
+ # Check if token is a standard 3-letter amino acid code
295
+ if token in residue_constants.restype_3to1:
296
+ protein_tokens.append(token)
297
+ protein_indices.append(i)
298
+
299
+ if not protein_tokens:
300
+ raise ValueError("No protein tokens found in MolecularComplex")
301
+
302
+ n_residues = len(protein_tokens)
303
+
304
+ # Initialize atom37 arrays
305
+ atom37_positions = np.full((n_residues, 37, 3), np.nan, dtype=np.float32)
306
+ atom37_mask = np.zeros((n_residues, 37), dtype=bool)
307
+
308
+ # Extract confidence scores and chain_ids for protein residues only
309
+ protein_confidence = self.plddt[protein_indices]
310
+ protein_chain_ids = self.chain_id[protein_indices]
311
+
312
+ # Convert tokens back to single-letter sequence with chain breaks
313
+ single_letter_residues = []
314
+ prev_chain_id = None
315
+
316
+ for i, (token, chain_id_val) in enumerate(
317
+ zip(protein_tokens, protein_chain_ids)
318
+ ):
319
+ # Add chain break if we're switching to a new chain
320
+ if prev_chain_id is not None and chain_id_val != prev_chain_id:
321
+ single_letter_residues.append("|")
322
+ single_letter_residues.append(residue_constants.restype_3to1[token])
323
+ prev_chain_id = chain_id_val
324
+
325
+ single_letter_sequence = "".join(single_letter_residues)
326
+
327
+ # Calculate final sequence length (includes chain breaks)
328
+ sequence_length = len(single_letter_sequence)
329
+
330
+ # Convert flat atoms back to atom37 representation using atom names
331
+ for res_idx, token_idx in enumerate(protein_indices):
332
+ token = self.sequence[token_idx]
333
+ start_atom, end_atom = self.token_to_atoms[token_idx]
334
+
335
+ res_atom_positions = self.atom_positions[start_atom:end_atom]
336
+ res_atom_names = (
337
+ np.array(self.atom_names[start_atom:end_atom], dtype=str)
338
+ if self.atom_names is not None
339
+ else np.array([], dtype=str)
340
+ )
341
+
342
+ # Build a mapping from normalized atom name -> position for this residue
343
+ # Normalize to uppercase and strip whitespace for robust matching
344
+ name_to_pos: dict[str, np.ndarray] = {}
345
+ for i, nm in enumerate(res_atom_names):
346
+ key = nm.upper().strip()
347
+ # Prefer first occurrence; ignore duplicates/altlocs
348
+ if key not in name_to_pos:
349
+ name_to_pos[key] = res_atom_positions[i]
350
+
351
+ # Place atoms into atom37 by matching stored atom names to atom37 indices.
352
+ # This handles all atoms present in the flat representation, not just
353
+ # the canonical residue_atoms for this residue type. This preserves
354
+ # atoms that were in the original atom37_mask even if they're atypical
355
+ # for the residue (e.g., from alternate conformations or data quirks).
356
+ for atom_name_str, pos in name_to_pos.items():
357
+ idx37 = residue_constants.atom_order.get(atom_name_str)
358
+ if idx37 is not None:
359
+ atom37_positions[res_idx, idx37] = pos
360
+ atom37_mask[res_idx, idx37] = True
361
+
362
+ # Create arrays that match sequence length (including chain breaks)
363
+ # Initialize arrays with proper size
364
+ chain_id_expanded = np.full(sequence_length, -1, dtype=np.int64)
365
+ entity_id_expanded = np.full(sequence_length, -1, dtype=np.int64)
366
+ sym_id_expanded = np.zeros(sequence_length, dtype=np.int64)
367
+ residue_index_expanded = np.zeros(sequence_length, dtype=np.int64)
368
+ insertion_code_expanded = np.array([""] * sequence_length, dtype=object)
369
+ confidence_expanded = np.zeros(sequence_length, dtype=np.float32)
370
+ atom37_positions_expanded = np.full(
371
+ (sequence_length, 37, 3), np.nan, dtype=np.float32
372
+ )
373
+ atom37_mask_expanded = np.zeros((sequence_length, 37), dtype=bool)
374
+
375
+ # Map residue data to sequence positions (skipping chain breaks)
376
+ residue_idx = 0
377
+ residue_counter_per_chain = {}
378
+
379
+ for seq_pos, char in enumerate(single_letter_sequence):
380
+ if char != "|":
381
+ # This is a residue position
382
+ chain_id_val = protein_chain_ids[residue_idx]
383
+
384
+ chain_id_expanded[seq_pos] = chain_id_val
385
+ entity_id_expanded[seq_pos] = chain_id_val # Simplified mapping
386
+
387
+ # Track residue numbering per chain
388
+ if chain_id_val not in residue_counter_per_chain:
389
+ residue_counter_per_chain[chain_id_val] = 1
390
+ else:
391
+ residue_counter_per_chain[chain_id_val] += 1
392
+
393
+ residue_index_expanded[seq_pos] = residue_counter_per_chain[
394
+ chain_id_val
395
+ ]
396
+ confidence_expanded[seq_pos] = protein_confidence[residue_idx]
397
+ atom37_positions_expanded[seq_pos] = atom37_positions[residue_idx]
398
+ atom37_mask_expanded[seq_pos] = atom37_mask[residue_idx]
399
+
400
+ residue_idx += 1
401
+ # Chain break positions keep default values (-1, False, etc.)
402
+
403
+ # Use the expanded arrays
404
+ chain_id = chain_id_expanded
405
+ entity_id = entity_id_expanded
406
+ sym_id = sym_id_expanded
407
+ residue_index = residue_index_expanded
408
+ insertion_code = insertion_code_expanded
409
+ protein_confidence = confidence_expanded
410
+ atom37_positions = atom37_positions_expanded
411
+ atom37_mask = atom37_mask_expanded
412
+
413
+ # Create protein complex metadata preserving chain information
414
+ # Convert MolecularComplex metadata to ProteinComplex format
415
+ unique_chain_ids = np.unique(protein_chain_ids)
416
+ entity_lookup = {int(cid): int(cid) for cid in unique_chain_ids}
417
+ chain_lookup = {
418
+ int(cid): self.metadata.chain_lookup.get(int(cid), chr(65 + int(cid)))
419
+ for cid in unique_chain_ids
420
+ }
421
+
422
+ protein_metadata = ProteinComplexMetadata(
423
+ entity_lookup=entity_lookup,
424
+ chain_lookup=chain_lookup,
425
+ assembly_composition=self.metadata.assembly_composition,
426
+ )
427
+
428
+ return ProteinComplex(
429
+ id=self.id,
430
+ sequence=single_letter_sequence,
431
+ entity_id=entity_id,
432
+ chain_id=chain_id,
433
+ sym_id=sym_id,
434
+ residue_index=residue_index,
435
+ insertion_code=insertion_code,
436
+ atom37_positions=atom37_positions,
437
+ atom37_mask=atom37_mask,
438
+ confidence=protein_confidence,
439
+ metadata=protein_metadata,
440
+ )
441
+
442
+ @classmethod
443
+ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex":
444
+ """Read MolecularComplex from mmcif file or string.
445
+
446
+ Args:
447
+ inp: Path to mmCIF file or mmCIF content as string
448
+ id: Optional identifier to assign to the complex
449
+
450
+ Returns:
451
+ MolecularComplex with all molecules (proteins, ligands, nucleic acids)
452
+ """
453
+ from io import StringIO
454
+
455
+ # Check if input is a file path or mmCIF string content
456
+ if os.path.exists(inp):
457
+ # Input is a file path
458
+ mmcif_file = pdbx.CIFFile.read(inp)
459
+ else:
460
+ # Input is mmCIF string content
461
+ mmcif_file = pdbx.CIFFile.read(StringIO(inp))
462
+
463
+ # Get structure - handle missing model information gracefully
464
+ try:
465
+ structure = pdbx.get_structure(
466
+ mmcif_file, model=1, extra_fields=["b_factor"]
467
+ )
468
+ except (KeyError, ValueError):
469
+ # Fallback for mmCIF files without model information
470
+ try:
471
+ structure = pdbx.get_structure(mmcif_file)
472
+ except Exception:
473
+ # Last resort: use the first available model or all atoms
474
+ structure = pdbx.get_structure(mmcif_file, model=None)
475
+ # Type hint for pyright - structure is an AtomArray which is iterable
476
+ if TYPE_CHECKING:
477
+ structure: Any = structure
478
+
479
+ # Read label_asym_id from the raw CIF atom_site category.
480
+ # Biotite's atom.chain_id uses auth_asym_id, which collapses ligands
481
+ # onto their parent protein chain. label_asym_id gives each entity a
482
+ # distinct chain identifier.
483
+ block = mmcif_file.block
484
+ label_asym_ids: list[str] | None = None
485
+ if "atom_site" in block:
486
+ atom_site = block["atom_site"]
487
+ if "label_asym_id" in atom_site:
488
+ _col = atom_site["label_asym_id"]
489
+ _raw = (
490
+ _col.as_array(str)
491
+ if hasattr(_col, "as_array")
492
+ else np.array(list(_col), dtype=str) # type: ignore[arg-type]
493
+ )
494
+ # biotite's get_structure(model=1) filters to model 1 AND
495
+ # removes alternate conformations. We must apply the same
496
+ # filters to label_asym_id to keep arrays aligned.
497
+ keep = np.ones(len(_raw), dtype=bool)
498
+ if "pdbx_PDB_model_num" in atom_site:
499
+ _mc = atom_site["pdbx_PDB_model_num"]
500
+ _models = (
501
+ _mc.as_array(str)
502
+ if hasattr(_mc, "as_array")
503
+ else np.array(list(_mc), dtype=str) # type: ignore[arg-type]
504
+ )
505
+ keep &= _models == "1"
506
+ if "label_alt_id" in atom_site:
507
+ _ac = atom_site["label_alt_id"]
508
+ _alts = (
509
+ _ac.as_array(str)
510
+ if hasattr(_ac, "as_array")
511
+ else np.array(list(_ac), dtype=str) # type: ignore[arg-type]
512
+ )
513
+ keep &= np.isin(_alts, [".", "?", "", "A"])
514
+ filtered = _raw[keep]
515
+ if len(filtered) == len(structure):
516
+ label_asym_ids = filtered.tolist()
517
+ # If lengths still don't match, fall back to atom.chain_id
518
+
519
+ # Get entity information from mmCIF
520
+ entity_info = {}
521
+ try:
522
+ if "entity" in block:
523
+ entity_category = block["entity"]
524
+ if "id" in entity_category and "type" in entity_category:
525
+ entity_ids = entity_category["id"]
526
+ entity_types = entity_category["type"]
527
+ # Convert CIFColumn to list for iteration
528
+ if hasattr(entity_ids, "__iter__") and hasattr(
529
+ entity_types, "__iter__"
530
+ ):
531
+ # Type annotation to help pyright understand these are iterable
532
+ entity_ids_list = list(entity_ids) # type: ignore
533
+ entity_types_list = list(entity_types) # type: ignore
534
+ for eid, etype in zip(entity_ids_list, entity_types_list):
535
+ entity_info[eid] = etype
536
+ except Exception:
537
+ pass
538
+
539
+ # Initialize arrays for flat atom representation
540
+ sequence_tokens = []
541
+ flat_positions = []
542
+ flat_elements = []
543
+ flat_names = []
544
+ flat_hetero = []
545
+ token_to_atoms = []
546
+ confidence_scores = []
547
+ chain_ids = [] # Track chain IDs for each token
548
+
549
+ atom_idx = 0
550
+
551
+ # Group atoms by chain and residue.
552
+ # Use label_asym_id (distinct per entity) when available, otherwise
553
+ # fall back to biotite's chain_id (auth_asym_id).
554
+ chain_residue_groups: dict[str, dict[tuple[int, str], dict]] = {}
555
+ for atom_i, atom in enumerate(structure):
556
+ chain_id = (
557
+ label_asym_ids[atom_i] if label_asym_ids is not None else atom.chain_id
558
+ )
559
+ res_id = atom.res_id
560
+ res_name = atom.res_name
561
+
562
+ if chain_id not in chain_residue_groups:
563
+ chain_residue_groups[chain_id] = {}
564
+ # Key by (res_id, res_name) to distinguish residues that share
565
+ # the same res_id but have different res_name (e.g. a protein
566
+ # residue and a ligand that were on the same auth chain).
567
+ res_key = (res_id, res_name)
568
+ if res_key not in chain_residue_groups[chain_id]:
569
+ chain_residue_groups[chain_id][res_key] = {
570
+ "atoms": [],
571
+ "res_name": res_name,
572
+ "is_hetero": atom.hetero,
573
+ }
574
+ chain_residue_groups[chain_id][res_key]["atoms"].append(atom)
575
+
576
+ # Create a mapping from chain_id to numeric indices
577
+ chain_id_to_numeric = {
578
+ chain_id: idx
579
+ for idx, chain_id in enumerate(sorted(chain_residue_groups.keys()))
580
+ }
581
+
582
+ # Process each chain and residue
583
+ for chain_id in sorted(chain_residue_groups.keys()):
584
+ residues = chain_residue_groups[chain_id]
585
+ numeric_chain_id = chain_id_to_numeric[chain_id]
586
+
587
+ for res_key in sorted(residues.keys()):
588
+ residue_data = residues[res_key]
589
+ res_name = residue_data["res_name"]
590
+ atoms = residue_data["atoms"]
591
+ is_hetero = residue_data["is_hetero"]
592
+
593
+ # Skip water molecules
594
+ if res_name == "HOH":
595
+ continue
596
+
597
+ # Determine token name
598
+ if not is_hetero and res_name in residue_constants.restype_3to1:
599
+ # Standard amino acid
600
+ token_name = res_name
601
+ elif res_name in ["A", "T", "G", "C", "U", "DA", "DT", "DG", "DC"]:
602
+ # Nucleotide
603
+ token_name = res_name
604
+ else:
605
+ # Ligand or other molecule
606
+ token_name = res_name
607
+
608
+ sequence_tokens.append(token_name)
609
+ chain_ids.append(
610
+ numeric_chain_id
611
+ ) # Store the numeric chain ID for this token
612
+ token_start = atom_idx
613
+
614
+ # Add all atoms from this residue
615
+ for atom in atoms:
616
+ flat_positions.append(atom.coord)
617
+
618
+ # Get element character
619
+ element = atom.element
620
+ flat_elements.append(element)
621
+
622
+ # Get atom name
623
+ atom_name = atom.atom_name
624
+ flat_names.append(atom_name)
625
+
626
+ # Get hetero flag
627
+ hetero_flag = atom.hetero
628
+ flat_hetero.append(hetero_flag)
629
+
630
+ atom_idx += 1
631
+
632
+ # Record token-to-atom mapping
633
+ token_to_atoms.append([token_start, atom_idx])
634
+
635
+ # Add confidence score (B-factor if available, otherwise 1.0)
636
+ bfactor = getattr(atoms[0], "b_factor", 50.0) if atoms else 50.0
637
+ confidence_scores.append(min(bfactor / 100.0, 1.0))
638
+
639
+ # Convert to numpy arrays
640
+ if not flat_positions:
641
+ # Create minimal arrays if no atoms found
642
+ atom_positions = np.zeros((0, 3), dtype=np.float32)
643
+ atom_elements = np.zeros(0, dtype=object)
644
+ atom_names = np.zeros(0, dtype=object)
645
+ atom_hetero = np.zeros(0, dtype=bool)
646
+ token_to_atoms_array = np.zeros((len(sequence_tokens), 2), dtype=np.int32)
647
+ chain_id_array = (
648
+ np.array(chain_ids, dtype=np.int64)
649
+ if chain_ids
650
+ else np.zeros(len(sequence_tokens), dtype=np.int64)
651
+ )
652
+ else:
653
+ atom_positions = np.array(flat_positions, dtype=np.float32)
654
+ atom_elements = np.array(flat_elements, dtype=object)
655
+ atom_names = np.array(flat_names, dtype=object)
656
+ atom_hetero = np.array(flat_hetero, dtype=bool)
657
+ token_to_atoms_array = np.array(token_to_atoms, dtype=np.int32)
658
+ chain_id_array = np.array(chain_ids, dtype=np.int64)
659
+
660
+ confidence_array = np.array(confidence_scores, dtype=np.float32)
661
+
662
+ # Create metadata using the chain_id_to_numeric mapping
663
+ if chain_residue_groups:
664
+ chain_lookup = {
665
+ numeric_id: chain_id
666
+ for chain_id, numeric_id in chain_id_to_numeric.items()
667
+ }
668
+ else:
669
+ chain_lookup = {}
670
+
671
+ metadata = MolecularComplexMetadata(
672
+ entity_lookup=entity_info,
673
+ chain_lookup=chain_lookup,
674
+ assembly_composition=None,
675
+ )
676
+
677
+ # Set complex ID - if input was a path, use the stem; otherwise use default
678
+ if os.path.exists(inp):
679
+ complex_id = id or Path(inp).stem
680
+ else:
681
+ complex_id = id or "complex_from_string"
682
+
683
+ return cls(
684
+ id=complex_id,
685
+ sequence=sequence_tokens,
686
+ atom_positions=atom_positions,
687
+ atom_elements=atom_elements,
688
+ token_to_atoms=token_to_atoms_array,
689
+ chain_id=chain_id_array,
690
+ plddt=confidence_array,
691
+ metadata=metadata,
692
+ atom_names=atom_names,
693
+ atom_hetero=atom_hetero,
694
+ )
695
+
696
+ def _get_entity_mapping(
697
+ self,
698
+ ) -> tuple[dict[str, list[str]], dict[str, int], dict[int, tuple[str, ...]]]:
699
+ """Compute chain→sequence, chain→entity_id, and entity_id→sequence mappings.
700
+
701
+ Returns:
702
+ (chain_sequences, chain_to_entity, entity_sequences)
703
+ """
704
+ chain_sequences: dict[str, list[str]] = {}
705
+ for token_idx in range(len(self.token_to_atoms)):
706
+ chain_id_numeric = self.chain_id[token_idx]
707
+ chain_id_str = self.metadata.chain_lookup.get(
708
+ int(chain_id_numeric), chr(65 + int(chain_id_numeric))
709
+ )
710
+ if chain_id_str not in chain_sequences:
711
+ chain_sequences[chain_id_str] = []
712
+ chain_sequences[chain_id_str].append(self.sequence[token_idx])
713
+
714
+ sequence_to_entity: dict[tuple[str, ...], int] = {}
715
+ chain_to_entity: dict[str, int] = {}
716
+ entity_sequences: dict[int, tuple[str, ...]] = {}
717
+ entity_id_counter = 1
718
+ for chain_id_str, sequence in chain_sequences.items():
719
+ seq_tuple = tuple(sequence)
720
+ if seq_tuple not in sequence_to_entity:
721
+ sequence_to_entity[seq_tuple] = entity_id_counter
722
+ entity_sequences[entity_id_counter] = seq_tuple
723
+ entity_id_counter += 1
724
+ chain_to_entity[chain_id_str] = sequence_to_entity[seq_tuple]
725
+
726
+ return chain_sequences, chain_to_entity, entity_sequences
727
+
728
+ def _add_entity_information(
729
+ self, cif_file: CIFFile, entity_sequences: dict[int, tuple[str, ...]]
730
+ ) -> None:
731
+ """Add _entity category to CIF file so OST can identify ligands vs polymers."""
732
+
733
+ entity_ids: list[str] = []
734
+ entity_types: list[str] = []
735
+ entity_descriptions: list[str] = []
736
+ for eid in sorted(entity_sequences.keys()):
737
+ seq = entity_sequences[eid]
738
+ entity_ids.append(str(eid))
739
+ has_protein = any(t in residue_constants.restype_3to1 for t in seq)
740
+ has_na = any(
741
+ t in ("A", "T", "G", "C", "U", "DA", "DT", "DG", "DC") for t in seq
742
+ )
743
+ if has_protein or has_na:
744
+ entity_types.append("polymer")
745
+ if has_protein:
746
+ entity_descriptions.append(f"Polymer entity {eid} (protein)")
747
+ else:
748
+ entity_descriptions.append(f"Polymer entity {eid} (nucleic acid)")
749
+ else:
750
+ entity_types.append("non-polymer")
751
+ entity_descriptions.append(f"Non-polymer entity {eid}")
752
+
753
+ if entity_ids:
754
+ cif_file.block["entity"] = CIFCategory(
755
+ name="entity",
756
+ columns={
757
+ "id": CIFColumn(
758
+ data=CIFData(array=np.array(entity_ids), dtype=np.str_)
759
+ ),
760
+ "type": CIFColumn(
761
+ data=CIFData(array=np.array(entity_types), dtype=np.str_)
762
+ ),
763
+ "pdbx_description": CIFColumn(
764
+ data=CIFData(array=np.array(entity_descriptions), dtype=np.str_)
765
+ ),
766
+ },
767
+ )
768
+
769
+ # Add _struct_asym to map chain IDs to entity IDs
770
+ _, chain_to_entity, _ = self._get_entity_mapping()
771
+ if chain_to_entity:
772
+ asym_ids = sorted(chain_to_entity.keys())
773
+ asym_entity_ids = [str(chain_to_entity[c]) for c in asym_ids]
774
+ cif_file.block["struct_asym"] = CIFCategory(
775
+ name="struct_asym",
776
+ columns={
777
+ "id": CIFColumn(
778
+ data=CIFData(array=np.array(asym_ids), dtype=np.str_)
779
+ ),
780
+ "entity_id": CIFColumn(
781
+ data=CIFData(array=np.array(asym_entity_ids), dtype=np.str_)
782
+ ),
783
+ },
784
+ )
785
+
786
+ def to_mmcif(self) -> str:
787
+ """Write MolecularComplex to mmcif string using biotite.
788
+
789
+ Returns:
790
+ String representation of the complex in mmCIF format
791
+ """
792
+ # Pre-allocate AtomArray
793
+ n_atoms = len(self.atom_positions)
794
+ atom_array = bs.AtomArray(length=n_atoms)
795
+
796
+ # Set coordinates directly (already vectorized)
797
+ atom_array.coord = self.atom_positions
798
+
799
+ # Pre-allocate per-atom arrays
800
+ atom_res_ids = np.zeros(n_atoms, dtype=np.int32)
801
+ atom_chain_ids = np.empty(n_atoms, dtype=object)
802
+ atom_res_names = np.empty(n_atoms, dtype=object)
803
+ atom_hetero = np.zeros(n_atoms, dtype=bool)
804
+ atom_bfactors = np.zeros(n_atoms, dtype=np.float32)
805
+ atom_names = np.empty(n_atoms, dtype=object)
806
+
807
+ # Build entity mappings: chains with identical sequences share entity ID
808
+ _, chain_to_entity, entity_sequences = self._get_entity_mapping()
809
+
810
+ atom_entity_ids = np.zeros(n_atoms, dtype=np.int32)
811
+
812
+ # Track residue IDs per chain
813
+ chain_res_counters: dict[int, int] = {}
814
+
815
+ # Vectorized expansion of token-level to atom-level annotations
816
+ for token_idx, (start, end) in enumerate(self.token_to_atoms):
817
+ token = self.sequence[token_idx]
818
+ chain_id_numeric = self.chain_id[token_idx]
819
+ chain_id_str = self.metadata.chain_lookup.get(
820
+ int(chain_id_numeric), chr(65 + int(chain_id_numeric))
821
+ )
822
+
823
+ # Track residue numbering per chain
824
+ if chain_id_numeric not in chain_res_counters:
825
+ chain_res_counters[chain_id_numeric] = 1
826
+ res_id = chain_res_counters[chain_id_numeric]
827
+ chain_res_counters[chain_id_numeric] += 1
828
+
829
+ # Determine if protein
830
+ is_protein = token in residue_constants.restype_3to1
831
+
832
+ # Get atom names for this residue
833
+ if self.atom_names is not None:
834
+ # Use stored atom names (preserves original names from mmCIF)
835
+ names = list(self.atom_names[start:end])
836
+ elif is_protein:
837
+ # Fallback: use standard protein atom names
838
+ standard_names = residue_constants.residue_atoms.get(
839
+ token, ["N", "CA", "C", "O"]
840
+ )
841
+ names = standard_names[: end - start]
842
+ # Pad if needed
843
+ while len(names) < (end - start):
844
+ names.append(f"X{len(names)+1}")
845
+ else:
846
+ # Fallback: generate names for ligands/nucleic acids
847
+ names = [f"C{i+1}" for i in range(end - start)]
848
+
849
+ # Vectorized assignment for this token's atoms
850
+ atom_res_ids[start:end] = res_id
851
+ atom_chain_ids[start:end] = chain_id_str
852
+ atom_res_names[start:end] = token
853
+ # Use stored hetero flags if available, otherwise guess based on protein status
854
+ if self.atom_hetero is not None:
855
+ atom_hetero[start:end] = self.atom_hetero[start:end]
856
+ else:
857
+ atom_hetero[start:end] = not is_protein
858
+ atom_bfactors[start:end] = self.plddt[token_idx] * 100.0
859
+ atom_names[start:end] = names
860
+ atom_entity_ids[start:end] = chain_to_entity.get(chain_id_str, 1)
861
+
862
+ # Set all AtomArray attributes at once (convert object arrays to proper string arrays)
863
+ # res_name uses U8 to accommodate CCD codes up to 5 characters (e.g., A1AZ2);
864
+ # chain_id uses U16 because chain names like ``ligand_1`` / ``ligand_2`` /
865
+ # auth-asym ids of arbitrary length are possible.
866
+ atom_array.res_id = atom_res_ids
867
+ atom_array.chain_id = np.array(atom_chain_ids, dtype="U16")
868
+ atom_array.res_name = np.array(atom_res_names, dtype="U8")
869
+ atom_array.hetero = atom_hetero
870
+ atom_array.atom_name = np.array(atom_names, dtype="U4")
871
+ atom_array.add_annotation("b_factor", dtype=float)
872
+ atom_array.b_factor = atom_bfactors
873
+ atom_array.add_annotation("entity_id", dtype=int)
874
+ atom_array.entity_id = atom_entity_ids
875
+
876
+ # Use existing elements or infer them from atom names
877
+ if self.atom_elements is not None and len(self.atom_elements) == n_atoms:
878
+ # Convert object array to proper string array for biotite
879
+ atom_array.element = np.array(self.atom_elements, dtype="U4")
880
+ else:
881
+ # Use biotite's built-in element inference
882
+ atom_array.element = bs.infer_elements(atom_array)
883
+
884
+ # Create CIF file and set structure
885
+ cif_file = CIFFile()
886
+ set_structure(cif_file, atom_array, data_block=self.id)
887
+
888
+ # Manually fix label_entity_id (biotite doesn't use entity_id annotation correctly)
889
+ if "atom_site" in cif_file.block:
890
+ atom_site = cif_file.block["atom_site"]
891
+ if "label_asym_id" in atom_site and "label_entity_id" in atom_site:
892
+ label_asym_ids = atom_site["label_asym_id"]
893
+ if hasattr(label_asym_ids, "as_array"):
894
+ chain_ids_list = label_asym_ids.as_array(str).tolist()
895
+ elif hasattr(label_asym_ids, "__iter__"):
896
+ chain_ids_list = list(label_asym_ids) # type: ignore[arg-type]
897
+ else:
898
+ chain_ids_list = []
899
+ updated_entity_ids = [
900
+ str(chain_to_entity.get(cid, 1)) for cid in chain_ids_list
901
+ ]
902
+ if updated_entity_ids:
903
+ atom_site["label_entity_id"] = CIFColumn(
904
+ data=CIFData(array=np.array(updated_entity_ids), dtype=np.str_)
905
+ )
906
+
907
+ # Add _entity category for OST compatibility
908
+ self._add_entity_information(cif_file, entity_sequences)
909
+
910
+ # Convert to string
911
+ output = io.StringIO()
912
+ cif_file.write(output)
913
+ return output.getvalue()
914
+
915
+ def dockq(self, native: "MolecularComplex") -> Any:
916
+ """Compute DockQ score against native structure.
917
+
918
+ Args:
919
+ native: Native MolecularComplex to compute DockQ against
920
+
921
+ Returns:
922
+ DockQ result containing score and alignment information
923
+ """
924
+ # Imports moved to top of file
925
+
926
+ # Convert both complexes to ProteinComplex format for DockQ computation
927
+ # This extracts only the protein portion and converts to PDB format
928
+ try:
929
+ self_pc = self.to_protein_complex()
930
+ native_pc = native.to_protein_complex()
931
+ except ValueError as e:
932
+ raise ValueError(
933
+ f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}"
934
+ )
935
+
936
+ # Normalize chain IDs for PDB compatibility
937
+ self_pc = self_pc.normalize_chain_ids_for_pdb()
938
+ native_pc = native_pc.normalize_chain_ids_for_pdb()
939
+
940
+ # Use the existing ProteinComplex.dockq() method
941
+ try:
942
+ dockq_result = self_pc.dockq(native_pc)
943
+ return dockq_result
944
+ except Exception:
945
+ # Fallback to manual DockQ computation if ProteinComplex.dockq() fails
946
+ return self._compute_dockq_manual(native)
947
+
948
+ def _compute_dockq_manual(self, native: "MolecularComplex") -> Any:
949
+ """Manual DockQ computation fallback."""
950
+ # Imports moved to top of file
951
+
952
+ # Convert both complexes to ProteinComplex format
953
+ try:
954
+ self_pc = self.to_protein_complex()
955
+ native_pc = native.to_protein_complex()
956
+ except ValueError as e:
957
+ raise ValueError(
958
+ f"Cannot convert MolecularComplex to ProteinComplex for DockQ: {e}"
959
+ )
960
+
961
+ # Normalize chain IDs for PDB compatibility
962
+ self_pc = self_pc.normalize_chain_ids_for_pdb()
963
+ native_pc = native_pc.normalize_chain_ids_for_pdb()
964
+
965
+ # Write temporary PDB files and run DockQ
966
+ with TemporaryDirectory() as tdir:
967
+ dir_path = Path(tdir)
968
+ self_pdb = dir_path / "self.pdb"
969
+ native_pdb = dir_path / "native.pdb"
970
+
971
+ # Write PDB files
972
+ self_pc.to_pdb(self_pdb)
973
+ native_pc.to_pdb(native_pdb)
974
+
975
+ # Run DockQ
976
+ try:
977
+ output = check_output(["DockQ", str(self_pdb), str(native_pdb)])
978
+ output_text = output.decode()
979
+
980
+ # Parse DockQ output
981
+ lines = output_text.split("\n")
982
+
983
+ # Find the total DockQ score
984
+ dockq_score = None
985
+ for line in lines:
986
+ if "Total DockQ" in line:
987
+ match = re.search(r"Total DockQ.*: ([\d.]+)", line)
988
+ if match:
989
+ dockq_score = float(match.group(1))
990
+ break
991
+
992
+ if dockq_score is None:
993
+ # Try to find individual DockQ scores
994
+ for line in lines:
995
+ if line.startswith("DockQ") and ":" in line:
996
+ try:
997
+ dockq_score = float(line.split(":")[1].strip())
998
+ break
999
+ except (ValueError, IndexError):
1000
+ continue
1001
+
1002
+ if dockq_score is None:
1003
+ raise ValueError("Could not parse DockQ score from output")
1004
+
1005
+ # Return a simple result structure
1006
+ return {
1007
+ "total_dockq": dockq_score,
1008
+ "raw_output": output_text,
1009
+ "aligned": self, # Return self as aligned structure
1010
+ }
1011
+
1012
+ except FileNotFoundError:
1013
+ raise RuntimeError(
1014
+ "DockQ is not installed. Please install DockQ to use this method."
1015
+ )
1016
+ except Exception as e:
1017
+ raise RuntimeError(f"DockQ computation failed: {e}")
1018
+
1019
+ def rmsd(self, target: "MolecularComplex", **kwargs) -> float:
1020
+ """Compute RMSD against target structure.
1021
+
1022
+ Args:
1023
+ target: Target MolecularComplex to compute RMSD against
1024
+ **kwargs: Additional arguments passed to compute_rmsd
1025
+
1026
+ Returns:
1027
+ float: RMSD value between the two structures
1028
+ """
1029
+ # Imports moved to top of file
1030
+
1031
+ # Ensure both complexes have the same number of tokens
1032
+ if len(self) != len(target):
1033
+ raise ValueError(
1034
+ f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}"
1035
+ )
1036
+
1037
+ # Extract center positions for each token (using centroid of atoms)
1038
+ mobile_coords = []
1039
+ target_coords = []
1040
+ atom_mask = []
1041
+
1042
+ for i in range(len(self)):
1043
+ # Get atom positions for this token
1044
+ mobile_start, mobile_end = self.token_to_atoms[i]
1045
+ target_start, target_end = target.token_to_atoms[i]
1046
+
1047
+ # Extract atom positions
1048
+ mobile_atoms = self.atom_positions[mobile_start:mobile_end]
1049
+ target_atoms = target.atom_positions[target_start:target_end]
1050
+
1051
+ # Check if both tokens have atoms
1052
+ if len(mobile_atoms) == 0 or len(target_atoms) == 0:
1053
+ # Skip tokens with no atoms
1054
+ continue
1055
+
1056
+ # For simplicity, use the centroid of atoms as the representative position
1057
+ mobile_center = mobile_atoms.mean(axis=0)
1058
+ target_center = target_atoms.mean(axis=0)
1059
+
1060
+ mobile_coords.append(mobile_center)
1061
+ target_coords.append(target_center)
1062
+ atom_mask.append(True)
1063
+
1064
+ if len(mobile_coords) == 0:
1065
+ raise ValueError("No valid atoms found for RMSD computation")
1066
+
1067
+ # Convert to tensors
1068
+ mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze(
1069
+ 0
1070
+ ) # [1, N, 3]
1071
+ target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze(
1072
+ 0
1073
+ ) # [1, N, 3]
1074
+ mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N]
1075
+
1076
+ # Compute RMSD using existing infrastructure
1077
+ rmsd_value = compute_rmsd(
1078
+ mobile=mobile_tensor,
1079
+ target=target_tensor,
1080
+ atom_exists_mask=mask_tensor,
1081
+ reduction="batch",
1082
+ **kwargs,
1083
+ )
1084
+
1085
+ return float(rmsd_value)
1086
+
1087
+ def lddt_ca(self, target: "MolecularComplex", **kwargs) -> float:
1088
+ """Compute LDDT score against target structure.
1089
+
1090
+ Args:
1091
+ target: Target MolecularComplex to compute LDDT against
1092
+ **kwargs: Additional arguments passed to compute_lddt
1093
+
1094
+ Returns:
1095
+ float: LDDT value between the two structures
1096
+ """
1097
+ # Imports moved to top of file
1098
+
1099
+ # Ensure both complexes have the same number of tokens
1100
+ if len(self) != len(target):
1101
+ raise ValueError(
1102
+ f"Complexes must have the same number of tokens: {len(self)} vs {len(target)}"
1103
+ )
1104
+
1105
+ # Extract center positions for each token (using centroid of atoms)
1106
+ mobile_coords = []
1107
+ target_coords = []
1108
+ atom_mask = []
1109
+
1110
+ for i in range(len(self)):
1111
+ # Get atom positions for this token
1112
+ mobile_start, mobile_end = self.token_to_atoms[i]
1113
+ target_start, target_end = target.token_to_atoms[i]
1114
+
1115
+ # Extract atom positions
1116
+ mobile_atoms = self.atom_positions[mobile_start:mobile_end]
1117
+ target_atoms = target.atom_positions[target_start:target_end]
1118
+
1119
+ # Check if both tokens have atoms
1120
+ if len(mobile_atoms) == 0 or len(target_atoms) == 0:
1121
+ # Skip tokens with no atoms
1122
+ mobile_coords.append(np.full(3, np.nan))
1123
+ target_coords.append(np.full(3, np.nan))
1124
+ atom_mask.append(False)
1125
+ continue
1126
+
1127
+ # For simplicity, use the centroid of atoms as the representative position
1128
+ mobile_center = mobile_atoms.mean(axis=0)
1129
+ target_center = target_atoms.mean(axis=0)
1130
+
1131
+ mobile_coords.append(mobile_center)
1132
+ target_coords.append(target_center)
1133
+ atom_mask.append(True)
1134
+
1135
+ if not any(atom_mask):
1136
+ raise ValueError("No valid atoms found for LDDT computation")
1137
+
1138
+ # Convert to tensors
1139
+ mobile_tensor = torch.from_numpy(np.stack(mobile_coords, axis=0)).unsqueeze(
1140
+ 0
1141
+ ) # [1, N, 3]
1142
+ target_tensor = torch.from_numpy(np.stack(target_coords, axis=0)).unsqueeze(
1143
+ 0
1144
+ ) # [1, N, 3]
1145
+ mask_tensor = torch.tensor(atom_mask, dtype=torch.bool).unsqueeze(0) # [1, N]
1146
+
1147
+ # Compute LDDT using existing infrastructure
1148
+ lddt_value = compute_lddt(
1149
+ all_atom_pred_pos=mobile_tensor,
1150
+ all_atom_positions=target_tensor,
1151
+ all_atom_mask=mask_tensor,
1152
+ per_residue=False, # Return overall LDDT score
1153
+ **kwargs,
1154
+ )
1155
+
1156
+ return float(lddt_value)
1157
+
1158
+ def state_dict(self):
1159
+ """This state dict is optimized for storage, so it turns things to fp16 whenever
1160
+ possible and converts numpy arrays to lists for JSON serialization.
1161
+ """
1162
+ dct = {k: v for k, v in vars(self).items()}
1163
+ for k, v in dct.items():
1164
+ if isinstance(v, np.ndarray):
1165
+ match v.dtype:
1166
+ case np.int64:
1167
+ dct[k] = v.astype(np.int32).tolist()
1168
+ case np.float64 | np.float32:
1169
+ dct[k] = v.astype(np.float16).tolist()
1170
+ case _:
1171
+ dct[k] = v.tolist()
1172
+ elif isinstance(v, MolecularComplexMetadata):
1173
+ dct[k] = asdict(v)
1174
+
1175
+ return dct
1176
+
1177
+ def to_blob(self) -> bytes:
1178
+ return brotli.compress(msgpack.dumps(self.state_dict()), quality=5)
1179
+
1180
+ @classmethod
1181
+ def from_state_dict(cls, dct):
1182
+ for k, v in dct.items():
1183
+ if isinstance(v, list) and k in [
1184
+ "atom_positions",
1185
+ "atom_elements",
1186
+ "atom_names",
1187
+ "atom_hetero",
1188
+ "token_to_atoms",
1189
+ "chain_id",
1190
+ "plddt",
1191
+ ]:
1192
+ dct[k] = np.array(v)
1193
+
1194
+ for k, v in dct.items():
1195
+ if isinstance(v, np.ndarray):
1196
+ if k in ["atom_positions", "plddt"]:
1197
+ dct[k] = v.astype(np.float32)
1198
+ elif k in ["token_to_atoms", "chain_id"]:
1199
+ dct[k] = (
1200
+ v.astype(np.int32)
1201
+ if k == "token_to_atoms"
1202
+ else v.astype(np.int64)
1203
+ )
1204
+
1205
+ dct["metadata"] = MolecularComplexMetadata(**dct["metadata"])
1206
+
1207
+ # Backward compatibility: if chain_id is missing, create default array
1208
+ if "chain_id" not in dct:
1209
+ # Default all tokens to chain 0
1210
+ dct["chain_id"] = np.zeros(len(dct["sequence"]), dtype=np.int64)
1211
+
1212
+ return cls(**dct)
1213
+
1214
+ @classmethod
1215
+ def from_blob(cls, input: Path | str | io.BytesIO | bytes):
1216
+ match input:
1217
+ case Path() | str():
1218
+ bytes = Path(input).read_bytes()
1219
+ case io.BytesIO():
1220
+ bytes = input.getvalue()
1221
+ case _:
1222
+ bytes = input
1223
+ return cls.from_state_dict(
1224
+ msgpack.loads(brotli.decompress(bytes), strict_map_key=False)
1225
+ )
1226
+
esmfold2_msa.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import string
5
+ from dataclasses import dataclass
6
+ from functools import cached_property
7
+ from itertools import islice
8
+ from typing import Sequence
9
+
10
+ import numpy as np
11
+ from Bio import SeqIO
12
+ from scipy.spatial.distance import cdist
13
+
14
+ from .esmfold2_misc import slice_any_object
15
+ from .esmfold2_msa_filter_sequences import greedy_select_indices, hhfilter
16
+ from .esmfold2_parsing import FastaEntry, read_sequences, write_sequences
17
+ from .esmfold2_sequential_dataclass import SequentialDataclass
18
+ from .esmfold2_system import PathOrBuffer
19
+
20
+ REMOVE_LOWERCASE_TRANSLATION = str.maketrans(dict.fromkeys(string.ascii_lowercase))
21
+
22
+
23
+ def remove_insertions_from_sequence(seq: str) -> str:
24
+ return seq.translate(REMOVE_LOWERCASE_TRANSLATION)
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class MSA(SequentialDataclass):
29
+ """Object-oriented interface to an MSA.
30
+
31
+ Args:
32
+ sequences (list[str]): List of protein sequences
33
+ headers (list[str]): List of headers describing the sequences
34
+
35
+ """
36
+
37
+ entries: list[FastaEntry]
38
+
39
+ @cached_property
40
+ def sequences(self) -> list[str]:
41
+ return [entry.sequence for entry in self.entries]
42
+
43
+ @cached_property
44
+ def headers(self) -> list[str]:
45
+ return [entry.header for entry in self.entries]
46
+
47
+ def __repr__(self):
48
+ return (
49
+ f"MSA({self.entries[0].header}: Depth={self.depth}, Length={self.seqlen})"
50
+ )
51
+
52
+ def to_fast_msa(self) -> FastMSA:
53
+ return FastMSA(self.array, self.headers)
54
+
55
+ @classmethod
56
+ def from_a3m(
57
+ cls,
58
+ path: PathOrBuffer,
59
+ remove_insertions: bool = True,
60
+ max_sequences: int | None = None,
61
+ ) -> MSA:
62
+ entries = []
63
+ for header, seq in islice(read_sequences(path), max_sequences):
64
+ if remove_insertions:
65
+ seq = remove_insertions_from_sequence(seq)
66
+ if entries:
67
+ assert (
68
+ len(seq) == len(entries[0].sequence)
69
+ ), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}"
70
+ entries.append(FastaEntry(header, seq))
71
+ return cls(entries)
72
+
73
+ def to_a3m(self, path: PathOrBuffer) -> None:
74
+ write_sequences(self.entries, path)
75
+
76
+ @classmethod
77
+ def from_stockholm(
78
+ cls,
79
+ path: PathOrBuffer,
80
+ remove_insertions: bool = True,
81
+ max_sequences: int | None = None,
82
+ ) -> MSA:
83
+ entries = []
84
+ for record in islice(SeqIO.parse(path, "stockholm"), max_sequences):
85
+ header = f"{record.id} {record.description}"
86
+ seq = str(record.seq)
87
+ if entries:
88
+ assert (
89
+ len(seq) == len(entries[0].sequence)
90
+ ), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}"
91
+ entries.append(FastaEntry(header, seq))
92
+ msa = cls(entries)
93
+ if remove_insertions:
94
+ keep_inds = [i for i, aa in enumerate(msa.query) if aa != "-"]
95
+ msa = msa.select_positions(keep_inds)
96
+ return msa
97
+
98
+ def to_bytes(self) -> bytes:
99
+ version = 1
100
+ version_bytes = version.to_bytes(1, "little")
101
+ seqlen_bytes = self.seqlen.to_bytes(4, "little")
102
+ depth_bytes = self.depth.to_bytes(4, "little")
103
+ array_bytes = self.array.tobytes()
104
+ header_bytes = "\n".join(entry.header for entry in self.entries).encode()
105
+ all_bytes = (
106
+ version_bytes + seqlen_bytes + depth_bytes + array_bytes + header_bytes
107
+ )
108
+ return all_bytes
109
+
110
+ @classmethod
111
+ def from_bytes(cls, data: bytes) -> MSA:
112
+ version_bytes, seqlen_bytes, depth_bytes, data = (
113
+ data[:1],
114
+ data[1:5],
115
+ data[5:9],
116
+ data[9:],
117
+ )
118
+ version = int.from_bytes(version_bytes, "little")
119
+ if version != 1:
120
+ raise ValueError(f"Unsupported version: {version}")
121
+ seqlen = int.from_bytes(seqlen_bytes, "little")
122
+ depth = int.from_bytes(depth_bytes, "little")
123
+ array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :]
124
+ array = np.frombuffer(array_bytes, dtype="|S1")
125
+ array = array.reshape(depth, seqlen)
126
+ headers = header_bytes.decode().split("\n")
127
+ # Sometimes the separation is two newlines, which results in an empty header.
128
+ headers = [header for header in headers if header]
129
+ # If all headers were empty (e.g., saved from from_sequences), use empty headers
130
+ if len(headers) == 0 and depth > 0:
131
+ headers = [""] * depth
132
+ entries = [
133
+ FastaEntry(header, b"".join(row).decode())
134
+ for header, row in zip(headers, array)
135
+ ]
136
+ return cls(entries)
137
+
138
+ # TODO(jmaccarl): set remove_insertions to True by default here to match other utils
139
+ @classmethod
140
+ def from_sequences(
141
+ cls, sequences: list[str], remove_insertions: bool = False
142
+ ) -> MSA:
143
+ if remove_insertions:
144
+ entries = [
145
+ FastaEntry("", remove_insertions_from_sequence(seq))
146
+ for seq in sequences
147
+ ]
148
+ else:
149
+ entries = [FastaEntry("", seq) for seq in sequences]
150
+ return cls(entries)
151
+
152
+ def to_sequence_bytes(self) -> bytes:
153
+ """Stores ONLY SEQUENCES in array format as bytes. Header information will be lost."""
154
+ seqlen_bytes = self.seqlen.to_bytes(4, "little")
155
+ array_bytes = self.array.tobytes()
156
+ all_bytes = seqlen_bytes + array_bytes
157
+ return all_bytes
158
+
159
+ @classmethod
160
+ def from_sequence_bytes(cls, data: bytes) -> MSA:
161
+ seqlen_bytes, array_bytes = data[:4], data[4:]
162
+ seqlen = int.from_bytes(seqlen_bytes, "little")
163
+ array = np.frombuffer(array_bytes, dtype="|S1")
164
+ array = array.reshape(-1, seqlen)
165
+ entries = [FastaEntry("", b"".join(row).decode()) for row in array]
166
+ return cls(entries)
167
+
168
+ @property
169
+ def depth(self) -> int:
170
+ return len(self.entries)
171
+
172
+ @property
173
+ def seqlen(self) -> int:
174
+ return len(self.entries[0].sequence)
175
+
176
+ @cached_property
177
+ def array(self) -> np.ndarray:
178
+ return np.array([list(seq) for seq in self.sequences], dtype="|S1")
179
+
180
+ @property
181
+ def query(self) -> str:
182
+ return self.entries[0].sequence
183
+
184
+ def select_sequences(self, indices: Sequence[int] | np.ndarray) -> MSA:
185
+ """Subselect rows of the MSA."""
186
+ entries = [self.entries[idx] for idx in indices]
187
+ return dataclasses.replace(self, entries=entries)
188
+
189
+ def select_positions(self, indices: Sequence[int] | np.ndarray) -> MSA:
190
+ """Subselect columns of the MSA."""
191
+ entries = [
192
+ FastaEntry(header, "".join(seq[idx] for idx in indices))
193
+ for header, seq in self.entries
194
+ ]
195
+ return dataclasses.replace(self, entries=entries)
196
+
197
+ def __getitem__(self, indices: int | list[int] | slice | np.ndarray):
198
+ if isinstance(indices, int):
199
+ indices = [indices]
200
+
201
+ entries = [
202
+ FastaEntry(header, slice_any_object(seq, indices))
203
+ for header, seq in self.entries
204
+ ]
205
+ return dataclasses.replace(self, entries=entries)
206
+
207
+ def __len__(self):
208
+ return self.seqlen
209
+
210
+ def greedy_select(self, num_seqs: int, mode: str = "max") -> MSA:
211
+ """Greedily select sequences that either maximize or minimize hamming distance.
212
+
213
+ Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,
214
+ iteratively add sequences to the list with the maximum (minimum) average Hamming
215
+ distance to the existing set of sequences.
216
+
217
+ Args:
218
+ num_seqs (int): Number of sequences to select.
219
+ mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless
220
+ you're doing it to prove a point for a paper.
221
+
222
+ Returns:
223
+ MSA object w/ subselected sequences.
224
+ """
225
+ assert mode in ("max", "min")
226
+ if self.depth <= num_seqs:
227
+ return self
228
+
229
+ indices = greedy_select_indices(self.array, num_seqs, mode)
230
+ return self.select_sequences(indices)
231
+
232
+ def hhfilter(
233
+ self,
234
+ seqid: int = 90,
235
+ diff: int = 0,
236
+ cov: int = 0,
237
+ qid: int = 0,
238
+ qsc: float = -20.0,
239
+ binary: str = "hhfilter",
240
+ ) -> MSA:
241
+ """Apply hhfilter to the sequences in the MSA and return a filtered MSA."""
242
+
243
+ indices = hhfilter(
244
+ self.sequences,
245
+ seqid=seqid,
246
+ diff=diff,
247
+ cov=cov,
248
+ qid=qid,
249
+ qsc=qsc,
250
+ binary=binary,
251
+ )
252
+ return self.select_sequences(indices)
253
+
254
+ def select_random_sequences(self, num_seqs: int) -> MSA:
255
+ """Uses random sampling to subselect sequences from the MSA. Always
256
+ keeps the query sequence.
257
+ """
258
+ if num_seqs >= self.depth:
259
+ return self
260
+
261
+ # Subselect random, always keeping the query sequence.
262
+ indices = np.sort(
263
+ np.append(
264
+ 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1
265
+ )
266
+ )
267
+ msa = self.select_sequences(indices) # type: ignore
268
+ return msa
269
+
270
+ def select_diverse_sequences(self, num_seqs: int) -> MSA:
271
+ """Applies hhfilter to select ~num_seqs sequences, then uses random sampling
272
+ to subselect if necessary.
273
+ """
274
+ if num_seqs >= self.depth:
275
+ return self
276
+
277
+ msa = self.hhfilter(diff=num_seqs)
278
+ if num_seqs < msa.depth:
279
+ msa = msa.select_random_sequences(num_seqs)
280
+ return msa
281
+
282
+ def pad_to_depth(self, depth: int) -> MSA:
283
+ if depth < self.depth:
284
+ raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}")
285
+ elif depth == self.depth:
286
+ return self
287
+
288
+ num_to_add = depth - self.depth
289
+ extra_entries = [FastaEntry("", "-" * self.seqlen) for _ in range(num_to_add)]
290
+ return dataclasses.replace(self, entries=self.entries + extra_entries)
291
+
292
+ @classmethod
293
+ def stack(
294
+ cls, msas: Sequence[MSA], remove_query_from_later_msas: bool = True
295
+ ) -> MSA:
296
+ """Stack a series of MSAs. Optionally remove the query from msas after the first."""
297
+ all_entries = []
298
+ for i, msa in enumerate(msas):
299
+ entries = msa.entries
300
+ if i > 0 and remove_query_from_later_msas:
301
+ entries = entries[1:]
302
+ all_entries.extend(entries)
303
+ return cls(entries=all_entries)
304
+
305
+ @cached_property
306
+ def seqid(self) -> np.ndarray:
307
+ array = self.array.view(np.uint8)
308
+ seqid = 1 - cdist(array[0][None], array, "hamming")
309
+ return seqid[0]
310
+
311
+ @classmethod
312
+ def concat(
313
+ cls,
314
+ msas: Sequence[MSA],
315
+ join_token: str | None = "|",
316
+ allow_depth_mismatch: bool = False,
317
+ ) -> MSA:
318
+ """Concatenate a series of MSAs horizontally, along the sequence dimension."""
319
+ if not msas:
320
+ raise ValueError("Cannot concatenate an empty list of MSAs")
321
+ msa_depths = [msa.depth for msa in msas]
322
+ if len(set(msa_depths)) != 1:
323
+ if not allow_depth_mismatch:
324
+ raise ValueError("Depth mismatch in concatenating MSAs")
325
+ else:
326
+ max_depth = max(msa_depths)
327
+ msas = [msa.pad_to_depth(max_depth) for msa in msas]
328
+ headers = [
329
+ "|".join([str(h) for h in headers])
330
+ for headers in zip(*(msa.headers for msa in msas))
331
+ ]
332
+
333
+ if join_token is None:
334
+ join_token = ""
335
+
336
+ seqs = [join_token.join(vals) for vals in zip(*(msa.sequences for msa in msas))]
337
+ entries = [FastaEntry(header, seq) for header, seq in zip(headers, seqs)]
338
+ return cls(entries)
339
+
340
+
341
+ @dataclass(frozen=True)
342
+ class FastMSA(SequentialDataclass):
343
+ """Object-oriented interface to an MSA stored as a numpy uint8 array."""
344
+
345
+ array: np.ndarray
346
+ headers: list[str] | None = None
347
+
348
+ def __post_init__(self):
349
+ if self.headers is not None:
350
+ assert (
351
+ len(self.headers) == self.depth
352
+ ), "Number of headers must match depth."
353
+
354
+ @classmethod
355
+ def from_bytes(cls, data: bytes) -> FastMSA:
356
+ version_bytes, seqlen_bytes, depth_bytes, data = (
357
+ data[:1],
358
+ data[1:5],
359
+ data[5:9],
360
+ data[9:],
361
+ )
362
+ version = int.from_bytes(version_bytes, "little")
363
+ if version != 1:
364
+ raise ValueError(f"Unsupported version: {version}")
365
+ seqlen = int.from_bytes(seqlen_bytes, "little")
366
+ depth = int.from_bytes(depth_bytes, "little")
367
+ array_bytes, header_bytes = data[: seqlen * depth], data[seqlen * depth :]
368
+ array = np.frombuffer(array_bytes, dtype="|S1")
369
+ array = array.reshape(depth, seqlen)
370
+ headers = header_bytes.decode().split("\n")
371
+ # Sometimes the separation is two newlines, which results in an empty header.
372
+ headers = [header for header in headers if header]
373
+ # If all headers were empty (e.g., saved from from_sequences), use empty headers
374
+ if len(headers) == 0 and depth > 0:
375
+ headers = [""] * depth
376
+ return cls(array, headers)
377
+
378
+ @classmethod
379
+ def from_sequence_bytes(cls, data: bytes) -> FastMSA:
380
+ seqlen_bytes, array_bytes = data[:4], data[4:]
381
+ seqlen = int.from_bytes(seqlen_bytes, "little")
382
+ array = np.frombuffer(array_bytes, dtype="|S1")
383
+ array = array.reshape(-1, seqlen)
384
+ return cls(array)
385
+
386
+ @property
387
+ def depth(self) -> int:
388
+ return self.array.shape[0]
389
+
390
+ @property
391
+ def seqlen(self) -> int:
392
+ return self.array.shape[1]
393
+
394
+ def __len__(self):
395
+ return self.seqlen
396
+
397
+ def __getitem__(self, indices: int | list[int] | slice | np.ndarray):
398
+ if isinstance(indices, int):
399
+ indices = [indices]
400
+
401
+ return dataclasses.replace(self, array=self.array[:, indices])
402
+
403
+ def select_sequences(self, indices: Sequence[int] | np.ndarray) -> FastMSA:
404
+ """Subselect rows of the MSA."""
405
+ array = self.array[indices]
406
+ headers = (
407
+ [self.headers[idx] for idx in indices] if self.headers is not None else None
408
+ )
409
+ return dataclasses.replace(self, array=array, headers=headers)
410
+
411
+ def select_random_sequences(self, num_seqs: int) -> FastMSA:
412
+ """Uses random sampling to subselect sequences from the MSA. Always
413
+ keeps the query sequence.
414
+ """
415
+ if num_seqs >= self.depth:
416
+ return self
417
+
418
+ # Subselect random, always keeping the query sequence.
419
+ indices = np.sort(
420
+ np.append(
421
+ 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1
422
+ )
423
+ )
424
+ msa = self.select_sequences(indices) # type: ignore
425
+ return msa
426
+
427
+ def pad_to_depth(self, depth: int) -> FastMSA:
428
+ if depth < self.depth:
429
+ raise ValueError(f"Cannot pad to depth {depth} when depth is {self.depth}")
430
+ elif depth == self.depth:
431
+ return self
432
+
433
+ num_to_add = depth - self.depth
434
+ array = np.pad(
435
+ self.array,
436
+ [(0, num_to_add), (0, 0)],
437
+ constant_values=ord("-") if self.array.dtype == np.uint8 else b"-",
438
+ )
439
+ headers = self.headers
440
+ if headers is not None:
441
+ headers = headers + [""] * num_to_add
442
+ return dataclasses.replace(self, array=array, headers=headers)
443
+
444
+ @classmethod
445
+ def concat(
446
+ cls,
447
+ msas: Sequence[FastMSA],
448
+ join_token: str | None = None,
449
+ allow_depth_mismatch: bool = False,
450
+ ) -> FastMSA:
451
+ """Concatenate a series of MSAs horizontally, along the sequence dimension."""
452
+ if not msas:
453
+ raise ValueError("Cannot concatenate an empty list of MSAs")
454
+ if join_token is not None and join_token != "":
455
+ raise NotImplementedError("join_token is not supported for FastMSA")
456
+
457
+ msa_depths = [msa.depth for msa in msas]
458
+ if len(set(msa_depths)) != 1:
459
+ if not allow_depth_mismatch:
460
+ raise ValueError("Depth mismatch in concatenating MSAs")
461
+ else:
462
+ max_depth = max(msa_depths)
463
+ msas = [msa.pad_to_depth(max_depth) for msa in msas]
464
+ headers = [
465
+ "|".join([str(h) for h in headers])
466
+ for headers in zip(
467
+ *(
468
+ msa.headers if msa.headers is not None else [""] * msa.depth
469
+ for msa in msas
470
+ )
471
+ )
472
+ ]
473
+
474
+ array = np.concatenate([msa.array for msa in msas], axis=1)
475
+ return cls(array, headers)
476
+
477
+ def to_msa(self) -> MSA:
478
+ headers = (
479
+ self.headers
480
+ if self.headers is not None
481
+ else [f"seq{i}" for i in range(self.depth)]
482
+ )
483
+ entries = [
484
+ FastaEntry(header, b"".join(row).decode())
485
+ for header, row in zip(headers, self.array)
486
+ ]
487
+ return MSA(entries)
488
+
489
+ @classmethod
490
+ def stack(
491
+ cls, msas: Sequence[FastMSA], remove_query_from_later_msas: bool = True
492
+ ) -> FastMSA:
493
+ """Stack a series of MSAs. Optionally remove the query from msas after the first."""
494
+ arrays = []
495
+ all_headers = []
496
+ for i, msa in enumerate(msas):
497
+ array = msa.array
498
+ headers = msa.headers
499
+ if i > 0 and remove_query_from_later_msas:
500
+ array = array[1:]
501
+ if headers is not None:
502
+ headers = headers[1:]
503
+ arrays.append(array)
504
+ if headers is not None:
505
+ all_headers.extend(headers)
506
+ return cls(np.concatenate(arrays, axis=0), all_headers)
507
+
esmfold2_msa_filter_sequences.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from scipy.spatial.distance import cdist
7
+
8
+ from .esmfold2_system import run_subprocess_with_errorcheck
9
+
10
+
11
+ def greedy_select_indices(array, num_seqs: int, mode: str = "max") -> list[int]:
12
+ """Greedily select sequences that either maximize or minimize hamming distance.
13
+
14
+ Algorithm proposed in the MSA Transformer paper. Starting from the query sequence,
15
+ iteratively add sequences to the list with the maximum (minimum) average Hamming
16
+ distance to the existing set of sequences.
17
+
18
+ Args:
19
+ array (np.ndarray): Character array representing the sequences in the MSA
20
+ num_seqs (int): Number of sequences to select.
21
+ mode (str): Whether to maximize or minimize diversity. DO NOT pick 'min' unless
22
+ you're doing it to prove a point for a paper.
23
+
24
+ Returns:
25
+ list[int]: List of indices to select from the array
26
+ """
27
+ assert mode in ("max", "min")
28
+ depth = array.shape[0]
29
+ if depth <= num_seqs:
30
+ return list(range(depth))
31
+ array = array.view(np.uint8)
32
+
33
+ optfunc = np.argmax if mode == "max" else np.argmin
34
+ all_indices = np.arange(depth)
35
+ indices = [0]
36
+ pairwise_distances = np.zeros((0, depth))
37
+ for _ in range(num_seqs - 1):
38
+ dist = cdist(array[indices[-1:]], array, "hamming")
39
+ pairwise_distances = np.concatenate([pairwise_distances, dist])
40
+ shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
41
+ shifted_index = optfunc(shifted_distance)
42
+ index = np.delete(all_indices, indices)[shifted_index]
43
+ indices.append(index)
44
+ indices = sorted(indices)
45
+ return indices
46
+
47
+
48
+ def hhfilter(
49
+ sequences: list[str],
50
+ seqid: int = 90,
51
+ diff: int = 0,
52
+ cov: int = 0,
53
+ qid: int = 0,
54
+ qsc: float = -20.0,
55
+ binary: str = "hhfilter",
56
+ ) -> list[int]:
57
+ with tempfile.TemporaryDirectory(
58
+ dir="/dev/shm" if os.path.exists("/dev/shm") else None
59
+ ) as tempdirname:
60
+ tempdir = Path(tempdirname)
61
+ fasta_file = tempdir / "input.fasta"
62
+ fasta_file.write_text(
63
+ "\n".join(f">{i}\n{seq}" for i, seq in enumerate(sequences))
64
+ )
65
+ output_file = tempdir / "output.fasta"
66
+ command = " ".join(
67
+ [
68
+ f"{binary}",
69
+ f"-i {fasta_file}",
70
+ "-M a3m",
71
+ f"-o {output_file}",
72
+ f"-id {seqid}",
73
+ f"-diff {diff}",
74
+ f"-cov {cov}",
75
+ f"-qid {qid}",
76
+ f"-qsc {qsc}",
77
+ ]
78
+ ).split(" ")
79
+ run_subprocess_with_errorcheck(command, capture_output=True)
80
+ with output_file.open() as f:
81
+ indices = [int(line[1:].strip()) for line in f if line.startswith(">")]
82
+ return indices
83
+
esmfold2_normalize_coordinates.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypeVar
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from . import esmfold2_residue_constants as RC
8
+ from .esmfold2_affine3d import Affine3D
9
+
10
+ ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor)
11
+
12
+
13
+ def atom3_to_backbone_frames(bb_positions: torch.Tensor) -> Affine3D:
14
+ N, CA, C = bb_positions.unbind(dim=-2)
15
+ return Affine3D.from_graham_schmidt(C, CA, N)
16
+
17
+
18
+ def index_by_atom_name(
19
+ atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2
20
+ ) -> ArrayOrTensor:
21
+ squeeze = False
22
+ if isinstance(atom_names, str):
23
+ atom_names = [atom_names]
24
+ squeeze = True
25
+ indices = [RC.atom_order[atom_name] for atom_name in atom_names]
26
+ dim = dim % atom37.ndim
27
+ index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim))
28
+ result = atom37[index] # type: ignore
29
+ if squeeze:
30
+ result = result.squeeze(dim)
31
+ return result
32
+
33
+
34
+ def get_protein_normalization_frame(coords: Tensor) -> Affine3D:
35
+ """Given a set of coordinates for a protein, compute a single frame that can be used to normalize the coordinates.
36
+ Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame
37
+ using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame.
38
+
39
+ Args:
40
+ coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates
41
+
42
+ Returns:
43
+ Affine3D: tensor of Affine3D frame
44
+ """
45
+ bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2)
46
+ coord_mask = torch.all(torch.all(torch.isfinite(bb_coords), dim=-1), dim=-1)
47
+
48
+ average_position_per_n_ca_c = bb_coords.masked_fill(
49
+ ~coord_mask[..., None, None], 0
50
+ ).sum(-3) / (coord_mask.sum(-1)[..., None, None] + 1e-8)
51
+ frame = atom3_to_backbone_frames(average_position_per_n_ca_c.float())
52
+
53
+ return frame
54
+
55
+
56
+ def apply_frame_to_coords(coords: Tensor, frame: Affine3D) -> Tensor:
57
+ """Given a set of coordinates and a single frame, apply the frame to the coordinates.
58
+
59
+ Args:
60
+ coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates
61
+ frame (Affine3D): Affine3D frame
62
+
63
+ Returns:
64
+ torch.FloatTensor: [L, 37, 3] tensor of transformed coordinates
65
+ """
66
+ coords_trans_rot = frame[..., None, None].invert().apply(coords)
67
+
68
+ # only transform coordinates with frame that have a valid rotation
69
+ valid_frame = frame.trans.norm(dim=-1) > 0
70
+
71
+ is_inf = torch.isinf(coords)
72
+ coords = coords_trans_rot.where(valid_frame[..., None, None, None], coords)
73
+ coords.masked_fill_(is_inf, torch.inf)
74
+
75
+ return coords
76
+
77
+
78
+ def normalize_coordinates(coords: Tensor) -> Tensor:
79
+ return apply_frame_to_coords(coords, get_protein_normalization_frame(coords))
80
+
esmfold2_output.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import groupby
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from .esmfold2_constants import ELEMENT_NUMBER_TO_SYMBOL, MOL_TYPE_NONPOLYMER
8
+ from .esmfold2_molecular_complex import (
9
+ MolecularComplex,
10
+ MolecularComplexMetadata,
11
+ )
12
+
13
+
14
+ def get_element_symbol(atomic_num: int) -> str:
15
+ return ELEMENT_NUMBER_TO_SYMBOL.get(atomic_num, "X")
16
+
17
+
18
+ def build_molecular_complex_from_features(
19
+ coords: torch.Tensor,
20
+ plddt: torch.Tensor,
21
+ atom_mask: torch.Tensor,
22
+ ref_element: torch.Tensor,
23
+ ref_atom_name_chars: torch.Tensor,
24
+ chain_infos: list,
25
+ complex_id: str,
26
+ ) -> MolecularComplex:
27
+ """Construct a MolecularComplex from feature-dict tensors and chain metadata.
28
+
29
+ Non-polymer chains (ligands) collapse all per-atom tokens into a single
30
+ residue token whose pLDDT is the per-token average and whose hetero flag
31
+ is True.
32
+ """
33
+ mask_np = atom_mask.bool().cpu().numpy()
34
+ coords_np = coords.float().cpu().numpy()
35
+ name_chars_np = ref_atom_name_chars.cpu().numpy()
36
+ elements_np = ref_element.cpu().numpy()
37
+ plddt_np = plddt.float().cpu().numpy()
38
+
39
+ sequence_tokens: list[str] = []
40
+ chain_ids_per_token: list[int] = []
41
+ token_to_atoms: list[list[int]] = []
42
+ confidence: list[float] = []
43
+ flat_positions: list[list[float]] = []
44
+ flat_elements: list[str] = []
45
+ flat_names: list[str] = []
46
+ flat_hetero: list[bool] = []
47
+
48
+ chain_lookup: dict[int, str] = {}
49
+ entity_info: dict[int, str] = {}
50
+ out_atom_cursor = 0
51
+
52
+ for ci in chain_infos:
53
+ chain_lookup[ci.asym_id] = ci.chain_id
54
+ is_nonpolymer = ci.mol_type == MOL_TYPE_NONPOLYMER
55
+ entity_info[ci.entity_id] = "non-polymer" if is_nonpolymer else "polymer"
56
+
57
+ if is_nonpolymer:
58
+ residue_name = ci.tokens[0].residue_name if ci.tokens else "LIG"
59
+ sequence_tokens.append(residue_name)
60
+ chain_ids_per_token.append(ci.asym_id)
61
+ avg_plddt = (
62
+ float(np.mean([plddt_np[ti.token_index] for ti in ci.tokens]))
63
+ if ci.tokens
64
+ else 0.0
65
+ )
66
+ confidence.append(avg_plddt)
67
+ token_atom_start = out_atom_cursor
68
+ for ti in ci.tokens:
69
+ for atom_idx in range(ti.atom_start, ti.atom_start + ti.atom_count):
70
+ if not mask_np[atom_idx]:
71
+ continue
72
+ flat_positions.append(coords_np[atom_idx].tolist())
73
+ flat_elements.append(get_element_symbol(int(elements_np[atom_idx])))
74
+ chars = name_chars_np[atom_idx]
75
+ name = "".join(
76
+ chr(int(c) + 32) for c in chars if int(c) != 0
77
+ ).strip()
78
+ flat_names.append(name)
79
+ flat_hetero.append(True)
80
+ out_atom_cursor += 1
81
+ token_to_atoms.append([token_atom_start, out_atom_cursor])
82
+ continue
83
+
84
+ # Atom-tokenized modified residues (HYP, MSE, ...) span multiple
85
+ # tokens per residue; collapse them back to one mmCIF residue.
86
+ for _residue_index, ti_iter in groupby(
87
+ ci.tokens, key=lambda t: t.residue_index
88
+ ):
89
+ ti_group = list(ti_iter)
90
+ sequence_tokens.append(ti_group[0].residue_name)
91
+ chain_ids_per_token.append(ci.asym_id)
92
+ confidence.append(
93
+ float(np.mean([plddt_np[ti.token_index] for ti in ti_group]))
94
+ )
95
+ token_atom_start = out_atom_cursor
96
+ for ti in ti_group:
97
+ for atom_idx in range(ti.atom_start, ti.atom_start + ti.atom_count):
98
+ if not mask_np[atom_idx]:
99
+ continue
100
+ flat_positions.append(coords_np[atom_idx].tolist())
101
+ flat_elements.append(get_element_symbol(int(elements_np[atom_idx])))
102
+ chars = name_chars_np[atom_idx]
103
+ name = "".join(
104
+ chr(int(c) + 32) for c in chars if int(c) != 0
105
+ ).strip()
106
+ flat_names.append(name)
107
+ flat_hetero.append(False)
108
+ out_atom_cursor += 1
109
+ token_to_atoms.append([token_atom_start, out_atom_cursor])
110
+
111
+ return MolecularComplex(
112
+ id=complex_id,
113
+ sequence=sequence_tokens,
114
+ atom_positions=np.array(flat_positions, dtype=np.float32).reshape(-1, 3),
115
+ atom_elements=np.array(flat_elements, dtype=object),
116
+ token_to_atoms=np.array(token_to_atoms, dtype=np.int32).reshape(-1, 2),
117
+ chain_id=np.array(chain_ids_per_token, dtype=np.int64),
118
+ plddt=np.array(confidence, dtype=np.float32),
119
+ atom_names=np.array(flat_names, dtype=object),
120
+ atom_hetero=np.array(flat_hetero, dtype=bool),
121
+ metadata=MolecularComplexMetadata(
122
+ entity_lookup=entity_info,
123
+ chain_lookup=chain_lookup,
124
+ assembly_composition=None,
125
+ ),
126
+ )
127
+
128
+
129
+ def build_molecular_complex(
130
+ structure: Any, coords: torch.Tensor, plddt: torch.Tensor, complex_id: str
131
+ ) -> MolecularComplex:
132
+ """Directly constructs a MolecularComplex from model outputs without intermediate files.
133
+
134
+ Args:
135
+ structure: Object with .chains, .residues, .atoms numpy structured arrays.
136
+ coords: [N_atoms, 3] predicted atom coordinates.
137
+ plddt: [N_residues] per-residue confidence scores.
138
+ complex_id: Identifier string for the resulting complex.
139
+ """
140
+ flat_positions = []
141
+ flat_elements = []
142
+ flat_names = []
143
+ flat_hetero = []
144
+
145
+ sequence_tokens = []
146
+ token_to_atoms = []
147
+ chain_ids_per_token = []
148
+ confidence_scores = []
149
+
150
+ chain_lookup = {}
151
+ entity_info = {}
152
+
153
+ global_atom_cursor = 0
154
+ global_res_cursor = 0
155
+ atom_array_idx = 0
156
+
157
+ for chain in structure.chains:
158
+ chain_idx_numeric = chain["asym_id"]
159
+ chain_name_str = str(chain["name"])
160
+ mol_type = chain["mol_type"]
161
+
162
+ chain_lookup[chain_idx_numeric] = chain_name_str
163
+ entity_info[chain["entity_id"]] = (
164
+ "polymer" if mol_type != MOL_TYPE_NONPOLYMER else "non-polymer"
165
+ )
166
+
167
+ res_start = chain["res_idx"]
168
+ res_end = chain["res_idx"] + chain["res_num"]
169
+ residues = structure.residues[res_start:res_end]
170
+
171
+ for residue in residues:
172
+ res_name = str(residue["name"])
173
+
174
+ sequence_tokens.append(res_name)
175
+ chain_ids_per_token.append(chain_idx_numeric)
176
+
177
+ score = plddt[global_res_cursor].item()
178
+ confidence_scores.append(score)
179
+ token_start_idx = atom_array_idx
180
+
181
+ atom_start = residue["atom_idx"]
182
+ atom_end = residue["atom_idx"] + residue["atom_num"]
183
+ atoms = structure.atoms[atom_start:atom_end]
184
+
185
+ for atom in atoms:
186
+ if not atom["is_present"]:
187
+ continue
188
+
189
+ pos = coords[global_atom_cursor].tolist()
190
+ flat_positions.append(pos)
191
+
192
+ elem = get_element_symbol(atom["element"].item())
193
+ flat_elements.append(elem)
194
+
195
+ raw_name = atom["name"]
196
+ if hasattr(raw_name, "tolist"):
197
+ raw_name = raw_name.tolist()
198
+ name_str = "".join([chr(c + 32) for c in raw_name if c != 0])
199
+ flat_names.append(name_str)
200
+
201
+ flat_hetero.append(mol_type == MOL_TYPE_NONPOLYMER)
202
+
203
+ global_atom_cursor += 1
204
+ atom_array_idx += 1
205
+
206
+ token_to_atoms.append([token_start_idx, atom_array_idx])
207
+ global_res_cursor += 1
208
+
209
+ return MolecularComplex(
210
+ id=complex_id,
211
+ sequence=sequence_tokens,
212
+ atom_positions=np.array(flat_positions, dtype=np.float32),
213
+ atom_elements=np.array(flat_elements, dtype=object),
214
+ token_to_atoms=np.array(token_to_atoms, dtype=np.int32),
215
+ chain_id=np.array(chain_ids_per_token, dtype=np.int64),
216
+ plddt=np.array(confidence_scores, dtype=np.float32),
217
+ atom_names=np.array(flat_names, dtype=object),
218
+ atom_hetero=np.array(flat_hetero, dtype=bool),
219
+ metadata=MolecularComplexMetadata(
220
+ entity_lookup=entity_info,
221
+ chain_lookup=chain_lookup,
222
+ assembly_composition=None,
223
+ ),
224
+ )
225
+
esmfold2_paired_msa.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Taxonomy-paired MSA construction for ESMFold2 inference.
2
+
3
+ Taxonomy IDs are read from FASTA headers as ``key=N`` tokens. Rows
4
+ where any chain has ``key=-1`` (or no ``key=`` at all) are treated as
5
+ unpaired and assigned to that chain's block-diagonal section after
6
+ the paired rows.
7
+ """
8
+
9
+ import re
10
+
11
+ import numpy as np
12
+
13
+ from .esmfold2_constants import (
14
+ MSA_GAP_TOKEN_ID,
15
+ PROTEIN_3TO1,
16
+ PROTEIN_RESIDUE_TO_RES_TYPE,
17
+ PROTEIN_UNK_RES_TYPE,
18
+ )
19
+ from .esmfold2_msa import MSA
20
+
21
+ _KEY_RE = re.compile(r"key=(-?\d+)")
22
+
23
+
24
+ def protein_letter_to_res_type() -> dict[str, int]:
25
+ """Return the protein 1-letter → res_type mapping used by the MSA encoder."""
26
+ mapping: dict[str, int] = {}
27
+ for three, one in PROTEIN_3TO1.items():
28
+ if three in PROTEIN_RESIDUE_TO_RES_TYPE:
29
+ mapping[one] = PROTEIN_RESIDUE_TO_RES_TYPE[three]
30
+ mapping["-"] = MSA_GAP_TOKEN_ID
31
+ mapping["X"] = PROTEIN_UNK_RES_TYPE
32
+ return mapping
33
+
34
+
35
+ def _taxonomy_from_header(header: str) -> int:
36
+ if not header:
37
+ return -1
38
+ m = _KEY_RE.search(header)
39
+ return int(m.group(1)) if m else -1
40
+
41
+
42
+ def msa_to_res_type_and_deletions(
43
+ msa: MSA, letter_to_res_type: dict[str, int]
44
+ ) -> tuple[np.ndarray, np.ndarray]:
45
+ """Convert an :class:`MSA` to ``(res_type[M, L], deletion_count[M, L])``.
46
+
47
+ Handles a3m insertion convention: lowercase letters and ``.`` are
48
+ insertions and are not emitted; their count is accumulated into the
49
+ next non-insertion position's deletion value. ``L`` is the query
50
+ length after stripping insertions from row 0.
51
+ """
52
+ query = msa.entries[0].sequence
53
+ L = sum(1 for ch in query if not (ch.islower() or ch == "."))
54
+ M = msa.depth
55
+
56
+ res_type = np.full((M, L), MSA_GAP_TOKEN_ID, dtype=np.int64)
57
+ deletions = np.zeros((M, L), dtype=np.float32)
58
+
59
+ for r, entry in enumerate(msa.entries):
60
+ col = 0
61
+ ins = 0
62
+ for ch in entry.sequence:
63
+ if ch == "." or (ch.islower() and ch != "-"):
64
+ ins += 1
65
+ continue
66
+ if col >= L:
67
+ break
68
+ if ch == "-":
69
+ res_type[r, col] = MSA_GAP_TOKEN_ID
70
+ else:
71
+ res_type[r, col] = letter_to_res_type.get(
72
+ ch.upper(), PROTEIN_UNK_RES_TYPE
73
+ )
74
+ if ins > 0:
75
+ deletions[r, col] = float(ins)
76
+ ins = 0
77
+ col += 1
78
+ return res_type, deletions
79
+
80
+
81
+ def _dummy_msa_residues(query_res_types: np.ndarray) -> np.ndarray:
82
+ """Single-row 'MSA' for chains without one — just the query."""
83
+ return query_res_types[None, :] # [1, L]
84
+
85
+
86
+ def construct_paired_msa(
87
+ chain_msas: dict[int, MSA | None],
88
+ chain_query_res_types: dict[int, np.ndarray],
89
+ token_asym_ids: np.ndarray,
90
+ token_res_ids: np.ndarray,
91
+ letter_to_res_type: dict[str, int] | None = None,
92
+ *,
93
+ max_pairs: int = 8192,
94
+ max_total: int = 16384,
95
+ max_seqs: int = 16384,
96
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
97
+ """Build paired MSA features.
98
+
99
+ Parameters
100
+ ----------
101
+ chain_msas
102
+ ``asym_id -> MSA`` (or ``None`` for chains without an MSA).
103
+ chain_query_res_types
104
+ ``asym_id -> np.ndarray[L_c]`` of res-type ids for the chain's
105
+ query. Used to build dummy MSAs when a chain has no MSA.
106
+ token_asym_ids
107
+ Per-token asym_id, length ``T``. Must be non-decreasing.
108
+ token_res_ids
109
+ Per-token residue index within chain, length ``T``.
110
+ letter_to_res_type
111
+ 1-letter → res-type mapping. Defaults to
112
+ :func:`protein_letter_to_res_type`.
113
+
114
+ Returns
115
+ -------
116
+ msa_residues : ``np.ndarray[M, T]`` int64
117
+ deletion_value : ``np.ndarray[M, T]`` float32 (raw deletion counts; the
118
+ ``arctan(/3) * pi/2`` transform is applied by the caller)
119
+ is_paired : ``np.ndarray[M, T]`` float32 broadcast of per-row,
120
+ per-chain paired flags.
121
+ """
122
+ if letter_to_res_type is None:
123
+ letter_to_res_type = protein_letter_to_res_type()
124
+
125
+ chain_ids: list[int] = sorted(chain_msas.keys())
126
+
127
+ # Build per-chain (res_type, deletions, taxonomy) tables.
128
+ chain_res_type: dict[int, np.ndarray] = {}
129
+ chain_deletions: dict[int, np.ndarray] = {}
130
+ chain_taxonomies: dict[int, list[int]] = {}
131
+ for c in chain_ids:
132
+ m = chain_msas.get(c)
133
+ if m is None or m.depth == 0:
134
+ qres = chain_query_res_types[c]
135
+ chain_res_type[c] = _dummy_msa_residues(qres)
136
+ chain_deletions[c] = np.zeros((1, qres.shape[0]), dtype=np.float32)
137
+ chain_taxonomies[c] = [-1]
138
+ continue
139
+ rt, dl = msa_to_res_type_and_deletions(m, letter_to_res_type)
140
+ chain_res_type[c] = rt
141
+ chain_deletions[c] = dl
142
+ chain_taxonomies[c] = [_taxonomy_from_header(e.header) for e in m.entries]
143
+
144
+ # Group by taxonomy, skip query row and unpaired (-1) entries.
145
+ taxonomy_map: dict[int, list[tuple[int, int]]] = {}
146
+ for c in chain_ids:
147
+ for seq_idx, taxon in enumerate(chain_taxonomies[c]):
148
+ if seq_idx == 0 or taxon == -1:
149
+ continue
150
+ taxonomy_map.setdefault(taxon, []).append((c, seq_idx))
151
+ taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1}
152
+ # Order taxonomies by number of distinct chains, descending.
153
+ sorted_taxa = sorted(
154
+ taxonomy_map.items(), key=lambda kv: len({c for c, _ in kv[1]}), reverse=True
155
+ )
156
+
157
+ visited = {s for _, items in taxonomy_map.items() for s in items}
158
+ available: dict[int, list[int]] = {
159
+ c: [i for i in range(1, len(chain_taxonomies[c])) if (c, i) not in visited]
160
+ for c in chain_ids
161
+ }
162
+
163
+ pairing: list[dict[int, int]] = [{c: 0 for c in chain_ids}]
164
+ is_paired: list[dict[int, int]] = [{c: 1 for c in chain_ids}]
165
+
166
+ for _, pairs in sorted_taxa:
167
+ per_chain: dict[int, list[int]] = {}
168
+ for c, seq_idx in pairs:
169
+ per_chain.setdefault(c, []).append(seq_idx)
170
+ max_occ = max(len(v) for v in per_chain.values())
171
+ for i in range(max_occ):
172
+ row_pairing: dict[int, int] = {}
173
+ row_is_paired: dict[int, int] = {}
174
+ for c, seq_idxs in per_chain.items():
175
+ row_pairing[c] = seq_idxs[i % len(seq_idxs)]
176
+ row_is_paired[c] = 1
177
+ for c in chain_ids:
178
+ if c in row_pairing:
179
+ continue
180
+ row_is_paired[c] = 0
181
+ if available[c]:
182
+ row_pairing[c] = available[c].pop(0)
183
+ else:
184
+ row_pairing[c] = -1
185
+ pairing.append(row_pairing)
186
+ is_paired.append(row_is_paired)
187
+ if len(pairing) >= max_pairs:
188
+ break
189
+ if len(pairing) >= max_pairs:
190
+ break
191
+
192
+ max_left = max((len(v) for v in available.values()), default=0)
193
+ for _ in range(min(max_total - len(pairing), max_left)):
194
+ row_pairing = {}
195
+ row_is_paired = {}
196
+ for c in chain_ids:
197
+ row_is_paired[c] = 0
198
+ if available[c]:
199
+ row_pairing[c] = available[c].pop(0)
200
+ else:
201
+ row_pairing[c] = -1
202
+ pairing.append(row_pairing)
203
+ is_paired.append(row_is_paired)
204
+ if len(pairing) >= max_total:
205
+ break
206
+
207
+ pairing = pairing[:max_seqs]
208
+ is_paired = is_paired[:max_seqs]
209
+ M = len(pairing)
210
+ T = len(token_asym_ids)
211
+
212
+ msa_residues = np.full((M, T), MSA_GAP_TOKEN_ID, dtype=np.int64)
213
+ deletion_value = np.zeros((M, T), dtype=np.float32)
214
+ paired_mask = np.zeros((M, T), dtype=np.float32)
215
+
216
+ # Vectorize per chain: gather chain rows according to pairing[c], then
217
+ # index into them by the chain's token residue ids.
218
+ for c in chain_ids:
219
+ rt = chain_res_type[c]
220
+ dl = chain_deletions[c]
221
+ Lc = rt.shape[1]
222
+ chain_pairing = np.array([row[c] for row in pairing], dtype=np.int64)
223
+ chain_paired = np.array([row[c] for row in is_paired], dtype=np.float32)
224
+
225
+ token_mask = token_asym_ids == c
226
+ if not token_mask.any():
227
+ continue
228
+ token_res_in_chain = token_res_ids[token_mask]
229
+ # Clamp residue indices to the MSA's column range. Modified-residue
230
+ # tokens that exceed the query length fall back to the last column.
231
+ cols = np.minimum(token_res_in_chain, Lc - 1)
232
+
233
+ # Rows where pairing == -1 fall back to gap (already initialized).
234
+ valid_rows = chain_pairing >= 0
235
+ if valid_rows.any():
236
+ gathered_rt = rt[chain_pairing[valid_rows]][:, cols]
237
+ gathered_dl = dl[chain_pairing[valid_rows]][:, cols]
238
+ valid_idx = np.where(valid_rows)[0]
239
+ token_idx = np.where(token_mask)[0]
240
+ msa_residues[np.ix_(valid_idx, token_idx)] = gathered_rt
241
+ deletion_value[np.ix_(valid_idx, token_idx)] = gathered_dl
242
+
243
+ paired_mask[:, token_mask] = chain_paired[:, None]
244
+
245
+ return msa_residues, deletion_value, paired_mask
246
+
esmfold2_parsing.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from pathlib import Path
3
+ from typing import Generator, Iterable, NamedTuple
4
+
5
+ PathOrBuffer = str | Path | io.TextIOBase
6
+ FastaEntry = NamedTuple("FastaEntry", [("header", str), ("sequence", str)])
7
+
8
+
9
+ def parse_fasta(fasta_string: str) -> Generator[FastaEntry, None, None]:
10
+ """
11
+ Parses a fasta file and yields FastaEntry objects
12
+
13
+ Args:
14
+ fasta_string: The fasta file as a string
15
+ Returns:
16
+ A generator of FastaEntry objects
17
+ """
18
+ header = None
19
+ seq = []
20
+ num_sequences = 0
21
+ for line in fasta_string.splitlines():
22
+ if not line or line[0] == "#":
23
+ continue
24
+ if line.startswith(">"):
25
+ if header is not None:
26
+ yield FastaEntry(header, "".join(seq))
27
+ seq = []
28
+ header = line[1:].strip()
29
+ else:
30
+ seq.append(line)
31
+ if header is not None:
32
+ num_sequences += 1
33
+ yield FastaEntry(header, "".join(seq))
34
+
35
+ if num_sequences == 0:
36
+ raise ValueError("Found no sequences in input")
37
+
38
+
39
+ def read_sequences(path: PathOrBuffer) -> Generator[FastaEntry, None, None]:
40
+ # Uses duck typing to try and call the right method
41
+ # Doesn't use explicit isinstance check to support
42
+ # inputs that are not explicitly str/Path/TextIOBase but
43
+ # may support similar functionality
44
+ data = None # type: ignore
45
+ try:
46
+ if str(path).endswith(".gz"):
47
+ import gzip
48
+
49
+ data = gzip.open(path, "rt") # type: ignore
50
+ else:
51
+ try:
52
+ data = open(path) # type: ignore
53
+ except TypeError:
54
+ data: io.TextIOBase = path # type: ignore
55
+
56
+ yield from parse_fasta(data.read())
57
+ finally:
58
+ if data is not None:
59
+ data.close()
60
+
61
+
62
+ def read_first_sequence(path: PathOrBuffer) -> FastaEntry:
63
+ return next(iter(read_sequences(path)))
64
+
65
+
66
+ def count_fasta_sequences(path: str | Path) -> int:
67
+ """Count sequences in a FASTA file by counting header lines.
68
+
69
+ Faster than parsing the full file — only scans for '>' prefixes.
70
+ Returns 0 if the file does not exist.
71
+ """
72
+ path = Path(path)
73
+ if not path.exists():
74
+ return 0
75
+ with open(path) as f:
76
+ return sum(1 for line in f if line.startswith(">"))
77
+
78
+
79
+ def append_fasta_sequence(header: str, sequence: str, path: str | Path) -> None:
80
+ """Append a single sequence to a FASTA file (creating it if needed)."""
81
+ path = Path(path)
82
+ path.parent.mkdir(parents=True, exist_ok=True)
83
+ # The existing file may not end with a newline (e.g., write_sequences()
84
+ # explicitly avoids writing a newline at the end), so we insert one before
85
+ # appending to avoid merging with the last line.
86
+ needs_newline = (
87
+ path.exists() and path.stat().st_size > 0 and path.read_bytes()[-1:] != b"\n"
88
+ )
89
+ with open(path, "a") as f:
90
+ if needs_newline:
91
+ f.write("\n")
92
+ f.write(f">{header}\n{sequence}\n")
93
+
94
+
95
+ def write_sequences(sequences: Iterable[tuple[str, str]], path: PathOrBuffer) -> None:
96
+ needs_closing = False
97
+ handle = None
98
+ try:
99
+ try:
100
+ handle = open(path, "w") # type: ignore
101
+ needs_closing = True
102
+ except TypeError:
103
+ handle = path
104
+ has_prev = False
105
+ for header, seq in sequences:
106
+ if has_prev:
107
+ handle.write("\n") # type: ignore
108
+ handle.write(f">{header}\n{seq}") # type: ignore
109
+ has_prev = True
110
+ finally:
111
+ if needs_closing:
112
+ handle.close() # type: ignore
113
+
esmfold2_predicted_aligned_error.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from .esmfold2_affine3d import Affine3D
5
+
6
+
7
+ def masked_mean(
8
+ mask: torch.Tensor,
9
+ value: torch.Tensor,
10
+ dim: int | None | tuple[int, ...] = None,
11
+ eps=1e-10,
12
+ ) -> torch.Tensor:
13
+ """Compute the mean of `value` where only positions where `mask == true` are
14
+ counted.
15
+ """
16
+ mask = mask.expand(*value.shape)
17
+ return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
18
+
19
+
20
+ def _pae_bins(
21
+ max_bin: float = 31, num_bins: int = 64, device: torch.device = torch.device("cpu")
22
+ ):
23
+ bins = torch.linspace(0, max_bin, steps=(num_bins - 1), device=device)
24
+ step = max_bin / (num_bins - 2)
25
+ bin_centers = bins + step / 2
26
+ bin_centers = torch.cat(
27
+ [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
28
+ )
29
+ return bin_centers
30
+
31
+
32
+ def _compute_pae_masks(mask: torch.Tensor):
33
+ square_mask = (mask.unsqueeze(-1) * mask.unsqueeze(-2)).bool()
34
+ return square_mask
35
+
36
+
37
+ def compute_predicted_aligned_error(
38
+ logits: torch.Tensor,
39
+ aa_mask: torch.Tensor,
40
+ sequence_id: torch.Tensor | None = None,
41
+ max_bin: float = 31,
42
+ ) -> torch.Tensor:
43
+ bins = _pae_bins(max_bin, logits.shape[-1], logits.device)
44
+ square_mask = _compute_pae_masks(aa_mask)
45
+ min_v = torch.finfo(logits.dtype).min
46
+ probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1)
47
+
48
+ return (probs * bins).sum(dim=-1)
49
+
50
+
51
+ @torch.no_grad
52
+ def compute_tm(logits: torch.Tensor, aa_mask: torch.Tensor, max_bin: float = 31.0):
53
+ square_mask = _compute_pae_masks(aa_mask)
54
+ seqlens = aa_mask.sum(-1, keepdim=True)
55
+ bins = _pae_bins(max_bin, logits.shape[-1], logits.device)
56
+ d0 = 1.24 * (seqlens.clamp_min(19) - 15) ** (1 / 3) - 1.8
57
+ f_d = 1.0 / (1 + (bins / d0.unsqueeze(-1)) ** 2)
58
+
59
+ min_v = torch.finfo(logits.dtype).min
60
+ probs = logits.masked_fill(~square_mask.unsqueeze(-1), min_v).softmax(dim=-1)
61
+ # This is the sum over bins
62
+ ptm = (probs * f_d.unsqueeze(-2)).sum(dim=-1)
63
+ # This is the mean over residues j
64
+ ptm = masked_mean(square_mask, ptm, dim=-1)
65
+ # The we do a max over residues i
66
+ return ptm.max(dim=-1).values
67
+
68
+
69
+ def tm_loss(
70
+ logits: torch.Tensor,
71
+ pred_affine: torch.Tensor,
72
+ targ_affine: torch.Tensor,
73
+ targ_mask: torch.Tensor,
74
+ tm_mask: torch.Tensor | None = None,
75
+ sequence_id: torch.Tensor | None = None,
76
+ max_bin: float = 31,
77
+ ):
78
+ pred = Affine3D.from_tensor(pred_affine)
79
+ targ = Affine3D.from_tensor(targ_affine)
80
+
81
+ def transform(affine: Affine3D):
82
+ pts = affine.trans[..., None, :, :]
83
+ return affine.invert()[..., None].apply(pts)
84
+
85
+ with torch.no_grad():
86
+ sq_diff = (transform(pred) - transform(targ)).square().sum(dim=-1)
87
+
88
+ num_bins = logits.shape[-1]
89
+ sq_bins = torch.linspace(
90
+ 0, max_bin, num_bins - 1, device=logits.device
91
+ ).square()
92
+ # Gets the bin id by using a sum.
93
+ true_bins = (sq_diff[..., None] > sq_bins).sum(dim=-1).long()
94
+
95
+ errors = F.cross_entropy(logits.movedim(3, 1), true_bins, reduction="none")
96
+ square_mask = _compute_pae_masks(targ_mask)
97
+ loss = masked_mean(square_mask, errors, dim=(-1, -2))
98
+
99
+ if tm_mask is not None:
100
+ loss = masked_mean(tm_mask, loss, dim=None)
101
+ else:
102
+ loss = loss.mean()
103
+
104
+ return loss
105
+
esmfold2_prepare_input.py ADDED
@@ -0,0 +1,1464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prepare ESMFold2 model inputs from sequence-level StructurePredictionInput.
2
+
3
+ This module converts StructurePredictionInput (protein/DNA/RNA/ligand sequences)
4
+ into the tensor dict expected by the ESMFold2 model forward pass.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+ import warnings
11
+ from collections import defaultdict
12
+ from dataclasses import dataclass, field
13
+
14
+ import numpy as np
15
+ import torch
16
+
17
+ from .esmfold2_conformers import (
18
+ get_ccd_leaving_atoms,
19
+ get_idealized_atom_pos,
20
+ get_ligand_ccd_atoms_with_charges,
21
+ get_ligand_ccd_bonds,
22
+ get_ligand_idealized_atom_pos,
23
+ )
24
+ from .esmfold2_constants import (
25
+ CHARGED_ATOMS,
26
+ DNA_1TO3,
27
+ DNA_BACKBONE_ATOMS,
28
+ DNA_HEAVY_ATOMS,
29
+ DNA_RESIDUE_TO_RES_TYPE,
30
+ DNA_RNA_LIGAND_INPUT_ID,
31
+ DNA_UNK_RES_TYPE,
32
+ ELEMENT_TO_ATOMIC_NUM,
33
+ ESM_PROTEIN_VOCAB,
34
+ MOL_TYPE_DNA,
35
+ MOL_TYPE_NONPOLYMER,
36
+ MOL_TYPE_PROTEIN,
37
+ MOL_TYPE_RNA,
38
+ MSA_GAP_TOKEN_ID,
39
+ PROTEIN_1TO3,
40
+ PROTEIN_3TO1,
41
+ PROTEIN_HEAVY_ATOMS,
42
+ PROTEIN_RESIDUE_TO_RES_TYPE,
43
+ PROTEIN_UNK_RES_TYPE,
44
+ RNA_1TO3,
45
+ RNA_BACKBONE_ATOMS,
46
+ RNA_HEAVY_ATOMS,
47
+ RNA_RESIDUE_TO_RES_TYPE,
48
+ RNA_UNK_RES_TYPE,
49
+ )
50
+ from .esmfold2_types import (
51
+ MSA,
52
+ DNAInput,
53
+ LigandInput,
54
+ Modification,
55
+ ProteinInput,
56
+ RNAInput,
57
+ StructurePredictionInput,
58
+ )
59
+
60
+ # =============================================================================
61
+ # Lightweight data model
62
+ # =============================================================================
63
+
64
+ _ZERO_POS = np.array([0.0, 0.0, 0.0], dtype=np.float32)
65
+
66
+
67
+ @dataclass
68
+ class AtomInfo:
69
+ name: str
70
+ element: str
71
+ charge: int
72
+ ref_pos: np.ndarray # Idealized position from CCD [3]
73
+ pos: np.ndarray # Experimental position [3] (zeros for inference)
74
+ token_index: int = -1
75
+ atom_index: int = -1
76
+ space_uid: int = -1
77
+ is_valid: bool = True
78
+
79
+
80
+ @dataclass
81
+ class TokenInfo:
82
+ token_index: int
83
+ residue_index: int # Within chain (0-based)
84
+ residue_name: str # 3-letter code
85
+ mol_type: int # 0=protein, 1=DNA, 2=RNA, 3=nonpolymer
86
+ res_type: int # Residue type index (2-32)
87
+ input_id: int # ESM vocab ID
88
+ asym_id: int
89
+ sym_id: int
90
+ entity_id: int
91
+ atom_start: int # Index into atoms list
92
+ atom_count: int
93
+
94
+
95
+ @dataclass
96
+ class ChainInfo:
97
+ chain_id: str
98
+ asym_id: int
99
+ entity_id: int
100
+ sym_id: int
101
+ mol_type: int
102
+ tokens: list[TokenInfo] = field(default_factory=list)
103
+
104
+
105
+ # =============================================================================
106
+ # Helper functions
107
+ # =============================================================================
108
+
109
+ # Caches for hot-path functions
110
+ _ENCODE_ATOM_NAME_CACHE: dict[str, list[int]] = {}
111
+ _ELEMENT_ATOMIC_NUM_CACHE: dict[str, int] = {}
112
+
113
+
114
+ def encode_atom_name(name: str) -> list[int]:
115
+ """Encode atom name as 4 character indices (offset by 32 from ASCII)."""
116
+ if name in _ENCODE_ATOM_NAME_CACHE:
117
+ return _ENCODE_ATOM_NAME_CACHE[name]
118
+ padded = name.ljust(4)[:4]
119
+ result = [ord(c) - 32 if c != " " else 0 for c in padded]
120
+ _ENCODE_ATOM_NAME_CACHE[name] = result
121
+ return result
122
+
123
+
124
+ def get_element_atomic_num(element: str) -> int:
125
+ """Get atomic number for an element symbol."""
126
+ if element in _ELEMENT_ATOMIC_NUM_CACHE:
127
+ return _ELEMENT_ATOMIC_NUM_CACHE[element]
128
+ result = ELEMENT_TO_ATOMIC_NUM.get(element.upper(), 0)
129
+ _ELEMENT_ATOMIC_NUM_CACHE[element] = result
130
+ return result
131
+
132
+
133
+ def _infer_element(atom_name: str) -> str:
134
+ """Infer element from atom name."""
135
+ name = atom_name.strip()
136
+ if not name:
137
+ return "C"
138
+ if name[0].isdigit():
139
+ return name[1] if len(name) > 1 else "H"
140
+ if len(name) == 2 and name in (
141
+ "FE",
142
+ "ZN",
143
+ "MG",
144
+ "MN",
145
+ "CO",
146
+ "NI",
147
+ "CU",
148
+ "SE",
149
+ "BR",
150
+ ):
151
+ return name
152
+ return name[0]
153
+
154
+
155
+ def _compute_res_type(name: str, mol_type: int) -> int:
156
+ """Compute residue type index from residue name and mol_type."""
157
+ if mol_type == MOL_TYPE_PROTEIN:
158
+ return PROTEIN_RESIDUE_TO_RES_TYPE.get(name, PROTEIN_UNK_RES_TYPE)
159
+ elif mol_type == MOL_TYPE_DNA:
160
+ if name in DNA_RESIDUE_TO_RES_TYPE:
161
+ return DNA_RESIDUE_TO_RES_TYPE[name]
162
+ if name in RNA_RESIDUE_TO_RES_TYPE:
163
+ return RNA_RESIDUE_TO_RES_TYPE[name]
164
+ return DNA_UNK_RES_TYPE
165
+ elif mol_type == MOL_TYPE_RNA:
166
+ if name in RNA_RESIDUE_TO_RES_TYPE:
167
+ return RNA_RESIDUE_TO_RES_TYPE[name]
168
+ if name in DNA_RESIDUE_TO_RES_TYPE:
169
+ return DNA_RESIDUE_TO_RES_TYPE[name]
170
+ return RNA_UNK_RES_TYPE
171
+ return PROTEIN_UNK_RES_TYPE
172
+
173
+
174
+ def _compute_esm_input_id(name: str, mol_type: int) -> int:
175
+ """Compute ESM vocabulary input ID."""
176
+ if mol_type == MOL_TYPE_PROTEIN:
177
+ letter = PROTEIN_3TO1.get(name)
178
+ if letter is None:
179
+ return DNA_RNA_LIGAND_INPUT_ID
180
+ return ESM_PROTEIN_VOCAB.get(letter, ESM_PROTEIN_VOCAB["X"])
181
+ return DNA_RNA_LIGAND_INPUT_ID
182
+
183
+
184
+ # =============================================================================
185
+ # Tokenization functions — build tokens and atoms from sequences
186
+ # =============================================================================
187
+
188
+
189
+ def tokenize_protein(
190
+ sequence: str,
191
+ modifications: list[Modification] | None,
192
+ entity_id: int,
193
+ asym_id: int,
194
+ sym_id: int,
195
+ token_offset: int,
196
+ atom_offset: int,
197
+ space_uid_offset: int,
198
+ ) -> tuple[list[TokenInfo], list[AtomInfo]]:
199
+ """Tokenize a protein sequence into tokens and atoms.
200
+
201
+ Standard residues produce 1 token with all heavy atoms.
202
+ Modified residues (from modifications) are atom-tokenized (1 token per atom).
203
+ """
204
+ tokens: list[TokenInfo] = []
205
+ atoms: list[AtomInfo] = []
206
+
207
+ # Build 3-letter sequence, applying modifications
208
+ seq_3letter = [PROTEIN_1TO3.get(c, "UNK") for c in sequence]
209
+ modified_positions: set[int] = set()
210
+ if modifications:
211
+ for mod in modifications:
212
+ seq_3letter[mod.position] = mod.ccd
213
+ modified_positions.add(mod.position)
214
+
215
+ token_idx = token_offset
216
+ atom_idx = atom_offset
217
+ space_uid = space_uid_offset
218
+
219
+ for res_idx, res_name in enumerate(seq_3letter):
220
+ # MSE → MET for atom lookup
221
+ res_corrected = "MET" if res_name == "MSE" else res_name
222
+ is_modified = res_idx in modified_positions
223
+
224
+ # Check if standard residue (has predefined atom list)
225
+ if not is_modified and res_corrected in PROTEIN_HEAVY_ATOMS:
226
+ # Standard residue: 1 token, multiple atoms
227
+ atom_names = PROTEIN_HEAVY_ATOMS[res_corrected]
228
+ res_type = _compute_res_type(res_corrected, MOL_TYPE_PROTEIN)
229
+ input_id = _compute_esm_input_id(res_corrected, MOL_TYPE_PROTEIN)
230
+
231
+ atom_start = atom_idx
232
+ for a_name in atom_names:
233
+ ref_pos = get_idealized_atom_pos(res_type, a_name)
234
+ atoms.append(
235
+ AtomInfo(
236
+ name=a_name,
237
+ element=_infer_element(a_name),
238
+ charge=CHARGED_ATOMS.get((res_corrected, a_name), 0),
239
+ ref_pos=ref_pos.copy()
240
+ if ref_pos is not None
241
+ else _ZERO_POS.copy(),
242
+ pos=_ZERO_POS.copy(),
243
+ token_index=token_idx,
244
+ atom_index=atom_idx,
245
+ space_uid=space_uid,
246
+ )
247
+ )
248
+ atom_idx += 1
249
+
250
+ tokens.append(
251
+ TokenInfo(
252
+ token_index=token_idx,
253
+ residue_index=res_idx,
254
+ residue_name=res_corrected,
255
+ mol_type=MOL_TYPE_PROTEIN,
256
+ res_type=res_type,
257
+ input_id=input_id,
258
+ asym_id=asym_id,
259
+ sym_id=sym_id,
260
+ entity_id=entity_id,
261
+ atom_start=atom_start,
262
+ atom_count=len(atom_names),
263
+ )
264
+ )
265
+ token_idx += 1
266
+ space_uid += 1
267
+
268
+ else:
269
+ # Modified or unknown residue: atom-tokenized
270
+ ccd_atoms = get_ligand_ccd_atoms_with_charges(res_name)
271
+ if ccd_atoms is None:
272
+ # Fallback: backbone only
273
+ ccd_atoms = [
274
+ (_infer_element(n), _infer_element(n), 0)
275
+ for n in ["N", "CA", "C", "O"]
276
+ ]
277
+
278
+ # Filter leaving atoms if not terminal
279
+ is_terminal = res_idx == len(seq_3letter) - 1
280
+ leaving_atoms = set() if is_terminal else get_ccd_leaving_atoms(res_name)
281
+ kept_atoms = [a for a in ccd_atoms if a[0] not in leaving_atoms]
282
+ # Single-atom residues (e.g. NH2 cap): the local frame is
283
+ # ill-defined with one atom; place at origin.
284
+ single_atom_residue = len(kept_atoms) == 1
285
+
286
+ for a_name, a_element, a_charge in kept_atoms:
287
+ ref_pos = get_ligand_idealized_atom_pos(res_name, a_name)
288
+ atoms.append(
289
+ AtomInfo(
290
+ name=a_name,
291
+ element=a_element,
292
+ charge=a_charge,
293
+ ref_pos=_ZERO_POS.copy()
294
+ if single_atom_residue
295
+ else (
296
+ ref_pos.copy() if ref_pos is not None else _ZERO_POS.copy()
297
+ ),
298
+ pos=_ZERO_POS.copy(),
299
+ token_index=token_idx,
300
+ atom_index=atom_idx,
301
+ space_uid=space_uid,
302
+ )
303
+ )
304
+ tokens.append(
305
+ TokenInfo(
306
+ token_index=token_idx,
307
+ residue_index=res_idx,
308
+ residue_name=res_name,
309
+ mol_type=MOL_TYPE_PROTEIN,
310
+ res_type=PROTEIN_UNK_RES_TYPE,
311
+ input_id=DNA_RNA_LIGAND_INPUT_ID,
312
+ asym_id=asym_id,
313
+ sym_id=sym_id,
314
+ entity_id=entity_id,
315
+ atom_start=atom_idx,
316
+ atom_count=1,
317
+ )
318
+ )
319
+ token_idx += 1
320
+ atom_idx += 1
321
+
322
+ space_uid += 1
323
+
324
+ return tokens, atoms
325
+
326
+
327
+ def tokenize_nucleotide(
328
+ sequence: str,
329
+ modifications: list[Modification] | None,
330
+ mol_type: int,
331
+ entity_id: int,
332
+ asym_id: int,
333
+ sym_id: int,
334
+ token_offset: int,
335
+ atom_offset: int,
336
+ space_uid_offset: int,
337
+ ) -> tuple[list[TokenInfo], list[AtomInfo]]:
338
+ """Tokenize a DNA or RNA sequence into tokens and atoms."""
339
+ tokens: list[TokenInfo] = []
340
+ atoms: list[AtomInfo] = []
341
+
342
+ letter_to_3 = DNA_1TO3 if mol_type == MOL_TYPE_DNA else RNA_1TO3
343
+ heavy_atoms = DNA_HEAVY_ATOMS if mol_type == MOL_TYPE_DNA else RNA_HEAVY_ATOMS
344
+ backbone_atoms = (
345
+ DNA_BACKBONE_ATOMS if mol_type == MOL_TYPE_DNA else RNA_BACKBONE_ATOMS
346
+ )
347
+ unk_res_type = DNA_UNK_RES_TYPE if mol_type == MOL_TYPE_DNA else RNA_UNK_RES_TYPE
348
+
349
+ seq_3letter = [letter_to_3.get(c, "UNK") for c in sequence]
350
+ modified_positions: set[int] = set()
351
+ if modifications:
352
+ for mod in modifications:
353
+ seq_3letter[mod.position] = mod.ccd
354
+ modified_positions.add(mod.position)
355
+
356
+ token_idx = token_offset
357
+ atom_idx = atom_offset
358
+ space_uid = space_uid_offset
359
+
360
+ for res_idx, res_name in enumerate(seq_3letter):
361
+ is_modified = res_idx in modified_positions
362
+
363
+ if not is_modified and res_name in heavy_atoms:
364
+ # Standard nucleotide
365
+ atom_names = heavy_atoms[res_name]
366
+ res_type = _compute_res_type(res_name, mol_type)
367
+ input_id = DNA_RNA_LIGAND_INPUT_ID
368
+
369
+ atom_start = atom_idx
370
+ for a_name in atom_names:
371
+ ref_pos = get_idealized_atom_pos(res_type, a_name)
372
+ atoms.append(
373
+ AtomInfo(
374
+ name=a_name,
375
+ element=_infer_element(a_name),
376
+ charge=CHARGED_ATOMS.get((res_name, a_name), 0),
377
+ ref_pos=ref_pos.copy()
378
+ if ref_pos is not None
379
+ else _ZERO_POS.copy(),
380
+ pos=_ZERO_POS.copy(),
381
+ token_index=token_idx,
382
+ atom_index=atom_idx,
383
+ space_uid=space_uid,
384
+ )
385
+ )
386
+ atom_idx += 1
387
+
388
+ tokens.append(
389
+ TokenInfo(
390
+ token_index=token_idx,
391
+ residue_index=res_idx,
392
+ residue_name=res_name,
393
+ mol_type=mol_type,
394
+ res_type=res_type,
395
+ input_id=input_id,
396
+ asym_id=asym_id,
397
+ sym_id=sym_id,
398
+ entity_id=entity_id,
399
+ atom_start=atom_start,
400
+ atom_count=len(atom_names),
401
+ )
402
+ )
403
+ token_idx += 1
404
+ space_uid += 1
405
+
406
+ elif not is_modified and res_name == "UNK":
407
+ # Unknown nucleotide: backbone only
408
+ atom_names = backbone_atoms
409
+ atom_start = atom_idx
410
+ for a_name in atom_names:
411
+ ref_pos = None # No idealized positions for UNK
412
+ atoms.append(
413
+ AtomInfo(
414
+ name=a_name,
415
+ element=_infer_element(a_name),
416
+ charge=0,
417
+ ref_pos=_ZERO_POS.copy(),
418
+ pos=_ZERO_POS.copy(),
419
+ token_index=token_idx,
420
+ atom_index=atom_idx,
421
+ space_uid=space_uid,
422
+ )
423
+ )
424
+ atom_idx += 1
425
+
426
+ tokens.append(
427
+ TokenInfo(
428
+ token_index=token_idx,
429
+ residue_index=res_idx,
430
+ residue_name=res_name,
431
+ mol_type=mol_type,
432
+ res_type=unk_res_type,
433
+ input_id=DNA_RNA_LIGAND_INPUT_ID,
434
+ asym_id=asym_id,
435
+ sym_id=sym_id,
436
+ entity_id=entity_id,
437
+ atom_start=atom_start,
438
+ atom_count=len(atom_names),
439
+ )
440
+ )
441
+ token_idx += 1
442
+ space_uid += 1
443
+
444
+ else:
445
+ # Modified nucleotide: atom-tokenized
446
+ ccd_atoms = get_ligand_ccd_atoms_with_charges(res_name)
447
+ if ccd_atoms is None:
448
+ ccd_atoms = [
449
+ (_infer_element(n), _infer_element(n), 0) for n in backbone_atoms
450
+ ]
451
+
452
+ is_terminal = res_idx == len(seq_3letter) - 1
453
+ leaving_atoms = set() if is_terminal else get_ccd_leaving_atoms(res_name)
454
+
455
+ for a_name, a_element, a_charge in ccd_atoms:
456
+ if a_name in leaving_atoms:
457
+ continue
458
+ ref_pos = get_ligand_idealized_atom_pos(res_name, a_name)
459
+ atoms.append(
460
+ AtomInfo(
461
+ name=a_name,
462
+ element=a_element,
463
+ charge=a_charge,
464
+ ref_pos=ref_pos.copy()
465
+ if ref_pos is not None
466
+ else _ZERO_POS.copy(),
467
+ pos=_ZERO_POS.copy(),
468
+ token_index=token_idx,
469
+ atom_index=atom_idx,
470
+ space_uid=space_uid,
471
+ )
472
+ )
473
+ tokens.append(
474
+ TokenInfo(
475
+ token_index=token_idx,
476
+ residue_index=res_idx,
477
+ residue_name=res_name,
478
+ mol_type=mol_type,
479
+ res_type=PROTEIN_UNK_RES_TYPE,
480
+ input_id=DNA_RNA_LIGAND_INPUT_ID,
481
+ asym_id=asym_id,
482
+ sym_id=sym_id,
483
+ entity_id=entity_id,
484
+ atom_start=atom_idx,
485
+ atom_count=1,
486
+ )
487
+ )
488
+ token_idx += 1
489
+ atom_idx += 1
490
+
491
+ space_uid += 1
492
+
493
+ return tokens, atoms
494
+
495
+
496
+ def tokenize_ligand_ccd(
497
+ ccd_codes: list[str],
498
+ entity_id: int,
499
+ asym_id: int,
500
+ sym_id: int,
501
+ token_offset: int,
502
+ atom_offset: int,
503
+ space_uid_offset: int,
504
+ has_covalent_bond: bool,
505
+ ) -> tuple[list[TokenInfo], list[AtomInfo]]:
506
+ """Tokenize a ligand from CCD codes (1 token per atom)."""
507
+ tokens: list[TokenInfo] = []
508
+ atoms: list[AtomInfo] = []
509
+
510
+ token_idx = token_offset
511
+ atom_idx = atom_offset
512
+ space_uid = space_uid_offset
513
+
514
+ for res_idx, code in enumerate(ccd_codes):
515
+ ccd_atoms = get_ligand_ccd_atoms_with_charges(code)
516
+ if ccd_atoms is None:
517
+ raise ValueError(f"CCD component {code} not found")
518
+
519
+ leaving_atoms = get_ccd_leaving_atoms(code) if has_covalent_bond else set()
520
+
521
+ for a_name, a_element, a_charge in ccd_atoms:
522
+ if a_name in leaving_atoms:
523
+ continue
524
+ ref_pos = get_ligand_idealized_atom_pos(code, a_name)
525
+ atoms.append(
526
+ AtomInfo(
527
+ name=a_name,
528
+ element=a_element,
529
+ charge=a_charge,
530
+ ref_pos=ref_pos.copy() if ref_pos is not None else _ZERO_POS.copy(),
531
+ pos=_ZERO_POS.copy(),
532
+ token_index=token_idx,
533
+ atom_index=atom_idx,
534
+ space_uid=space_uid,
535
+ )
536
+ )
537
+ tokens.append(
538
+ TokenInfo(
539
+ token_index=token_idx,
540
+ residue_index=res_idx,
541
+ residue_name=code,
542
+ mol_type=MOL_TYPE_NONPOLYMER,
543
+ res_type=PROTEIN_UNK_RES_TYPE,
544
+ input_id=DNA_RNA_LIGAND_INPUT_ID,
545
+ asym_id=asym_id,
546
+ sym_id=sym_id,
547
+ entity_id=entity_id,
548
+ atom_start=atom_idx,
549
+ atom_count=1,
550
+ )
551
+ )
552
+ token_idx += 1
553
+ atom_idx += 1
554
+
555
+ space_uid += 1
556
+
557
+ return tokens, atoms
558
+
559
+
560
+ def tokenize_ligand_smiles(
561
+ smiles: str,
562
+ entity_id: int,
563
+ asym_id: int,
564
+ sym_id: int,
565
+ token_offset: int,
566
+ atom_offset: int,
567
+ space_uid_offset: int,
568
+ seed: int | None = None,
569
+ ) -> tuple[list[TokenInfo], list[AtomInfo]]:
570
+ """Tokenize a ligand from SMILES (1 token per heavy atom)."""
571
+ from rdkit import Chem
572
+ from rdkit.Chem import AllChem
573
+
574
+ mol = Chem.MolFromSmiles(smiles)
575
+ if mol is None:
576
+ raise ValueError(f"Failed to parse SMILES: {smiles}")
577
+ mol = Chem.AddHs(mol)
578
+
579
+ # Assign atom names using canonical ranking
580
+ canonical_order = AllChem.CanonicalRankAtoms(mol) # type: ignore[attr-defined]
581
+ for atom, can_idx in zip(mol.GetAtoms(), canonical_order):
582
+ atom_name = atom.GetSymbol().upper() + str(can_idx + 1)
583
+ if len(atom_name) > 4:
584
+ raise ValueError(
585
+ f"SMILES {smiles} has atom name longer than 4 chars: {atom_name}"
586
+ )
587
+ atom.SetProp("name", atom_name)
588
+
589
+ # Generate 3D conformer
590
+ options = AllChem.ETKDGv3() # type: ignore[attr-defined]
591
+ options.clearConfs = False
592
+ if seed is not None:
593
+ options.randomSeed = seed
594
+ conf_id = AllChem.EmbedMolecule(mol, options) # type: ignore[attr-defined]
595
+ if conf_id == -1:
596
+ options.useRandomCoords = True
597
+ conf_id = AllChem.EmbedMolecule(mol, options) # type: ignore[attr-defined]
598
+ if conf_id != -1:
599
+ try:
600
+ AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000) # type: ignore[attr-defined]
601
+ except (RuntimeError, ValueError):
602
+ pass
603
+
604
+ # Remove hydrogens
605
+ mol_no_h = Chem.RemoveHs(mol)
606
+ if mol_no_h.GetNumConformers() == 0:
607
+ raise ValueError(f"Failed to generate conformer for SMILES: {smiles}")
608
+
609
+ conformer = mol_no_h.GetConformer(0)
610
+
611
+ tokens: list[TokenInfo] = []
612
+ atoms_list: list[AtomInfo] = []
613
+ token_idx = token_offset
614
+ atom_idx = atom_offset
615
+ space_uid = space_uid_offset
616
+
617
+ for atom in mol_no_h.GetAtoms():
618
+ a_name = atom.GetProp("name")
619
+ a_element = atom.GetSymbol()
620
+ a_charge = atom.GetFormalCharge()
621
+ pos_3d = conformer.GetAtomPosition(atom.GetIdx())
622
+ ref_pos = np.array([pos_3d.x, pos_3d.y, pos_3d.z], dtype=np.float32)
623
+
624
+ atoms_list.append(
625
+ AtomInfo(
626
+ name=a_name,
627
+ element=a_element,
628
+ charge=a_charge,
629
+ ref_pos=ref_pos,
630
+ pos=_ZERO_POS.copy(),
631
+ token_index=token_idx,
632
+ atom_index=atom_idx,
633
+ space_uid=space_uid,
634
+ )
635
+ )
636
+ tokens.append(
637
+ TokenInfo(
638
+ token_index=token_idx,
639
+ residue_index=0,
640
+ residue_name="LIG",
641
+ mol_type=MOL_TYPE_NONPOLYMER,
642
+ res_type=PROTEIN_UNK_RES_TYPE,
643
+ input_id=DNA_RNA_LIGAND_INPUT_ID,
644
+ asym_id=asym_id,
645
+ sym_id=sym_id,
646
+ entity_id=entity_id,
647
+ atom_start=atom_idx,
648
+ atom_count=1,
649
+ )
650
+ )
651
+ token_idx += 1
652
+ atom_idx += 1
653
+
654
+ return tokens, atoms_list
655
+
656
+
657
+ # =============================================================================
658
+ # Build chains from StructurePredictionInput
659
+ # =============================================================================
660
+
661
+
662
+ def _get_sequence_key(item) -> str:
663
+ """Get a hashable key for entity deduplication."""
664
+ if isinstance(item, ProteinInput):
665
+ return f"PROTEIN:{item.sequence}"
666
+ elif isinstance(item, DNAInput):
667
+ return f"DNA:{item.sequence}"
668
+ elif isinstance(item, RNAInput):
669
+ return f"RNA:{item.sequence}"
670
+ elif isinstance(item, LigandInput):
671
+ if item.ccd:
672
+ return f"LIGAND_CCD:{','.join(item.ccd)}"
673
+ return f"LIGAND_SMILES:{item.smiles}"
674
+ raise ValueError(f"Unknown input type: {type(item)}")
675
+
676
+
677
+ def build_chains_from_input(
678
+ input: StructurePredictionInput, seed: int | None = None
679
+ ) -> tuple[list[ChainInfo], list[TokenInfo], list[AtomInfo]]:
680
+ """Build chains, tokens, and atoms from StructurePredictionInput.
681
+
682
+ Handles entity deduplication (identical sequences get same entity_id),
683
+ sym_id assignment, and delegates to type-specific tokenization functions.
684
+ """
685
+ chains: list[ChainInfo] = []
686
+ all_tokens: list[TokenInfo] = []
687
+ all_atoms: list[AtomInfo] = []
688
+
689
+ # Entity deduplication
690
+ sequence_to_entity: dict[str, int] = {}
691
+ entity_sym_count: dict[int, int] = {}
692
+ next_entity_id = 0
693
+
694
+ # Gather chain IDs involved in covalent bonds
695
+ covalent_chain_ids: set[str] = set()
696
+ if input.covalent_bonds:
697
+ for cb in input.covalent_bonds:
698
+ covalent_chain_ids.update([cb.chain_id1, cb.chain_id2])
699
+
700
+ token_offset = 0
701
+ atom_offset = 0
702
+ space_uid_offset = 0
703
+ asym_id = 0
704
+
705
+ for item in input.sequences:
706
+ # Entity deduplication
707
+ seq_key = _get_sequence_key(item)
708
+ if seq_key in sequence_to_entity:
709
+ entity_id = sequence_to_entity[seq_key]
710
+ else:
711
+ entity_id = next_entity_id
712
+ sequence_to_entity[seq_key] = entity_id
713
+ next_entity_id += 1
714
+
715
+ # Get all chain IDs for this item
716
+ ids = [item.id] if isinstance(item.id, str) else item.id
717
+
718
+ for chain_id_str in ids:
719
+ # sym_id is the per-entity copy index; increment per chain so
720
+ # ProteinInput(id=['A','B']) gives chain A sym_id=0, chain B sym_id=1.
721
+ sym_id = entity_sym_count.get(entity_id, 0)
722
+ entity_sym_count[entity_id] = sym_id + 1
723
+ if isinstance(item, ProteinInput):
724
+ if item.msa is None:
725
+ warnings.warn(
726
+ f"No MSA provided for {item.id}, using single sequence mode"
727
+ )
728
+
729
+ new_tokens, new_atoms = tokenize_protein(
730
+ sequence=item.sequence,
731
+ modifications=item.modifications,
732
+ entity_id=entity_id,
733
+ asym_id=asym_id,
734
+ sym_id=sym_id,
735
+ token_offset=token_offset,
736
+ atom_offset=atom_offset,
737
+ space_uid_offset=space_uid_offset,
738
+ )
739
+
740
+ elif isinstance(item, (DNAInput, RNAInput)):
741
+ mol_type = MOL_TYPE_DNA if isinstance(item, DNAInput) else MOL_TYPE_RNA
742
+ new_tokens, new_atoms = tokenize_nucleotide(
743
+ sequence=item.sequence,
744
+ modifications=item.modifications,
745
+ mol_type=mol_type,
746
+ entity_id=entity_id,
747
+ asym_id=asym_id,
748
+ sym_id=sym_id,
749
+ token_offset=token_offset,
750
+ atom_offset=atom_offset,
751
+ space_uid_offset=space_uid_offset,
752
+ )
753
+
754
+ elif isinstance(item, LigandInput):
755
+ has_cov = chain_id_str in covalent_chain_ids
756
+ if item.ccd is not None:
757
+ if item.smiles is not None:
758
+ warnings.warn("Both ccd and smiles provided, using ccd")
759
+ new_tokens, new_atoms = tokenize_ligand_ccd(
760
+ ccd_codes=item.ccd,
761
+ entity_id=entity_id,
762
+ asym_id=asym_id,
763
+ sym_id=sym_id,
764
+ token_offset=token_offset,
765
+ atom_offset=atom_offset,
766
+ space_uid_offset=space_uid_offset,
767
+ has_covalent_bond=has_cov,
768
+ )
769
+ elif item.smiles is not None:
770
+ new_tokens, new_atoms = tokenize_ligand_smiles(
771
+ smiles=item.smiles,
772
+ entity_id=entity_id,
773
+ asym_id=asym_id,
774
+ sym_id=sym_id,
775
+ token_offset=token_offset,
776
+ atom_offset=atom_offset,
777
+ space_uid_offset=space_uid_offset,
778
+ seed=seed,
779
+ )
780
+ else:
781
+ raise ValueError("LigandInput must have either ccd or smiles")
782
+ else:
783
+ raise ValueError(f"Unknown input type: {type(item)}")
784
+
785
+ chain = ChainInfo(
786
+ chain_id=chain_id_str,
787
+ asym_id=asym_id,
788
+ entity_id=entity_id,
789
+ sym_id=sym_id,
790
+ mol_type=new_tokens[0].mol_type if new_tokens else MOL_TYPE_PROTEIN,
791
+ tokens=new_tokens,
792
+ )
793
+ chains.append(chain)
794
+ all_tokens.extend(new_tokens)
795
+ all_atoms.extend(new_atoms)
796
+
797
+ token_offset += len(new_tokens)
798
+ atom_offset += len(new_atoms)
799
+ space_uid_offset += len(set(a.space_uid for a in new_atoms))
800
+ asym_id += 1
801
+
802
+ return chains, all_tokens, all_atoms
803
+
804
+
805
+ # =============================================================================
806
+ # Feature tensor building
807
+ # =============================================================================
808
+
809
+
810
+ def compute_frame_indices(
811
+ tokens: list[TokenInfo], atoms: list[AtomInfo]
812
+ ) -> tuple[np.ndarray, np.ndarray]:
813
+ """Compute backbone frame indices for each token.
814
+
815
+ Protein: [N, CA, C]; DNA/RNA: [C1', C3', C4']; Ligand: distance-based.
816
+ """
817
+ # Build atom name -> atom_index lookup per token
818
+ token_atoms: dict[int, dict[str, int]] = defaultdict(dict)
819
+ for atom in atoms:
820
+ if atom.is_valid:
821
+ token_atoms[atom.token_index][atom.name] = atom.atom_index
822
+
823
+ # Ligand-token frames come from CCD reference-conformer geometry,
824
+ # grouped per residue. For each token, the frame is the 3 atoms nearest
825
+ # to its own atom in the residue's ref-pos space, ordered
826
+ # (1st-nearest, self, 2nd-nearest).
827
+ ligand_token_to_atom: dict[int, int] = {}
828
+ ligand_tokens_by_res: dict[tuple[int, int], list[int]] = defaultdict(list)
829
+ for t in tokens:
830
+ if t.mol_type == MOL_TYPE_NONPOLYMER:
831
+ ad = token_atoms.get(t.token_index)
832
+ if ad:
833
+ ligand_token_to_atom[t.token_index] = next(iter(ad.values()))
834
+ ligand_tokens_by_res[(t.asym_id, t.residue_index)].append(t.token_index)
835
+
836
+ ligand_token_frames: dict[int, tuple[int, int, int]] = {}
837
+ for tok_indices in ligand_tokens_by_res.values():
838
+ atom_indices = [
839
+ ligand_token_to_atom[ti] for ti in tok_indices if ti in ligand_token_to_atom
840
+ ]
841
+ if len(atom_indices) < 3:
842
+ for ti in tok_indices:
843
+ if ti in ligand_token_to_atom:
844
+ ai = ligand_token_to_atom[ti]
845
+ ligand_token_frames[ti] = (ai, ai, ai)
846
+ continue
847
+
848
+ ref_pos_chain = np.array([atoms[ai].ref_pos for ai in atom_indices])
849
+ dist_mat = np.sqrt(
850
+ ((ref_pos_chain[:, None] - ref_pos_chain[None]) ** 2).sum(-1)
851
+ )
852
+ sort_indices = np.argsort(dist_mat, axis=1)
853
+ local_frames = np.column_stack(
854
+ [sort_indices[:, 1], sort_indices[:, 0], sort_indices[:, 2]]
855
+ )
856
+
857
+ for ti in tok_indices:
858
+ if ti not in ligand_token_to_atom:
859
+ continue
860
+ ai = ligand_token_to_atom[ti]
861
+ local_idx = atom_indices.index(ai)
862
+ fl = local_frames[local_idx]
863
+ ligand_token_frames[ti] = (
864
+ atom_indices[fl[0]],
865
+ atom_indices[fl[1]],
866
+ atom_indices[fl[2]],
867
+ )
868
+
869
+ # Build frames for all tokens
870
+ frames_list: list[tuple[int, int, int]] = []
871
+ for t in tokens:
872
+ ad = token_atoms.get(t.token_index, {})
873
+ fallback = list(ad.values())[0] if ad else 0
874
+
875
+ if t.mol_type == MOL_TYPE_PROTEIN:
876
+ if t.res_type == PROTEIN_UNK_RES_TYPE:
877
+ frames_list.append((fallback, fallback, fallback))
878
+ else:
879
+ frames_list.append((ad.get("N", 0), ad.get("CA", 0), ad.get("C", 0)))
880
+ elif t.mol_type in (MOL_TYPE_DNA, MOL_TYPE_RNA):
881
+ if t.res_type == PROTEIN_UNK_RES_TYPE:
882
+ frames_list.append((fallback, fallback, fallback))
883
+ else:
884
+ frames_list.append(
885
+ (ad.get("C1'", 0), ad.get("C3'", 0), ad.get("C4'", 0))
886
+ )
887
+ elif t.mol_type == MOL_TYPE_NONPOLYMER:
888
+ if t.token_index in ligand_token_frames:
889
+ frames_list.append(ligand_token_frames[t.token_index])
890
+ else:
891
+ frames_list.append((fallback, fallback, fallback))
892
+ else:
893
+ frames_list.append((fallback, fallback, fallback))
894
+
895
+ frames = np.array(frames_list, dtype=np.int64)
896
+
897
+ # Compute resolved mask (vectorized)
898
+ n_atoms = len(atoms)
899
+ atom_positions = (
900
+ np.array([a.pos for a in atoms], dtype=np.float32)
901
+ if atoms
902
+ else np.zeros((0, 3), dtype=np.float32)
903
+ )
904
+ atom_is_valid = (
905
+ np.array([a.is_valid for a in atoms], dtype=bool)
906
+ if atoms
907
+ else np.zeros(0, dtype=bool)
908
+ )
909
+ atom_is_resolved = (
910
+ atom_is_valid & np.any(atom_positions != 0, axis=1)
911
+ if n_atoms > 0
912
+ else np.zeros(0, dtype=bool)
913
+ )
914
+
915
+ n_tokens = len(tokens)
916
+ if n_tokens == 0:
917
+ return frames, np.zeros(0, dtype=bool)
918
+
919
+ pos1 = atom_positions[frames[:, 0]]
920
+ pos2 = atom_positions[frames[:, 1]]
921
+ pos3 = atom_positions[frames[:, 2]]
922
+
923
+ all_resolved = (
924
+ atom_is_resolved[frames[:, 0]]
925
+ & atom_is_resolved[frames[:, 1]]
926
+ & atom_is_resolved[frames[:, 2]]
927
+ )
928
+ all_same = (frames[:, 0] == frames[:, 1]) & (frames[:, 1] == frames[:, 2])
929
+
930
+ v1 = pos1 - pos2
931
+ v2 = pos3 - pos2
932
+ norm1 = np.linalg.norm(v1, axis=1)
933
+ norm2 = np.linalg.norm(v2, axis=1)
934
+ valid_norms = (norm1 >= 1e-6) & (norm2 >= 1e-6)
935
+
936
+ cos_angle = np.zeros(n_tokens, dtype=np.float32)
937
+ mask = valid_norms
938
+ if np.any(mask):
939
+ cos_angle[mask] = np.sum(v1[mask] * v2[mask], axis=1) / (
940
+ norm1[mask] * norm2[mask]
941
+ )
942
+ cos_angle = np.clip(cos_angle, -1, 1)
943
+ angle_deg = np.degrees(np.arccos(np.abs(cos_angle)))
944
+ not_colinear = angle_deg >= 25
945
+
946
+ resolved_mask = all_resolved & ~all_same & valid_norms & not_colinear
947
+ return frames, resolved_mask
948
+
949
+
950
+ def compute_token_bonds(
951
+ tokens: list[TokenInfo],
952
+ atoms: list[AtomInfo],
953
+ input: StructurePredictionInput,
954
+ chains: list[ChainInfo],
955
+ ) -> torch.Tensor:
956
+ """Compute dense token bond matrix [L, L, 1].
957
+
958
+ Includes ligand intra-residue bonds (from CCD) and covalent bonds.
959
+ """
960
+ n_tokens = len(tokens)
961
+ edge_set: set[tuple[int, int]] = set()
962
+
963
+ def add_bond(i: int, j: int) -> None:
964
+ if i != j:
965
+ edge_set.add((min(i, j), max(i, j)))
966
+
967
+ # Build per-residue atom name -> token_index mapping for ligands and modified residues
968
+ # Key: (asym_id, residue_index, atom_name) -> token_index
969
+ atom_name_to_token: dict[tuple[int, int, str], int] = {}
970
+ for atom in atoms:
971
+ if atom.is_valid:
972
+ t = tokens[atom.token_index] if atom.token_index < len(tokens) else None
973
+ if t and (
974
+ t.mol_type == MOL_TYPE_NONPOLYMER or t.res_type == PROTEIN_UNK_RES_TYPE
975
+ ):
976
+ atom_name_to_token[(t.asym_id, t.residue_index, atom.name)] = (
977
+ atom.token_index
978
+ )
979
+
980
+ # Group atom-tokenized tokens by (asym_id, residue_index)
981
+ residue_tokens: dict[tuple[int, int], list[tuple[str, int]]] = defaultdict(list)
982
+ for atom in atoms:
983
+ if not atom.is_valid:
984
+ continue
985
+ t = tokens[atom.token_index] if atom.token_index < len(tokens) else None
986
+ if t and (
987
+ t.mol_type == MOL_TYPE_NONPOLYMER or t.res_type == PROTEIN_UNK_RES_TYPE
988
+ ):
989
+ residue_tokens[(t.asym_id, t.residue_index)].append(
990
+ (atom.name, atom.token_index)
991
+ )
992
+
993
+ # Add intra-residue bonds from CCD
994
+ for (asym_id_val, res_idx), atom_list in residue_tokens.items():
995
+ if not atom_list:
996
+ continue
997
+ res_name = tokens[atom_list[0][1]].residue_name
998
+ ccd_bonds = get_ligand_ccd_bonds(res_name)
999
+ atom_to_tok = {name: ti for name, ti in atom_list}
1000
+
1001
+ if ccd_bonds:
1002
+ for a1, a2 in ccd_bonds:
1003
+ if a1 in atom_to_tok and a2 in atom_to_tok:
1004
+ add_bond(atom_to_tok[a1], atom_to_tok[a2])
1005
+ else:
1006
+ # Fallback: fully connected within residue
1007
+ tok_indices = [ti for _, ti in atom_list]
1008
+ for i_idx in tok_indices:
1009
+ for j_idx in tok_indices:
1010
+ add_bond(i_idx, j_idx)
1011
+
1012
+ # Add covalent bonds from input
1013
+ if input.covalent_bonds:
1014
+ # Build chain_id -> chain mapping
1015
+ chain_by_id: dict[str, ChainInfo] = {c.chain_id: c for c in chains}
1016
+ # Build (asym_id, residue_index) -> list of tokens for atom index lookup
1017
+ chain_res_atoms: dict[tuple[int, int], list[AtomInfo]] = defaultdict(list)
1018
+ for atom in atoms:
1019
+ if atom.is_valid and atom.token_index < len(tokens):
1020
+ t = tokens[atom.token_index]
1021
+ chain_res_atoms[(t.asym_id, t.residue_index)].append(atom)
1022
+
1023
+ for cb in input.covalent_bonds:
1024
+ c1 = chain_by_id.get(cb.chain_id1)
1025
+ c2 = chain_by_id.get(cb.chain_id2)
1026
+ if c1 is None or c2 is None:
1027
+ continue
1028
+
1029
+ atoms_1 = chain_res_atoms.get((c1.asym_id, cb.res_idx1), [])
1030
+ atoms_2 = chain_res_atoms.get((c2.asym_id, cb.res_idx2), [])
1031
+
1032
+ if cb.atom_idx1 < len(atoms_1) and cb.atom_idx2 < len(atoms_2):
1033
+ add_bond(
1034
+ atoms_1[cb.atom_idx1].token_index, atoms_2[cb.atom_idx2].token_index
1035
+ )
1036
+
1037
+ # Add peptide bonds at modified-residue boundaries: an atom-tokenized
1038
+ # residue's N atom connects to the prev residue's C atom (and same for
1039
+ # the C side to the next residue's N).
1040
+ tokens_by_chain_res: dict[tuple[int, int], list[TokenInfo]] = defaultdict(list)
1041
+ for t in tokens:
1042
+ if t.mol_type == MOL_TYPE_PROTEIN:
1043
+ tokens_by_chain_res[(t.asym_id, t.residue_index)].append(t)
1044
+
1045
+ def _backbone_token(res_tokens: list[TokenInfo], atom_name: str) -> int | None:
1046
+ # Standard residue (single token wrapping all atoms): return that token.
1047
+ if len(res_tokens) == 1 and res_tokens[0].res_type != PROTEIN_UNK_RES_TYPE:
1048
+ return res_tokens[0].token_index
1049
+ for t in res_tokens:
1050
+ for a_idx in range(t.atom_start, t.atom_start + t.atom_count):
1051
+ if a_idx < len(atoms) and atoms[a_idx].name == atom_name:
1052
+ return t.token_index
1053
+ # Atom-tokenized residue without an atom of that name (e.g. ACE has
1054
+ # no N, NH2 has no C). Fall back to the first atom-tokenized token.
1055
+ return res_tokens[0].token_index if res_tokens else None
1056
+
1057
+ for (asym_id_val, res_idx), res_tokens in tokens_by_chain_res.items():
1058
+ is_atom_tokenized = any(t.res_type == PROTEIN_UNK_RES_TYPE for t in res_tokens)
1059
+ if not is_atom_tokenized:
1060
+ continue # Standard residue — no peptide bond added here.
1061
+ n_tok = _backbone_token(res_tokens, "N")
1062
+ c_tok = _backbone_token(res_tokens, "C")
1063
+ prev_tokens = tokens_by_chain_res.get((asym_id_val, res_idx - 1))
1064
+ if prev_tokens and n_tok is not None:
1065
+ prev_c = _backbone_token(prev_tokens, "C")
1066
+ if prev_c is not None:
1067
+ add_bond(prev_c, n_tok)
1068
+ next_tokens = tokens_by_chain_res.get((asym_id_val, res_idx + 1))
1069
+ if next_tokens and c_tok is not None:
1070
+ next_n = _backbone_token(next_tokens, "N")
1071
+ if next_n is not None:
1072
+ add_bond(c_tok, next_n)
1073
+
1074
+ # Expand to dense matrix
1075
+ bonds = torch.zeros(n_tokens, n_tokens, 1, dtype=torch.float32)
1076
+ for i, j in edge_set:
1077
+ bonds[i, j, 0] = 1.0
1078
+ bonds[j, i, 0] = 1.0
1079
+ return bonds
1080
+
1081
+
1082
+ def compute_representative_atoms(
1083
+ tokens: list[TokenInfo], atoms: list[AtomInfo]
1084
+ ) -> torch.Tensor:
1085
+ """Compute representative atom index per token (for token_to_rep_atom).
1086
+
1087
+ Returns:
1088
+ distogram_atom_idx: [L] — representative atom per token
1089
+ Protein: CB (or CA for GLY), DNA/RNA: C4/C2/C1', Ligand: first atom.
1090
+ """
1091
+ n_tokens = len(tokens)
1092
+
1093
+ # Build atom name -> index lookup per token
1094
+ token_atoms: dict[int, dict[str, int]] = defaultdict(dict)
1095
+ for atom in atoms:
1096
+ if atom.is_valid:
1097
+ token_atoms[atom.token_index][atom.name] = atom.atom_index
1098
+
1099
+ distogram_atom_idx = torch.zeros(n_tokens, dtype=torch.int64)
1100
+
1101
+ for t in tokens:
1102
+ ad = token_atoms.get(t.token_index, {})
1103
+ fallback_idx = list(ad.values())[0] if ad else 0
1104
+
1105
+ if t.mol_type == MOL_TYPE_PROTEIN:
1106
+ rep_idx = ad.get("CB", ad.get("CA", fallback_idx))
1107
+ elif t.mol_type in (MOL_TYPE_DNA, MOL_TYPE_RNA):
1108
+ if t.res_type in (27, 32): # Unknown nucleotides
1109
+ rep_idx = ad.get("C1'", fallback_idx)
1110
+ elif t.res_type in (23, 24, 28, 29): # Purines (A, G)
1111
+ rep_idx = ad.get("C4", ad.get("C1'", fallback_idx))
1112
+ else: # Pyrimidines (C, U, T)
1113
+ rep_idx = ad.get("C2", ad.get("C1'", fallback_idx))
1114
+ else:
1115
+ rep_idx = fallback_idx
1116
+
1117
+ distogram_atom_idx[t.token_index] = rep_idx
1118
+
1119
+ return distogram_atom_idx
1120
+
1121
+
1122
+ def compute_msa_features(
1123
+ input: StructurePredictionInput,
1124
+ chains: list[ChainInfo],
1125
+ tokens: list[TokenInfo],
1126
+ max_seqs: int = 16384,
1127
+ ) -> dict[str, torch.Tensor]:
1128
+ """Compute MSA features from protein MSAs.
1129
+
1130
+ Uses taxonomy-based pairing across chains
1131
+ (:func:`paired_msa.construct_paired_msa`): rows whose FASTA header
1132
+ contains ``key=N`` get paired across chains sharing the same ``N``.
1133
+
1134
+ Output: msa [M, L], deletion_value [M, L], has_deletion [M, L],
1135
+ deletion_mean [L], msa_mask [M, L]
1136
+ """
1137
+ from .esmfold2_paired_msa import (
1138
+ construct_paired_msa,
1139
+ protein_letter_to_res_type,
1140
+ )
1141
+
1142
+ n_tokens = len(tokens)
1143
+
1144
+ # A single ProteinInput with id=['A','B','C',...] yields one item but
1145
+ # multiple chains (one per id); broadcast the MSA across all of them.
1146
+ chain_msas: dict[int, MSA | None] = {}
1147
+ item_idx = 0
1148
+ for item in input.sequences:
1149
+ ids = [item.id] if isinstance(item.id, str) else list(item.id)
1150
+ for _ in ids:
1151
+ chain = chains[item_idx]
1152
+ if isinstance(item, ProteinInput):
1153
+ msa = item.msa
1154
+ if msa is None:
1155
+ msa = MSA.from_sequences([item.sequence])
1156
+ chain_msas[chain.asym_id] = msa
1157
+ else:
1158
+ chain_msas[chain.asym_id] = None
1159
+ item_idx += 1
1160
+
1161
+ letter_to_res_type = protein_letter_to_res_type()
1162
+
1163
+ # Build per-chain query res_types (used for chains without an MSA).
1164
+ chain_query_res_types: dict[int, np.ndarray] = {}
1165
+ for chain in chains:
1166
+ chain_tokens = [t for t in tokens if t.asym_id == chain.asym_id]
1167
+ chain_query_res_types[chain.asym_id] = np.array(
1168
+ [t.res_type for t in chain_tokens], dtype=np.int64
1169
+ )
1170
+
1171
+ token_asym_ids = np.array([t.asym_id for t in tokens], dtype=np.int64)
1172
+ token_res_ids = np.array([t.residue_index for t in tokens], dtype=np.int64)
1173
+
1174
+ msa_res, del_counts, paired = construct_paired_msa(
1175
+ chain_msas,
1176
+ chain_query_res_types,
1177
+ token_asym_ids,
1178
+ token_res_ids,
1179
+ letter_to_res_type=letter_to_res_type,
1180
+ max_seqs=max_seqs,
1181
+ )
1182
+
1183
+ # Tokens for chains without an MSA get their res_type at row 0 and gap
1184
+ # elsewhere; this mirrors the prior non-protein-token branch.
1185
+ for t in tokens:
1186
+ if chain_msas.get(t.asym_id) is None:
1187
+ msa_res[:, t.token_index] = MSA_GAP_TOKEN_ID
1188
+ msa_res[0, t.token_index] = t.res_type
1189
+
1190
+ if msa_res.shape[0] == 0:
1191
+ msa_res = np.full((1, n_tokens), MSA_GAP_TOKEN_ID, dtype=np.int64)
1192
+ del_counts = np.zeros((1, n_tokens), dtype=np.float32)
1193
+
1194
+ msa_data = torch.from_numpy(msa_res)
1195
+ del_data = torch.from_numpy(del_counts)
1196
+
1197
+ has_deletion = del_data > 0
1198
+ deletion_value = (np.pi / 2) * torch.arctan(del_data / 3)
1199
+ deletion_mean = deletion_value.mean(dim=0)
1200
+
1201
+ msa_mask = torch.ones_like(msa_data, dtype=torch.bool)
1202
+
1203
+ return {
1204
+ "msa": msa_data,
1205
+ "deletion_value": deletion_value,
1206
+ "has_deletion": has_deletion,
1207
+ "deletion_mean": deletion_mean,
1208
+ "msa_attention_mask": msa_mask,
1209
+ }
1210
+
1211
+
1212
+ def compute_distogram_conditioning(
1213
+ input: StructurePredictionInput,
1214
+ chains: list[ChainInfo],
1215
+ tokens: list[TokenInfo],
1216
+ disto_center: torch.Tensor,
1217
+ min_dist: float = 2.0,
1218
+ max_dist: float = 22.0,
1219
+ num_bins: int = 64,
1220
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1221
+ """Compute distogram conditioning from user-provided distograms.
1222
+
1223
+ Returns:
1224
+ disto_cond: [L, L] int64 (bin indices)
1225
+ disto_cond_mask: [L, L] bool
1226
+ """
1227
+ n_tokens = len(tokens)
1228
+ disto_cond = torch.zeros(n_tokens, n_tokens, dtype=torch.long)
1229
+ disto_cond_mask = torch.zeros(n_tokens, n_tokens, dtype=torch.bool)
1230
+
1231
+ if not input.distogram_conditioning:
1232
+ return disto_cond, disto_cond_mask
1233
+
1234
+ # Build chain_id -> asym_id mapping
1235
+ chain_id_to_asym: dict[str, int] = {c.chain_id: c.asym_id for c in chains}
1236
+
1237
+ # Build asym_id -> token indices mapping
1238
+ asym_to_tokens: dict[int, list[int]] = defaultdict(list)
1239
+ for t in tokens:
1240
+ asym_to_tokens[t.asym_id].append(t.token_index)
1241
+
1242
+ boundaries = torch.linspace(min_dist, max_dist, num_bins + 1)
1243
+
1244
+ for dc in input.distogram_conditioning:
1245
+ asym_id_val = chain_id_to_asym.get(dc.chain_id)
1246
+ if asym_id_val is None:
1247
+ continue
1248
+ tok_indices = asym_to_tokens[asym_id_val]
1249
+ n_chain = len(tok_indices)
1250
+ distogram = torch.tensor(dc.distogram, dtype=torch.float32)
1251
+
1252
+ if distogram.shape != (n_chain, n_chain):
1253
+ raise ValueError(
1254
+ f"Distogram shape {distogram.shape} doesn't match chain length {n_chain}"
1255
+ )
1256
+
1257
+ # Bin the distogram
1258
+ binned = torch.bucketize(distogram, boundaries[:-1]) - 1
1259
+ binned = binned.clamp(0, num_bins - 1)
1260
+
1261
+ for i, ti in enumerate(tok_indices):
1262
+ for j, tj in enumerate(tok_indices):
1263
+ disto_cond[ti, tj] = binned[i, j]
1264
+ disto_cond_mask[ti, tj] = True
1265
+
1266
+ return disto_cond, disto_cond_mask
1267
+
1268
+
1269
+ def build_feature_tensors(
1270
+ chains: list[ChainInfo],
1271
+ tokens: list[TokenInfo],
1272
+ atoms: list[AtomInfo],
1273
+ input: StructurePredictionInput,
1274
+ ) -> dict[str, torch.Tensor]:
1275
+ """Build all model input tensors from tokens and atoms."""
1276
+ n_tokens = len(tokens)
1277
+ n_real_atoms = len(atoms)
1278
+
1279
+ # Pad atoms to nearest multiple of 32
1280
+ target_atoms = math.ceil(n_real_atoms / 32) * 32 if n_real_atoms > 0 else 32
1281
+ n_padding = target_atoms - n_real_atoms
1282
+ padding_atoms = [
1283
+ AtomInfo(
1284
+ name="",
1285
+ element="",
1286
+ charge=0,
1287
+ ref_pos=_ZERO_POS.copy(),
1288
+ pos=_ZERO_POS.copy(),
1289
+ token_index=0,
1290
+ atom_index=n_real_atoms + i,
1291
+ space_uid=0,
1292
+ is_valid=False,
1293
+ )
1294
+ for i in range(n_padding)
1295
+ ]
1296
+ all_atoms = atoms + padding_atoms
1297
+ n_atoms = len(all_atoms)
1298
+
1299
+ # --- Token-level tensors ---
1300
+ token_index_arr = np.empty(n_tokens, dtype=np.int64)
1301
+ residue_index_arr = np.empty(n_tokens, dtype=np.int64)
1302
+ asym_id_arr = np.empty(n_tokens, dtype=np.int64)
1303
+ sym_id_arr = np.empty(n_tokens, dtype=np.int64)
1304
+ entity_id_arr = np.empty(n_tokens, dtype=np.int64)
1305
+ mol_type_arr = np.empty(n_tokens, dtype=np.int64)
1306
+ res_type_arr = np.empty(n_tokens, dtype=np.int64)
1307
+ input_ids_arr = np.empty(n_tokens, dtype=np.int64)
1308
+
1309
+ for i, t in enumerate(tokens):
1310
+ token_index_arr[i] = t.token_index
1311
+ residue_index_arr[i] = t.residue_index
1312
+ asym_id_arr[i] = t.asym_id
1313
+ sym_id_arr[i] = t.sym_id
1314
+ entity_id_arr[i] = t.entity_id
1315
+ mol_type_arr[i] = t.mol_type
1316
+ res_type_arr[i] = t.res_type
1317
+ input_ids_arr[i] = t.input_id
1318
+
1319
+ token_index = torch.from_numpy(token_index_arr)
1320
+ residue_index = torch.from_numpy(residue_index_arr)
1321
+ asym_id = torch.from_numpy(asym_id_arr)
1322
+ sym_id = torch.from_numpy(sym_id_arr)
1323
+ entity_id = torch.from_numpy(entity_id_arr)
1324
+ mol_type = torch.from_numpy(mol_type_arr)
1325
+ res_type = torch.from_numpy(res_type_arr)
1326
+ input_ids = torch.from_numpy(input_ids_arr)
1327
+ token_pad_mask = torch.ones(n_tokens, dtype=torch.bool)
1328
+
1329
+ # --- Atom-level tensors ---
1330
+ ref_pos_arr = np.zeros((n_atoms, 3), dtype=np.float32)
1331
+ ref_element_arr = np.zeros(n_atoms, dtype=np.int64)
1332
+ ref_charge_arr = np.zeros(n_atoms, dtype=np.int8)
1333
+ ref_atom_name_chars_arr = np.zeros((n_atoms, 4), dtype=np.int64)
1334
+ ref_space_uid_arr = np.zeros(n_atoms, dtype=np.int64)
1335
+ atom_pad_mask_arr = np.zeros(n_atoms, dtype=np.bool_)
1336
+ atom_to_token_arr = np.zeros(n_atoms, dtype=np.int64)
1337
+ all_positions = np.zeros((n_atoms, 3), dtype=np.float64)
1338
+ is_valid_arr = np.zeros(n_atoms, dtype=np.bool_)
1339
+
1340
+ for i, atom in enumerate(all_atoms):
1341
+ if atom.ref_pos is not None:
1342
+ ref_pos_arr[i] = atom.ref_pos
1343
+ ref_charge_arr[i] = atom.charge
1344
+ ref_space_uid_arr[i] = (
1345
+ atom.space_uid if atom.space_uid >= 0 else atom.token_index
1346
+ )
1347
+ atom_pad_mask_arr[i] = atom.is_valid
1348
+ is_valid_arr[i] = atom.is_valid
1349
+ all_positions[i] = atom.pos
1350
+
1351
+ if atom.is_valid:
1352
+ ref_element_arr[i] = get_element_atomic_num(atom.element)
1353
+ name_indices = encode_atom_name(atom.name)
1354
+ ref_atom_name_chars_arr[i] = name_indices
1355
+ atom_to_token_arr[i] = atom.token_index
1356
+
1357
+ ref_pos = torch.from_numpy(ref_pos_arr)
1358
+ ref_element = torch.from_numpy(ref_element_arr)
1359
+ ref_charge = torch.from_numpy(ref_charge_arr)
1360
+ ref_atom_name_chars = torch.from_numpy(ref_atom_name_chars_arr)
1361
+ ref_space_uid = torch.from_numpy(ref_space_uid_arr)
1362
+ atom_pad_mask = torch.from_numpy(atom_pad_mask_arr)
1363
+ atom_to_token = torch.from_numpy(atom_to_token_arr)
1364
+
1365
+ # Coordinates — center on resolved atoms
1366
+ raw_coords = torch.from_numpy(all_positions)
1367
+ is_nonzero = np.any(all_positions != 0, axis=1)
1368
+ atom_resolved_arr = is_valid_arr & is_nonzero
1369
+ resolved_mask = torch.from_numpy(atom_resolved_arr)
1370
+ valid_mask = torch.from_numpy(is_valid_arr)
1371
+
1372
+ if resolved_mask.any():
1373
+ centroid = raw_coords[resolved_mask].mean(dim=0, keepdim=True)
1374
+ raw_coords = raw_coords - centroid
1375
+ raw_coords[~valid_mask] = 0.0
1376
+
1377
+ coords = raw_coords.float().unsqueeze(0) # [1, A, 3]
1378
+ atom_resolved_mask = torch.tensor(atom_resolved_arr, dtype=torch.bool)
1379
+
1380
+ # --- Frames ---
1381
+ frames, _ = compute_frame_indices(tokens, atoms)
1382
+ frames_idx = torch.from_numpy(frames).to(torch.int64)
1383
+
1384
+ # --- Token bonds ---
1385
+ token_bonds = compute_token_bonds(tokens, atoms, input, chains)
1386
+
1387
+ # --- Representative atoms ---
1388
+ distogram_atom_idx = compute_representative_atoms(tokens, atoms)
1389
+
1390
+ # --- MSA features ---
1391
+ msa_features = compute_msa_features(input, chains, tokens)
1392
+
1393
+ # --- Distogram conditioning ---
1394
+ # disto_center is not needed for inference (no experimental coords)
1395
+ disto_center = torch.zeros(n_tokens, 3, dtype=torch.float32)
1396
+ disto_cond, disto_cond_mask = compute_distogram_conditioning(
1397
+ input, chains, tokens, disto_center
1398
+ )
1399
+
1400
+ # ref_pos: CCD conformer positions, used as-is for inference.
1401
+ # No random rotation or masking — at inference there are no resolved
1402
+ # experimental coordinates, so atom_resolved_mask is all False.
1403
+ # The model uses ref_pos for atom feature embedding.
1404
+
1405
+ # --- Pocket (dropped) ---
1406
+ pocket_feature = torch.zeros(n_tokens, dtype=torch.long)
1407
+
1408
+ return {
1409
+ # Token-level
1410
+ "token_index": token_index,
1411
+ "residue_index": residue_index,
1412
+ "asym_id": asym_id,
1413
+ "entity_id": entity_id,
1414
+ "sym_id": sym_id,
1415
+ "mol_type": mol_type,
1416
+ "res_type": res_type,
1417
+ "input_ids": input_ids,
1418
+ "token_bonds": token_bonds,
1419
+ "token_attention_mask": token_pad_mask,
1420
+ "pocket_feature": pocket_feature,
1421
+ # Atom-level
1422
+ "ref_pos": ref_pos,
1423
+ "ref_element": ref_element,
1424
+ "ref_charge": ref_charge,
1425
+ "ref_atom_name_chars": ref_atom_name_chars,
1426
+ "ref_space_uid": ref_space_uid,
1427
+ "gt_coords": coords,
1428
+ "atom_attention_mask": atom_pad_mask,
1429
+ "atom_to_token": atom_to_token,
1430
+ "is_resolved": atom_resolved_mask,
1431
+ "distogram_atom_idx": distogram_atom_idx,
1432
+ # Frames
1433
+ "frames_idx": frames_idx,
1434
+ # Distogram
1435
+ "disto_cond": disto_cond,
1436
+ "disto_cond_mask": disto_cond_mask,
1437
+ # MSA
1438
+ **msa_features,
1439
+ }
1440
+
1441
+
1442
+ # =============================================================================
1443
+ # Top-level entry point
1444
+ # =============================================================================
1445
+
1446
+
1447
+ def prepare_esmfold2_input(
1448
+ input: StructurePredictionInput, seed: int | None = None
1449
+ ) -> tuple[dict[str, torch.Tensor], list[ChainInfo]]:
1450
+ """Prepare ESMFold2 model inputs from StructurePredictionInput.
1451
+
1452
+ Args:
1453
+ input: The structure prediction input (sequences, conditioning, etc.)
1454
+ seed: Random seed for SMILES conformer generation and augmentation.
1455
+
1456
+ Returns:
1457
+ Tuple of (feature_dict, chain_infos) where feature_dict contains
1458
+ all tensors for the model forward pass, and chain_infos contains
1459
+ metadata for output processing.
1460
+ """
1461
+ chains, tokens, atoms = build_chains_from_input(input, seed)
1462
+ features = build_feature_tensors(chains, tokens, atoms, input)
1463
+ return features, chains
1464
+
esmfold2_processor.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from contextlib import contextmanager, nullcontext
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from .esmfold2_conformers import load_ccd
10
+ from .esmfold2_output import build_molecular_complex_from_features
11
+ from .esmfold2_prepare_input import ChainInfo, prepare_esmfold2_input
12
+ from .esmfold2_types import (
13
+ MSA,
14
+ Modification,
15
+ ProteinInput,
16
+ StructurePredictionInput,
17
+ )
18
+ from .esmfold2_molecular_complex import MolecularComplexResult
19
+
20
+
21
+ @contextmanager
22
+ def _seed_context(seed: int | None):
23
+ if seed is None:
24
+ yield
25
+ return
26
+ py_state = random.getstate()
27
+ np_state = np.random.get_state()
28
+ torch_state = torch.random.get_rng_state()
29
+ cuda_state = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
30
+ random.seed(seed)
31
+ np.random.seed(seed)
32
+ torch.manual_seed(seed)
33
+ if torch.cuda.is_available():
34
+ torch.cuda.manual_seed_all(seed)
35
+ try:
36
+ yield
37
+ finally:
38
+ random.setstate(py_state)
39
+ np.random.set_state(np_state)
40
+ torch.random.set_rng_state(torch_state)
41
+ if cuda_state is not None:
42
+ torch.cuda.set_rng_state_all(cuda_state)
43
+
44
+
45
+ def clean_esmfold2_input(input: StructurePredictionInput) -> StructurePredictionInput:
46
+ """Group identical protein sequences into the same ProteinInput with multiple ids.
47
+
48
+ Example: Passing a tetramer like [ProteinInput(id=["0"], seq="AAA|AAA|BBB|BBB")]
49
+ gets converted into [ProteinInput(id=["0_0", "0_1"], seq="AAA"),
50
+ ProteinInput(id=["0_2", "0_3"], seq="BBB")]
51
+
52
+ Preserves the original order of unique sequences. Also converts "|" chainbreak
53
+ tokens to ":" in the sequence.
54
+ """
55
+ cleaned_sequences: list = []
56
+ chain_to_ids: dict[str, list[str]] = {}
57
+ chain_to_modifications: dict[str, list] = {}
58
+ chain_to_msa: dict[str, MSA | None] = {}
59
+
60
+ for item in input.sequences:
61
+ if isinstance(item, ProteinInput):
62
+ sequence = ":".join(item.sequence.split("|"))
63
+ if ":" not in sequence:
64
+ cleaned_sequences.append(item)
65
+ continue
66
+
67
+ if ":" in sequence and input.covalent_bonds is not None:
68
+ raise ValueError(
69
+ "Covalent bonds are not supported when using chainbreaks. "
70
+ "Chains must be separated into multiple ProteinInput objects."
71
+ )
72
+
73
+ base_id = item.id[0] if isinstance(item.id, list) else item.id
74
+ chain_to_ids = {}
75
+ chain_to_modifications = {}
76
+ chain_to_msa = {}
77
+ chains = sequence.split(":")
78
+
79
+ chain_start_positions = []
80
+ pos = 0
81
+ for chain in chains:
82
+ chain_start_positions.append(pos)
83
+ pos += len(chain) + 1
84
+
85
+ if item.modifications is not None:
86
+ for chain_idx, chain in enumerate(chains):
87
+ chain_start = chain_start_positions[chain_idx]
88
+ chain_end = chain_start + len(chain)
89
+ chain_modifications = []
90
+ for mod in item.modifications:
91
+ if chain_start <= mod.position < chain_end:
92
+ adjusted_mod = Modification(
93
+ position=mod.position - chain_start, ccd=mod.ccd
94
+ )
95
+ chain_modifications.append(adjusted_mod)
96
+ if chain not in chain_to_modifications:
97
+ chain_to_modifications[chain] = chain_modifications
98
+ else:
99
+ chain_to_modifications[chain].extend(chain_modifications)
100
+
101
+ if item.msa is not None:
102
+ for chain_idx, chain in enumerate(chains):
103
+ if chain not in chain_to_msa:
104
+ chain_start = chain_start_positions[chain_idx]
105
+ chain_end = chain_start + len(chain)
106
+ chain_msa = item.msa.select_positions( # type: ignore
107
+ np.arange(chain_start, chain_end)
108
+ )
109
+ chain_to_msa[chain] = chain_msa
110
+
111
+ for i, chain in enumerate(chains):
112
+ chain_id = base_id + "_" + str(i)
113
+ if chain in chain_to_ids:
114
+ chain_to_ids[chain].append(chain_id)
115
+ else:
116
+ chain_to_ids[chain] = [chain_id]
117
+ cleaned_sequences.append((item, chain))
118
+ else:
119
+ cleaned_sequences.append(item)
120
+
121
+ for i in range(len(cleaned_sequences)):
122
+ if isinstance(cleaned_sequences[i], tuple):
123
+ item, chain = cleaned_sequences[i]
124
+ chain_ids = chain_to_ids[chain]
125
+ chain_modifications = (
126
+ chain_to_modifications.get(chain) if item.modifications else None
127
+ )
128
+ chain_msa = chain_to_msa.get(chain) if item.msa else None
129
+ cleaned_sequences[i] = ProteinInput(
130
+ id=chain_ids,
131
+ sequence=chain,
132
+ msa=chain_msa,
133
+ modifications=chain_modifications,
134
+ )
135
+
136
+ return StructurePredictionInput(
137
+ sequences=cleaned_sequences,
138
+ distogram_conditioning=input.distogram_conditioning,
139
+ covalent_bonds=input.covalent_bonds,
140
+ )
141
+
142
+
143
+ class ESMFold2InputBuilder:
144
+ def __init__(self, ccd_cache: Path | None = None):
145
+ load_ccd(ccd_cache)
146
+
147
+ def prepare_input(
148
+ self,
149
+ input: StructurePredictionInput,
150
+ seed: int | None = None,
151
+ device: torch.device | str | None = None,
152
+ ) -> tuple[dict, list[ChainInfo]]:
153
+ """Prepare raw input for the folding model.
154
+
155
+ Converts user-provided StructurePredictionInput into batched tensors
156
+ ready for model inference.
157
+
158
+ Parameters
159
+ ----------
160
+ input : StructurePredictionInput
161
+ Input specification (sequences, structures, constraints, etc.).
162
+ seed : int, optional
163
+ Random seed for reproducibility.
164
+ device : torch.device or str, optional
165
+ Target device for the returned tensors. Defaults to CPU; pass
166
+ ``model.device`` to skip a separate ``.to(...)`` step. ``fold()``
167
+ forwards ``model.device`` automatically.
168
+
169
+ Returns
170
+ -------
171
+ tuple[dict, list[ChainInfo]]
172
+ Batched input tensors and chain metadata for output processing.
173
+ """
174
+ structure_prediction_input = clean_esmfold2_input(input)
175
+ with _seed_context(seed) if seed is not None else nullcontext():
176
+ features, chain_infos = prepare_esmfold2_input(
177
+ structure_prediction_input, seed=seed
178
+ )
179
+ features = {
180
+ k: (v[None].to(device) if device is not None else v[None])
181
+ if isinstance(v, torch.Tensor)
182
+ else v
183
+ for k, v in features.items()
184
+ }
185
+
186
+ return features, chain_infos
187
+
188
+ def __call__(
189
+ self,
190
+ input: StructurePredictionInput,
191
+ seed: int | None = None,
192
+ device: torch.device | str | None = None,
193
+ ) -> tuple[dict, list[ChainInfo]]:
194
+ return self.prepare_input(input, seed=seed, device=device)
195
+
196
+ def decode(
197
+ self,
198
+ output: dict[str, torch.Tensor],
199
+ features: dict[str, torch.Tensor],
200
+ chain_infos: list[ChainInfo],
201
+ *,
202
+ num_diffusion_samples: int = 1,
203
+ complex_id: str = "pred",
204
+ ) -> MolecularComplexResult | list[MolecularComplexResult]:
205
+ """Convert raw model outputs into one MolecularComplexResult per sample.
206
+
207
+ Parameters
208
+ ----------
209
+ output : dict[str, Tensor]
210
+ Output dict returned by ESMFold2Model.forward.
211
+ features : dict[str, Tensor]
212
+ Feature dict from :meth:`prepare_input` (batched, on the model device).
213
+ chain_infos : list[ChainInfo]
214
+ Chain metadata returned alongside `features`.
215
+ num_diffusion_samples : int
216
+ Number of diffusion samples present in the output (Bm = B * num_diffusion_samples).
217
+ complex_id : str
218
+ Identifier assigned to each MolecularComplex.
219
+
220
+ Returns
221
+ -------
222
+ MolecularComplexResult or list[MolecularComplexResult]
223
+ A single result when num_diffusion_samples == 1, otherwise a list of length Bm.
224
+ """
225
+ atom_mask = features["atom_attention_mask"][0]
226
+ ref_element = features["ref_element"][0]
227
+ ref_atom_name_chars = features["ref_atom_name_chars"][0]
228
+
229
+ sample_coords = output["sample_atom_coords"]
230
+ plddts = output["plddt"]
231
+ Bm = sample_coords.shape[0]
232
+
233
+ ptm_t = output.get("ptm")
234
+ iptm_t = output.get("iptm")
235
+ pae_t = output.get("pae")
236
+ distogram_t = output.get("distogram_logits")
237
+ pair_chains_t = output.get("pair_chains_iptm")
238
+ residue_index_t = output.get("residue_index")
239
+ entity_id_t = output.get("entity_id")
240
+
241
+ results: list[MolecularComplexResult] = []
242
+ for i in range(Bm):
243
+ mc = build_molecular_complex_from_features(
244
+ coords=sample_coords[i],
245
+ plddt=plddts[i],
246
+ atom_mask=atom_mask,
247
+ ref_element=ref_element,
248
+ ref_atom_name_chars=ref_atom_name_chars,
249
+ chain_infos=chain_infos,
250
+ complex_id=complex_id,
251
+ )
252
+ results.append(
253
+ MolecularComplexResult(
254
+ complex=mc,
255
+ plddt=plddts[i].detach().cpu(),
256
+ ptm=float(ptm_t[i].item()) if ptm_t is not None else None,
257
+ iptm=float(iptm_t[i].item()) if iptm_t is not None else None,
258
+ pae=pae_t[i].detach().cpu() if pae_t is not None else None,
259
+ distogram=(
260
+ distogram_t[0].detach().cpu()
261
+ if distogram_t is not None
262
+ else None
263
+ ),
264
+ pair_chains_iptm=(
265
+ pair_chains_t[i].detach().cpu()
266
+ if pair_chains_t is not None
267
+ else None
268
+ ),
269
+ residue_index=(
270
+ residue_index_t[0].detach().cpu()
271
+ if residue_index_t is not None
272
+ else None
273
+ ),
274
+ entity_id=(
275
+ entity_id_t[0].detach().cpu()
276
+ if entity_id_t is not None
277
+ else None
278
+ ),
279
+ )
280
+ )
281
+
282
+ if num_diffusion_samples == 1 and len(results) == 1:
283
+ return results[0]
284
+ return results
285
+
286
+ def fold(
287
+ self,
288
+ model: Any,
289
+ input: StructurePredictionInput,
290
+ *,
291
+ num_loops: int = 3,
292
+ num_sampling_steps: int = 200,
293
+ num_diffusion_samples: int = 1,
294
+ seed: int | None = None,
295
+ noise_scale: float | None = None,
296
+ step_scale: float | None = None,
297
+ max_inference_sigma: int | None = None,
298
+ early_exit: bool = False,
299
+ complex_id: str = "pred",
300
+ ) -> MolecularComplexResult | list[MolecularComplexResult]:
301
+ """Fold a structure end-to-end: encode → model → decode.
302
+
303
+ Parameters
304
+ ----------
305
+ model : ESMFold2Model
306
+ The folding model. Must already be on the target device and in eval mode.
307
+ input : StructurePredictionInput
308
+ User-facing input specification.
309
+ num_loops, num_sampling_steps, num_diffusion_samples : int
310
+ Inference knobs forwarded to the model.
311
+ seed : int, optional
312
+ Seeds both input prep (SMILES conformer generation) and diffusion sampling.
313
+ noise_scale, step_scale, max_inference_sigma, early_exit
314
+ Optional sampler overrides forwarded to the model when not None.
315
+ complex_id : str
316
+ Identifier assigned to the predicted MolecularComplex(es).
317
+
318
+ Returns
319
+ -------
320
+ MolecularComplexResult or list[MolecularComplexResult]
321
+ A single result when num_diffusion_samples == 1, otherwise a list.
322
+ """
323
+ features, chain_infos = self.prepare_input(
324
+ input, seed=seed, device=model.device
325
+ )
326
+
327
+ sampler_kwargs: dict[str, Any] = {}
328
+ if noise_scale is not None:
329
+ sampler_kwargs["noise_scale"] = noise_scale
330
+ if step_scale is not None:
331
+ sampler_kwargs["step_scale"] = step_scale
332
+ if max_inference_sigma is not None:
333
+ sampler_kwargs["max_inference_sigma"] = max_inference_sigma
334
+
335
+ with torch.no_grad():
336
+ with _seed_context(seed) if seed is not None else nullcontext():
337
+ output = model(
338
+ **features,
339
+ num_loops=num_loops,
340
+ num_sampling_steps=num_sampling_steps,
341
+ num_diffusion_samples=num_diffusion_samples,
342
+ early_exit=early_exit,
343
+ **sampler_kwargs,
344
+ )
345
+
346
+ return self.decode(
347
+ output,
348
+ features,
349
+ chain_infos,
350
+ num_diffusion_samples=num_diffusion_samples,
351
+ complex_id=complex_id,
352
+ )
353
+
354
+
355
+ __all__ = ["ESMFold2InputBuilder", "clean_esmfold2_input"]
356
+
esmfold2_protein_chain.py ADDED
@@ -0,0 +1,1376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ import warnings
5
+ from dataclasses import asdict, dataclass, replace
6
+ from functools import cached_property
7
+ from pathlib import Path
8
+ from typing import Any, Mapping, Sequence
9
+
10
+ import biotite.structure as bs
11
+ import brotli
12
+ import msgpack
13
+ import msgpack_numpy
14
+ import numpy as np
15
+ import torch
16
+ from biotite.database import rcsb
17
+ from biotite.structure.io.pdb import PDBFile
18
+ from biotite.structure.io.pdbx import CIFCategory, CIFColumn, CIFData, CIFFile
19
+ from biotite.structure.io.pdbx import set_structure as set_structure_pdbx
20
+ from scipy.spatial import ConvexHull, KDTree
21
+ from scipy.spatial.distance import cdist, pdist, squareform
22
+
23
+ from . import esmfold2_residue_constants
24
+ from .esmfold2_misc import slice_python_object_as_numpy
25
+ from .esmfold2_affine3d import Affine3D
26
+ from .esmfold2_aligner import Aligner
27
+ from .esmfold2_atom_indexer import AtomIndexer
28
+ from .esmfold2_metrics import compute_gdt_ts, compute_lddt_ca
29
+ from .esmfold2_mmcif_parsing import MmcifWrapper, Residue
30
+ from .esmfold2_normalize_coordinates import (
31
+ apply_frame_to_coords,
32
+ get_protein_normalization_frame,
33
+ )
34
+ from .esmfold2_protein_structure import index_by_atom_name
35
+ from .esmfold2_utils_types import PathOrBuffer
36
+
37
+ msgpack_numpy.patch()
38
+ CHAIN_ID_CONST = "A"
39
+
40
+
41
+ def _str_key_to_int_key(dct: dict, ignore_keys: list[str] | None = None) -> dict:
42
+ new_dict = {}
43
+ for k, v in dct.items():
44
+ v_new = v
45
+ if k not in ignore_keys and isinstance(v, dict):
46
+ v_new = _str_key_to_int_key(v, ignore_keys=ignore_keys)
47
+ # Note assembly_composition is *supposed* to have string keys.
48
+ if isinstance(k, str) and k.isdigit():
49
+ new_dict[int(k)] = v_new
50
+ else:
51
+ new_dict[k] = v_new
52
+ return new_dict
53
+
54
+
55
+ def _num_non_null_residues(seqres_to_structure_chain: Mapping[int, Residue]) -> int:
56
+ return sum(
57
+ residue.residue_number is not None
58
+ for residue in seqres_to_structure_chain.values()
59
+ )
60
+
61
+
62
+ def infer_CB(C, N, Ca, L: float = 1.522, A: float = 1.927, D: float = -2.143):
63
+ """
64
+ Inspired by a util in trDesign:
65
+ https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L92
66
+
67
+ input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
68
+ output: 4th coord
69
+ """
70
+ norm = lambda x: x / np.sqrt(np.square(x).sum(-1, keepdims=True) + 1e-8)
71
+ with np.errstate(invalid="ignore"): # inf - inf = nan is ok here
72
+ vec_bc = N - Ca
73
+ vec_ba = N - C
74
+ bc = norm(vec_bc)
75
+ n = norm(np.cross(vec_ba, bc))
76
+ m = [bc, np.cross(n, bc), n]
77
+ d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
78
+ return Ca + sum([m * d for m, d in zip(m, d)])
79
+
80
+
81
+ def chain_to_ndarray(
82
+ atom_array: bs.AtomArray, mmcif: MmcifWrapper, chain_id: str, is_predicted=False
83
+ ):
84
+ entity_id = None
85
+ for entity, chains in mmcif.entities.items():
86
+ if chain_id in chains:
87
+ entity_id = entity
88
+ num_res = len(mmcif.chain_to_seqres[chain_id])
89
+ sequence = mmcif.chain_to_seqres[chain_id]
90
+
91
+ atom_positions = np.full([num_res, residue_constants.atom_type_num, 3], np.nan)
92
+ atom_mask = np.full([num_res, residue_constants.atom_type_num], False, dtype=bool)
93
+ residue_index = np.full([num_res], -1, dtype=np.int64)
94
+ insertion_code = np.full([num_res], "", dtype="<U4")
95
+
96
+ confidence = np.ones([num_res], dtype=np.float32)
97
+
98
+ for res_index in range(num_res):
99
+ chain = atom_array[atom_array.chain_id == chain_id]
100
+ assert isinstance(chain, bs.AtomArray)
101
+ res_at_position = mmcif.seqres_to_structure[chain_id][res_index]
102
+
103
+ if res_at_position.residue_number is None:
104
+ continue
105
+
106
+ residue_index[res_index] = res_at_position.residue_number
107
+ insertion_code[res_index] = res_at_position.insertion_code
108
+ res = chain[
109
+ (chain.res_id == res_at_position.residue_number)
110
+ & (chain.ins_code == res_at_position.insertion_code)
111
+ & (chain.hetero == res_at_position.hetflag)
112
+ ]
113
+ assert isinstance(res, bs.AtomArray)
114
+
115
+ # Atom level features
116
+ for atom in res:
117
+ atom_name = atom.atom_name
118
+ if atom_name == "SE" and atom.res_name == "MSE":
119
+ # Put the coords of the selenium atom in the sulphur column
120
+ atom_name = "SD"
121
+
122
+ if atom_name in residue_constants.atom_order:
123
+ atom_positions[res_index, residue_constants.atom_order[atom_name]] = (
124
+ atom.coord
125
+ )
126
+ atom_mask[res_index, residue_constants.atom_order[atom_name]] = True
127
+ if is_predicted and atom_name == "CA":
128
+ confidence[res_index] = atom.b_factor
129
+
130
+ assert all(sequence), "Some residue name was not specified correctly"
131
+ return (
132
+ sequence,
133
+ atom_positions,
134
+ atom_mask,
135
+ residue_index,
136
+ insertion_code,
137
+ confidence,
138
+ entity_id,
139
+ )
140
+
141
+
142
+ @dataclass(frozen=True)
143
+ class ProteinChain:
144
+ """Dataclass with atom37 representation of a single protein chain."""
145
+
146
+ id: str
147
+ sequence: str
148
+ chain_id: str # author chain id - mutable
149
+ entity_id: int | None
150
+ residue_index: np.ndarray
151
+ insertion_code: np.ndarray
152
+ atom37_positions: np.ndarray
153
+ atom37_mask: np.ndarray
154
+ confidence: np.ndarray
155
+ mmcif: MmcifWrapper | None = None
156
+ atom37_confidence: np.ndarray | None = None # [L, 37] per-atom pLDDT
157
+
158
+ def __post_init__(self):
159
+ assert self.atom37_mask.dtype == bool, self.atom37_mask.dtype
160
+ assert self.atom37_positions.shape[0] == len(self.sequence), (
161
+ self.atom37_positions.shape,
162
+ len(self.sequence),
163
+ )
164
+ assert self.atom37_mask.shape[0] == len(self.sequence), (
165
+ self.atom37_mask.shape,
166
+ len(self.sequence),
167
+ )
168
+ assert self.residue_index.shape[0] == len(self.sequence), (
169
+ self.residue_index.shape,
170
+ len(self.sequence),
171
+ )
172
+ assert self.insertion_code.shape[0] == len(self.sequence), (
173
+ self.insertion_code.shape,
174
+ len(self.sequence),
175
+ )
176
+ assert self.confidence.shape[0] == len(self.sequence), (
177
+ self.confidence.shape,
178
+ len(self.sequence),
179
+ )
180
+ if self.atom37_confidence is not None:
181
+ assert self.atom37_confidence.shape == self.atom37_mask.shape, (
182
+ self.atom37_confidence.shape,
183
+ self.atom37_mask.shape,
184
+ )
185
+
186
+ @cached_property
187
+ def atoms(self) -> AtomIndexer:
188
+ return AtomIndexer(self, property="atom37_positions", dim=-2)
189
+
190
+ @cached_property
191
+ def atom_mask(self) -> AtomIndexer:
192
+ return AtomIndexer(self, property="atom37_mask", dim=-1)
193
+
194
+ @cached_property
195
+ def atom_array(self) -> bs.AtomArray:
196
+ atoms = []
197
+ for res_idx_i, (
198
+ res_name,
199
+ res_idx,
200
+ ins_code,
201
+ positions,
202
+ mask,
203
+ conf,
204
+ ) in enumerate(
205
+ zip(
206
+ self.sequence,
207
+ self.residue_index,
208
+ self.insertion_code,
209
+ self.atom37_positions,
210
+ self.atom37_mask.astype(bool),
211
+ self.confidence,
212
+ )
213
+ ):
214
+ for i, pos in zip(np.where(mask)[0], positions[mask]):
215
+ b_factor = (
216
+ self.atom37_confidence[res_idx_i, i]
217
+ if self.atom37_confidence is not None
218
+ else conf
219
+ )
220
+ atom = bs.Atom(
221
+ coord=pos,
222
+ chain_id="A" if self.chain_id is None else self.chain_id,
223
+ res_id=res_idx,
224
+ ins_code=ins_code,
225
+ res_name=residue_constants.restype_1to3.get(res_name, "UNK"),
226
+ hetero=False,
227
+ atom_name=residue_constants.atom_types[i],
228
+ element=residue_constants.atom_types[i][0],
229
+ b_factor=float(b_factor),
230
+ )
231
+ atoms.append(atom)
232
+ return bs.array(atoms)
233
+
234
+ @cached_property
235
+ def residue_index_no_insertions(self) -> np.ndarray:
236
+ return self.residue_index + np.cumsum(self.insertion_code != "")
237
+
238
+ @cached_property
239
+ def atom_array_no_insertions(self) -> bs.AtomArray:
240
+ atoms = []
241
+ for res_idx, (res_name, positions, mask, conf) in enumerate(
242
+ zip(
243
+ self.sequence,
244
+ self.atom37_positions,
245
+ self.atom37_mask.astype(bool),
246
+ self.confidence,
247
+ )
248
+ ):
249
+ for i, pos in zip(np.where(mask)[0], positions[mask]):
250
+ b_factor = (
251
+ self.atom37_confidence[res_idx, i]
252
+ if self.atom37_confidence is not None
253
+ else conf
254
+ )
255
+ atom = bs.Atom(
256
+ coord=pos,
257
+ # hard coded to as we currently only support single chain structures
258
+ chain_id=CHAIN_ID_CONST,
259
+ res_id=res_idx + 1,
260
+ res_name=residue_constants.restype_1to3.get(res_name, "UNK"),
261
+ hetero=False,
262
+ atom_name=residue_constants.atom_types[i],
263
+ element=residue_constants.atom_types[i][0],
264
+ b_factor=float(b_factor),
265
+ )
266
+ atoms.append(atom)
267
+ return bs.array(atoms)
268
+
269
+ def __getitem__(self, idx: int | list[int] | slice | np.ndarray | torch.Tensor):
270
+ if isinstance(idx, int):
271
+ idx = [idx]
272
+ if isinstance(idx, torch.Tensor):
273
+ idx = idx.cpu().numpy()
274
+
275
+ sequence = slice_python_object_as_numpy(self.sequence, idx)
276
+ return replace(
277
+ self,
278
+ sequence=sequence,
279
+ residue_index=self.residue_index[..., idx],
280
+ insertion_code=self.insertion_code[..., idx],
281
+ atom37_positions=self.atom37_positions[..., idx, :, :],
282
+ atom37_mask=self.atom37_mask[..., idx, :],
283
+ confidence=self.confidence[..., idx],
284
+ atom37_confidence=self.atom37_confidence[..., idx, :]
285
+ if self.atom37_confidence is not None
286
+ else None,
287
+ )
288
+
289
+ def __len__(self):
290
+ return len(self.sequence)
291
+
292
+ def cbeta_contacts(self, distance_threshold: float = 8.0) -> np.ndarray:
293
+ distance = self.pdist_CB
294
+ contacts = (distance < distance_threshold).astype(np.int64)
295
+ contacts[np.isnan(distance)] = -1
296
+ np.fill_diagonal(contacts, -1)
297
+ return contacts
298
+
299
+ def to_pdb(self, path: PathOrBuffer, include_insertions: bool = True):
300
+ """Dssp works better w/o insertions."""
301
+ f = PDBFile()
302
+ if not include_insertions:
303
+ f.set_structure(self.atom_array_no_insertions)
304
+ else:
305
+ f.set_structure(self.atom_array)
306
+ f.write(path)
307
+
308
+ def to_pdb_string(self, include_insertions: bool = True) -> str:
309
+ buf = io.StringIO()
310
+ self.to_pdb(buf, include_insertions=include_insertions)
311
+ buf.seek(0)
312
+ return buf.read()
313
+
314
+ def to_mmcif(self, path: PathOrBuffer):
315
+ f = CIFFile()
316
+ set_structure_pdbx(f, self.atom_array, data_block=self.id)
317
+
318
+ # incantations molstar needs to render pLDDT / confidence onto
319
+ # the structure with "alphafold-view"
320
+ f.block["ma_qa_metric"] = CIFCategory(
321
+ name="ma_qa_metric",
322
+ columns={
323
+ "id": CIFColumn(data=CIFData(array=np.array([1, 2]), dtype=np.int64)),
324
+ "mode": CIFColumn(
325
+ data=CIFData(array=np.array(["global", "local"]), dtype=np.str_)
326
+ ),
327
+ "name": CIFColumn(
328
+ data=CIFData(array=np.array(["pLDDT", "pLDDT"]), dtype=np.str_)
329
+ ),
330
+ },
331
+ )
332
+
333
+ # table is a duplicate of data already in the atom array, but
334
+ # needed by molstar to render pLDDT / confidence
335
+ resid_pldd_table = {
336
+ # hard coded to as we currently only support single chain structures
337
+ "label_asym_id": CIFColumn(
338
+ data=CIFData(
339
+ array=[CHAIN_ID_CONST] * len(self.residue_index), dtype=np.str_
340
+ )
341
+ ),
342
+ "label_comp_id": CIFColumn(
343
+ data=CIFData(
344
+ array=[
345
+ residue_constants.restype_1to3.get(c, "UNK")
346
+ for c in self.sequence
347
+ ],
348
+ dtype=np.str_,
349
+ )
350
+ ),
351
+ "label_seq_id": CIFColumn(
352
+ data=CIFData(array=self.residue_index, dtype=np.int64)
353
+ ),
354
+ "ordinal_id": CIFColumn(
355
+ data=CIFData(array=self.residue_index, dtype=np.int64)
356
+ ),
357
+ # hard coded to show these are all local plDDT values
358
+ "metric_id": CIFColumn(
359
+ data=CIFData(array=["2"] * len(self.residue_index), dtype=np.str_)
360
+ ),
361
+ "metric_value": CIFColumn(
362
+ data=CIFData(array=self.confidence, dtype=np.float32)
363
+ ),
364
+ # hard coded to show there are the initial version, there are no revisions
365
+ "model_id": CIFColumn(
366
+ data=CIFData(array=["1"] * len(self.residue_index), dtype=np.str_)
367
+ ),
368
+ }
369
+ f.block["ma_qa_metric_local"] = CIFCategory(
370
+ name="ma_qa_metric_local", columns=resid_pldd_table
371
+ )
372
+ f.write(path)
373
+
374
+ def to_mmcif_string(self) -> str:
375
+ buf = io.StringIO()
376
+ self.to_mmcif(buf)
377
+ buf.seek(0)
378
+ return buf.read()
379
+
380
+ def state_dict(self, backbone_only=False, json_serializable=False):
381
+ """This state dict is optimized for storage, so it turns things to fp16 whenever
382
+ possible. Note that we also only support int32 residue indices, I'm hoping we don't
383
+ need more than 2**32 residues..."""
384
+ dct = {k: v for k, v in asdict(self).items() if k not in ["mmcif"]}
385
+ if backbone_only:
386
+ dct["atom37_mask"][:, 3:] = False
387
+ dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]]
388
+ if dct.get("atom37_confidence") is not None:
389
+ dct["atom37_confidence"] = dct["atom37_confidence"][dct["atom37_mask"]]
390
+ else:
391
+ dct.pop("atom37_confidence", None)
392
+
393
+ for k, v in dct.items():
394
+ if isinstance(v, np.ndarray):
395
+ match v.dtype:
396
+ case np.int64:
397
+ dct[k] = v.astype(np.int32)
398
+ case np.float64 | np.float32:
399
+ dct[k] = v.astype(np.float16)
400
+ case _:
401
+ pass
402
+ if json_serializable:
403
+ dct[k] = v.tolist()
404
+ return dct
405
+
406
+ def to_blob(self, backbone_only=False) -> bytes:
407
+ return brotli.compress(msgpack.dumps(self.state_dict(backbone_only)), quality=5)
408
+
409
+ @classmethod
410
+ def from_open_source(cls, pc: ProteinChain):
411
+ return cls(**vars(pc))
412
+
413
+ @classmethod
414
+ def from_state_dict(cls, dct):
415
+ # Note: assembly_composition is *supposed* to have string keys.
416
+ dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"])
417
+
418
+ for k, v in dct.items():
419
+ if isinstance(v, list):
420
+ dct[k] = np.array(v)
421
+
422
+ atom37 = np.full((*dct["atom37_mask"].shape, 3), np.nan)
423
+ atom37[dct["atom37_mask"]] = dct["atom37_positions"]
424
+ dct["atom37_positions"] = atom37
425
+ if "atom37_confidence" in dct:
426
+ atom37_conf = np.full(dct["atom37_mask"].shape, np.nan, dtype=np.float32)
427
+ atom37_conf[dct["atom37_mask"]] = dct["atom37_confidence"]
428
+ dct["atom37_confidence"] = atom37_conf
429
+ dct = {
430
+ k: (
431
+ v.astype(np.float32)
432
+ if k in ["atom37_positions", "confidence", "atom37_confidence"]
433
+ else v
434
+ )
435
+ for k, v in dct.items()
436
+ if not (k == "atom37_confidence" and v is None)
437
+ }
438
+ return cls(**dct, mmcif=None)
439
+
440
+ @classmethod
441
+ def from_blob(cls, input: Path | str | io.BytesIO | bytes):
442
+ """NOTE(@zlin): blob + sparse coding + brotli + fp16 reduces memory
443
+ of chains from 52G/1M chains to 20G/1M chains, I think this is a good first
444
+ shot at compressing and dumping chains to disk. I'm sure there's better ways."""
445
+ match input:
446
+ case Path() | str():
447
+ bytes = Path(input).read_bytes()
448
+ case io.BytesIO():
449
+ bytes = input.getvalue()
450
+ case _:
451
+ bytes = input
452
+ return cls.from_state_dict(msgpack.loads(brotli.decompress(bytes)))
453
+
454
+ def sasa(self, by_residue: bool = True):
455
+ arr = self.atom_array_no_insertions
456
+ sasa_per_atom = bs.sasa(arr) # type: ignore
457
+ if by_residue:
458
+ # Sum per-atom SASA into residue "bins", with np.bincount.
459
+ assert arr.res_id is not None
460
+ # NOTE(rverkuil): arr.res_id is 1-indexed, but np.bincount returns a sum for bin 0, so we strip.
461
+ # NOTE(aderry): We compute only for residues with coordinates, return NaN otherwise.
462
+ num_trailing_residues = len(self) - arr.res_id.max()
463
+ sasa_per_residue = np.concatenate(
464
+ [
465
+ np.bincount(arr.res_id, weights=sasa_per_atom)[1:],
466
+ np.zeros(num_trailing_residues),
467
+ ]
468
+ )
469
+ sasa_per_residue[~self.atom37_mask.any(-1)] = np.nan
470
+ assert len(sasa_per_residue) == len(self)
471
+ return sasa_per_residue
472
+ return sasa_per_atom
473
+
474
+ def sap_score(self, aggregation: str = "atom") -> np.ndarray:
475
+ """Computes per-atom SAP score.
476
+ Can optionally aggregate by residue (by averaging over atoms. NOTE: this returns values only for residues that have coordinates!)
477
+ or full-protein (sum of SAP score for atoms with SAP > 0, as in Lauer et al. 2011)."""
478
+ sap_radius = 5.0
479
+ arr = self.atom_array_no_insertions
480
+
481
+ # asserts to avoid type errors
482
+ assert arr.res_id is not None
483
+ assert arr.res_name is not None
484
+ assert arr.atom_name is not None
485
+ assert arr.coord is not None
486
+
487
+ # compute SASA and residue-specific properties
488
+ sasa_per_atom = self.sasa(by_residue=False)
489
+ resid_to_resname = dict(zip(arr.res_id, arr.res_name))
490
+
491
+ max_side_chain_asa = np.full(len(self), np.nan)
492
+ res_hydrophobicity = np.full(len(self), np.nan)
493
+ resolved_res_mask = self.atom37_mask.any(-1)
494
+ num_trailing_residues = len(self) - arr.res_id.max()
495
+
496
+ max_side_chain_asa[resolved_res_mask] = np.array(
497
+ [
498
+ residue_constants.side_chain_asa[resid_to_resname[i]]
499
+ for i in np.unique(arr.res_id)
500
+ ]
501
+ )
502
+ res_hydrophobicity[resolved_res_mask] = np.array(
503
+ [
504
+ residue_constants.hydrophobicity[resid_to_resname[i]]
505
+ for i in np.unique(arr.res_id)
506
+ ]
507
+ )
508
+ assert len(max_side_chain_asa) == len(self)
509
+ assert len(res_hydrophobicity) == len(self)
510
+
511
+ # compute SAP score
512
+ is_side_chain = ~bs.filter_peptide_backbone(arr)
513
+ sasa_per_atom[is_side_chain] = 0
514
+ kdtree = KDTree(arr.coord)
515
+ neighbors = kdtree.query_ball_tree(kdtree, sap_radius, p=2.0)
516
+ sap_by_atom = np.zeros_like(sasa_per_atom)
517
+ for i, nn_list in enumerate(neighbors):
518
+ saa_nn = np.zeros_like(sasa_per_atom)
519
+ saa_nn[nn_list] = sasa_per_atom[nn_list]
520
+ sasa_within_r = np.concatenate(
521
+ [
522
+ np.bincount(arr.res_id, weights=saa_nn)[1:],
523
+ np.zeros(num_trailing_residues),
524
+ ]
525
+ )
526
+ sap = np.nansum((sasa_within_r / max_side_chain_asa) * res_hydrophobicity)
527
+ sap_by_atom[i] = sap
528
+
529
+ match aggregation:
530
+ case "atom":
531
+ return sap_by_atom
532
+ case "residue":
533
+ sap_by_residue = np.concatenate(
534
+ [
535
+ np.bincount(arr.res_id, weights=sap_by_atom)[1:],
536
+ np.zeros(num_trailing_residues),
537
+ ]
538
+ ) / (
539
+ np.concatenate(
540
+ [np.bincount(arr.res_id)[1:], np.zeros(num_trailing_residues)]
541
+ )
542
+ + 1e-8
543
+ )
544
+ sap_by_residue[~resolved_res_mask] = np.nan
545
+ assert len(sap_by_residue) == len(self)
546
+ return sap_by_residue
547
+ case "protein":
548
+ return sum(sap_by_atom[sap_by_atom > 0]) # pyright: ignore[reportReturnType]
549
+ case _:
550
+ raise ValueError(
551
+ f"Invalid aggregation method: {aggregation}. Must be one of 'atom', 'residue', or 'protein'"
552
+ )
553
+
554
+ def globularity(self) -> float:
555
+ # Computes globularity using total volumes divided by MVEE.
556
+ # We make the simplifying approximation that atoms never overlap.
557
+ # The globularity is only computed where structure exists.
558
+ # Besides the approximation above, this is inspired by:
559
+
560
+ # https://www.mdpi.com/2073-4352/11/12/1539
561
+ # NOTE(@zeming): due to the approximation we make here, that atoms never overlap, you might get >1 globularity
562
+ mask = self.atom37_mask.any(-1)
563
+ points = self.atom37_positions[self.atom37_mask]
564
+ sequence = [aa for aa, m in zip(self.sequence, mask) if m] # type: ignore
565
+ A, _ = self._mvee(points, tol=1e-3)
566
+ mvee_volume = (4 * np.pi) / (3 * np.sqrt(np.linalg.det(A)))
567
+ volume = sum(residue_constants.amino_acid_volumes[x] for x in sequence)
568
+ ratio = volume / mvee_volume
569
+
570
+ # The paper says you must compare the ellipsoidal profile with T, a measurement of
571
+ # how elongated the ellipsoid is. We want a single number, so we multiply by 1/2T, so
572
+ # that value is normalized between 0-1
573
+ eigenvalues = np.linalg.eigvals(A)
574
+ R = 1 / np.sqrt(eigenvalues)
575
+ # ellipsoid radii length triangle inequality coefficient
576
+ T = max(R[0] / (R[1] + R[2]), R[1] / (R[0] + R[2]), R[2] / (R[0] + R[1]))
577
+ elongation_metric = 1 / max(T, 1)
578
+ return ratio * elongation_metric
579
+
580
+ @staticmethod
581
+ def _mvee(P: np.ndarray, tol, max_iter=10000):
582
+ # Finds minimum volume enclosing ellipsoid of a set of points.
583
+ # Returns A, c where the ellipse is defined as:
584
+ # (x-c).T @ A @ (x-c) = 1
585
+ hull = ConvexHull(P)
586
+ P = P[hull.vertices]
587
+ P = P.T
588
+
589
+ # Data points
590
+ d, N = P.shape
591
+ Q = np.zeros((d + 1, N))
592
+ Q[:d, :] = P[:d, :N]
593
+ Q[d, :] = np.ones((1, N))
594
+
595
+ # Initializations
596
+ count = 1
597
+ err = 1.0
598
+ u = np.full((N, 1), 1 / N) # 1st iteration
599
+
600
+ # Khachiyan Algorithm
601
+ for i in range(max_iter):
602
+ X = Q.dot(np.diag(u.squeeze())) @ Q.T
603
+ M = np.diag(Q.T @ np.linalg.inv(X) @ Q)
604
+ maximum, j = np.max(M), np.argmax(M)
605
+ step_size = (maximum - d - 1) / ((d + 1) * (maximum - 1))
606
+ new_u = (1 - step_size) * u
607
+ new_u[j] += step_size
608
+ count += 1
609
+ err = np.linalg.norm(new_u - u)
610
+ u = new_u
611
+ if err < tol:
612
+ break
613
+ else:
614
+ raise ValueError("MVEE did not converge")
615
+
616
+ d = P.shape[0] # Fixed: use P.shape[0] instead of P.shape
617
+ U = np.diag(u.squeeze())
618
+
619
+ # The A matrix for the ellipse
620
+ A = (1 / d) * np.linalg.inv(P @ U @ P.T - (P @ u) @ (P @ u).T)
621
+
622
+ # Center of the ellipse
623
+ c = P @ u
624
+
625
+ return A, c
626
+
627
+ def radius_of_gyration(self):
628
+ arr = self.atom_array_no_insertions
629
+ return bs.gyration_radius(arr)
630
+
631
+ def align(
632
+ self,
633
+ target: ProteinChain,
634
+ mobile_inds: list[int] | np.ndarray | None = None,
635
+ target_inds: list[int] | np.ndarray | None = None,
636
+ only_use_backbone: bool = False,
637
+ ):
638
+ """
639
+ Aligns the current protein to the provided target.
640
+
641
+ Args:
642
+ target (ProteinChain): The target protein to align to.
643
+ mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
644
+ target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
645
+ only_use_backbone (bool, optional): If True, only align the backbone atoms.
646
+ """
647
+ aligner = Aligner(
648
+ self if mobile_inds is None else self[mobile_inds],
649
+ target if target_inds is None else target[target_inds],
650
+ only_use_backbone,
651
+ )
652
+
653
+ return aligner.apply(self)
654
+
655
+ def rmsd(
656
+ self,
657
+ target: ProteinChain,
658
+ also_check_reflection: bool = False,
659
+ mobile_inds: list[int] | np.ndarray | None = None,
660
+ target_inds: list[int] | np.ndarray | None = None,
661
+ only_compute_backbone_rmsd: bool = False,
662
+ ):
663
+ """
664
+ Compute the RMSD between this protein chain and another.
665
+
666
+ Args:
667
+ target (ProteinChain): The target (other) protein chain to compare to.
668
+ also_check_reflection (bool, optional): If True, also check if the reflection of the mobile atoms has a lower RMSD.
669
+ mobile_inds (list[int], optional): The indices of the mobile atoms to align. These are NOT residue indices
670
+ target_inds (list[int], optional): The indices of the target atoms to align. These are NOT residue indices
671
+ only_compute_backbone_rmsd (bool, optional): If True, only compute the RMSD of the backbone atoms.
672
+ """
673
+ if isinstance(target, bs.AtomArray):
674
+ raise ValueError(
675
+ "Support for bs.AtomArray removed, use "
676
+ "ProteinChain.from_atomarry for ProteinChain."
677
+ )
678
+ aligner = Aligner(
679
+ self if mobile_inds is None else self[mobile_inds],
680
+ target if target_inds is None else target[target_inds],
681
+ only_compute_backbone_rmsd,
682
+ )
683
+ avg_rmsd = aligner.rmsd
684
+
685
+ if not also_check_reflection:
686
+ return avg_rmsd
687
+
688
+ aligner = Aligner(
689
+ self if mobile_inds is None else self[mobile_inds],
690
+ target if target_inds is None else target[target_inds],
691
+ only_compute_backbone_rmsd,
692
+ use_reflection=True,
693
+ )
694
+ avg_rmsd_neg = aligner.rmsd
695
+
696
+ return min(avg_rmsd, avg_rmsd_neg)
697
+
698
+ def lddt_ca(
699
+ self,
700
+ native: ProteinChain,
701
+ mobile_inds: list[int] | np.ndarray | None = None,
702
+ target_inds: list[int] | np.ndarray | None = None,
703
+ **kwargs,
704
+ ) -> float | np.ndarray:
705
+ """Compute the LDDT between this protein chain and another. NOTE: LDDT IS NOT SYMMETRIC.
706
+ The call should always be prediction.lddt_ca(native).
707
+
708
+ Arguments:
709
+ native (ProteinChain): The ground truth protein chain
710
+ mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
711
+ target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
712
+
713
+ Returns:
714
+ float | np.ndarray: The LDDT score between the two protein chains, either
715
+ a single float or per-residue LDDT scores if `per_residue` is True.
716
+ """
717
+ lddt = compute_lddt_ca(
718
+ torch.tensor(self.atom37_positions[mobile_inds]).unsqueeze(0),
719
+ torch.tensor(native.atom37_positions[target_inds]).unsqueeze(0),
720
+ torch.tensor(native.atom37_mask[mobile_inds]).unsqueeze(0),
721
+ **kwargs,
722
+ )
723
+ return float(lddt) if lddt.numel() == 1 else lddt.numpy().flatten()
724
+
725
+ def gdt_ts(
726
+ self,
727
+ target: ProteinChain,
728
+ mobile_inds: list[int] | np.ndarray | None = None,
729
+ target_inds: list[int] | np.ndarray | None = None,
730
+ **kwargs,
731
+ ) -> float | np.ndarray:
732
+ """Compute the GDT_TS between this protein chain and another.
733
+
734
+ Arguments:
735
+ target (ProteinChain): The other protein chain to compare to.
736
+ mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
737
+ target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
738
+
739
+ Returns:
740
+ float: The GDT_TS score between the two protein chains.
741
+ """
742
+ gdt_ts = compute_gdt_ts(
743
+ mobile=torch.tensor(
744
+ index_by_atom_name(self.atom37_positions[mobile_inds], "CA"),
745
+ dtype=torch.float32,
746
+ ).unsqueeze(0),
747
+ target=torch.tensor(
748
+ index_by_atom_name(target.atom37_positions[target_inds], "CA"),
749
+ dtype=torch.float32,
750
+ ).unsqueeze(0),
751
+ atom_exists_mask=torch.tensor(
752
+ index_by_atom_name(self.atom37_mask[mobile_inds], "CA", dim=-1)
753
+ & index_by_atom_name(target.atom37_mask[target_inds], "CA", dim=-1)
754
+ ).unsqueeze(0),
755
+ **kwargs,
756
+ )
757
+ return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
758
+
759
+ @classmethod
760
+ def chain_iterable_from_mmcif(
761
+ cls,
762
+ path: PathOrBuffer | MmcifWrapper,
763
+ id: str | None = None,
764
+ is_predicted: bool = False,
765
+ keep_source: bool = False,
766
+ ):
767
+ """Return a list[ProteinChain] object from an mmcif file, a iterable list of all protein chain
768
+ from an mmcif file
769
+ """
770
+ if isinstance(path, MmcifWrapper):
771
+ mmcif = path
772
+ else:
773
+ mmcif = MmcifWrapper.read(path, id)
774
+ for chain in bs.chain_iter(mmcif.structure):
775
+ chain = chain[bs.filter_amino_acids(chain) & ~chain.hetero]
776
+ if len(chain) == 0:
777
+ continue
778
+ chain_id = chain.chain_id[0]
779
+ entity_id = None
780
+ for entity, chains in mmcif.entities.items():
781
+ if chain_id in chains:
782
+ entity_id = entity
783
+ assert entity_id is not None
784
+ (
785
+ sequence,
786
+ atom_positions,
787
+ atom_mask,
788
+ residue_index,
789
+ insertion_code,
790
+ confidence,
791
+ _,
792
+ ) = chain_to_ndarray(chain, mmcif, chain_id, is_predicted)
793
+ assert all(sequence), "Some residue name was not specified correctly"
794
+
795
+ yield cls(
796
+ id=mmcif.id,
797
+ sequence=sequence,
798
+ chain_id=chain_id,
799
+ entity_id=entity_id,
800
+ atom37_positions=atom_positions,
801
+ atom37_mask=atom_mask,
802
+ residue_index=residue_index,
803
+ insertion_code=insertion_code,
804
+ confidence=confidence,
805
+ mmcif=mmcif if keep_source else None,
806
+ )
807
+
808
+ @classmethod
809
+ def from_mmcif(
810
+ cls,
811
+ path: PathOrBuffer | MmcifWrapper,
812
+ chain_id: str | None = None,
813
+ entity_id: int | None = None,
814
+ id: str | None = None,
815
+ is_predicted: bool = False,
816
+ keep_source: bool = False,
817
+ ):
818
+ """Return a ProteinChain object from an mmcif file.
819
+
820
+ Args:
821
+ path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
822
+ id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
823
+ is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
824
+ chain_id (str, optional): Select a chain corresponding to (author) chain id.
825
+ entity_id (int, optional): Select a chain corresponding to a particular entity.
826
+
827
+ If neither `chain_id` nor `entity_id` is specified, defaults to the first entity.
828
+ """
829
+ if isinstance(path, MmcifWrapper):
830
+ mmcif = path
831
+ else:
832
+ mmcif = MmcifWrapper.read(path, id)
833
+
834
+ # If neither chain_id nor entity_id is specified, default to the first entity
835
+ if chain_id is None and entity_id is None:
836
+ if not mmcif.entities:
837
+ raise ValueError("Structure contains no entities")
838
+ entity_id = min(mmcif.entities.keys()) # Pick the first entity by ID
839
+
840
+ if entity_id is not None:
841
+ assert chain_id is None
842
+ if entity_id not in mmcif.entities:
843
+ raise ValueError(
844
+ f"Structure does not contain entity `{entity_id}`. Valid entities: {mmcif.entities.keys()}"
845
+ )
846
+ chains = mmcif.entities[entity_id]
847
+
848
+ # Select the chain id corresponding to the longest chain. If all are equal length, selects the first.
849
+ chain_id = max(
850
+ chains,
851
+ key=lambda chain: _num_non_null_residues(
852
+ mmcif.seqres_to_structure[chain]
853
+ ),
854
+ )
855
+ else:
856
+ assert chain_id is not None
857
+ for entity, chains in mmcif.entities.items():
858
+ if chain_id in chains:
859
+ entity_id = entity
860
+ if entity_id is None:
861
+ warnings.warn(
862
+ "Failed to detect entity_id from mmcif file, it may be malformed."
863
+ )
864
+
865
+ atom_array = mmcif.structure
866
+ (
867
+ sequence,
868
+ atom_positions,
869
+ atom_mask,
870
+ residue_index,
871
+ insertion_code,
872
+ confidence,
873
+ _,
874
+ ) = chain_to_ndarray(atom_array, mmcif, chain_id, is_predicted)
875
+ assert all(sequence), "Some residue name was not specified correctly"
876
+
877
+ return cls(
878
+ id=mmcif.id,
879
+ sequence=sequence,
880
+ chain_id=chain_id,
881
+ entity_id=entity_id,
882
+ atom37_positions=atom_positions,
883
+ atom37_mask=atom_mask.astype(bool),
884
+ residue_index=residue_index,
885
+ insertion_code=insertion_code,
886
+ confidence=confidence,
887
+ mmcif=mmcif if keep_source else None,
888
+ )
889
+
890
+ @classmethod
891
+ def from_atom37(
892
+ cls,
893
+ atom37_positions: np.ndarray | torch.Tensor,
894
+ *,
895
+ id: str | None = None,
896
+ sequence: str | None = None,
897
+ chain_id: str | None = None,
898
+ entity_id: int | None = None,
899
+ residue_index: np.ndarray | torch.Tensor | None = None,
900
+ insertion_code: np.ndarray | None = None,
901
+ confidence: np.ndarray | torch.Tensor | None = None,
902
+ ):
903
+ if isinstance(atom37_positions, torch.Tensor):
904
+ atom37_positions = atom37_positions.cpu().numpy()
905
+ if atom37_positions.ndim == 4:
906
+ if atom37_positions.shape[0] != 1:
907
+ raise ValueError(
908
+ f"Cannot handle batched inputs, atom37_positions has shape {atom37_positions.shape}"
909
+ )
910
+ atom37_positions = atom37_positions[0]
911
+
912
+ assert isinstance(atom37_positions, np.ndarray)
913
+ seqlen = atom37_positions.shape[0]
914
+
915
+ atom_mask = np.isfinite(atom37_positions).all(-1)
916
+
917
+ if id is None:
918
+ id = ""
919
+
920
+ if sequence is None:
921
+ sequence = "A" * seqlen
922
+
923
+ if chain_id is None:
924
+ chain_id = "A"
925
+
926
+ if residue_index is None:
927
+ residue_index = np.arange(1, seqlen + 1)
928
+ elif isinstance(residue_index, torch.Tensor):
929
+ residue_index = residue_index.cpu().numpy()
930
+ assert isinstance(residue_index, np.ndarray)
931
+ if residue_index.ndim == 2:
932
+ if residue_index.shape[0] != 1:
933
+ raise ValueError(
934
+ f"Cannot handle batched inputs, residue_index has shape {residue_index.shape}"
935
+ )
936
+ residue_index = residue_index[0]
937
+ assert isinstance(residue_index, np.ndarray)
938
+
939
+ if insertion_code is None:
940
+ insertion_code = np.array(["" for _ in range(seqlen)])
941
+
942
+ if confidence is None:
943
+ confidence = np.ones(seqlen, dtype=np.float32)
944
+ elif isinstance(confidence, torch.Tensor):
945
+ confidence = confidence.cpu().numpy()
946
+ assert isinstance(confidence, np.ndarray)
947
+ if confidence.ndim == 2:
948
+ if confidence.shape[0] != 1:
949
+ raise ValueError(
950
+ f"Cannot handle batched inputs, confidence has shape {confidence.shape}"
951
+ )
952
+ confidence = confidence[0]
953
+ assert isinstance(confidence, np.ndarray)
954
+
955
+ return cls(
956
+ id=id,
957
+ sequence=sequence, # type: ignore
958
+ chain_id=chain_id,
959
+ entity_id=entity_id,
960
+ atom37_positions=atom37_positions,
961
+ atom37_mask=atom_mask.astype(bool),
962
+ residue_index=residue_index,
963
+ insertion_code=insertion_code,
964
+ confidence=confidence,
965
+ )
966
+
967
+ @classmethod
968
+ def from_backbone_atom_coordinates(
969
+ cls, backbone_atom_coordinates: np.ndarray | torch.Tensor, **kwargs
970
+ ):
971
+ """Create a ProteinChain from a set of backbone atom coordinates.
972
+
973
+ This function simply expands the seqlen x 3 x 3 array of backbone atom
974
+ coordinates to a seqlen x 37 x 3 array of all atom coordinates, with the padded
975
+ positions set to infinity. This allows us to use from_atom37 to create the
976
+ appropriate ProteinChain object with the appropriate atom37_mask.
977
+
978
+ This function passes all kwargs to from_atom37.
979
+ """
980
+ if isinstance(backbone_atom_coordinates, torch.Tensor):
981
+ backbone_atom_coordinates = backbone_atom_coordinates.cpu().numpy()
982
+ if backbone_atom_coordinates.ndim == 4:
983
+ if backbone_atom_coordinates.shape[0] != 1:
984
+ raise ValueError(
985
+ f"Cannot handle batched inputs, backbone_atom_coordinates has "
986
+ f"shape {backbone_atom_coordinates.shape}"
987
+ )
988
+ backbone_atom_coordinates = backbone_atom_coordinates[0]
989
+
990
+ assert isinstance(backbone_atom_coordinates, np.ndarray)
991
+ assert backbone_atom_coordinates.ndim == 3
992
+ assert backbone_atom_coordinates.shape[-2] == 3
993
+ assert backbone_atom_coordinates.shape[-1] == 3
994
+
995
+ atom37_positions = np.full(
996
+ (backbone_atom_coordinates.shape[0], 37, 3),
997
+ np.inf,
998
+ dtype=backbone_atom_coordinates.dtype,
999
+ )
1000
+ atom37_positions[:, :3, :] = backbone_atom_coordinates
1001
+
1002
+ return cls.from_atom37(atom37_positions=atom37_positions, **kwargs)
1003
+
1004
+ @classmethod
1005
+ def from_pdb(
1006
+ cls,
1007
+ path: PathOrBuffer,
1008
+ chain_id: str = "detect",
1009
+ id: str | None = None,
1010
+ is_predicted: bool = False,
1011
+ ) -> "ProteinChain":
1012
+ """Return a ProteinChain object from an pdb file. NOTE: prefer mmcif for rcsb PDB files.
1013
+ This function is mostly to interface with old PDB files and predicted structures -
1014
+ it will not fill out the entity id correctly
1015
+
1016
+ Args:
1017
+ path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
1018
+ id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
1019
+ is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
1020
+ chain_id (str, optional): Select a chain corresponding to (author) chain id. "detect" uses the
1021
+ first detected chain
1022
+ """
1023
+
1024
+ if id is not None:
1025
+ file_id = id
1026
+ else:
1027
+ match path:
1028
+ case Path() | str():
1029
+ file_id = Path(path).with_suffix("").name
1030
+ case _:
1031
+ file_id = "null"
1032
+
1033
+ atom_array = PDBFile.read(path).get_structure(
1034
+ model=1, extra_fields=["b_factor"]
1035
+ )
1036
+ if chain_id == "detect":
1037
+ chain_id = atom_array.chain_id[0]
1038
+ atom_array = atom_array[
1039
+ bs.filter_amino_acids(atom_array)
1040
+ & ~atom_array.hetero
1041
+ & (atom_array.chain_id == chain_id)
1042
+ ]
1043
+
1044
+ entity_id = 1 # Not supplied in PDBfiles
1045
+
1046
+ sequence = "".join(
1047
+ residue_constants.restype_3to1.get(monomer[0].res_name, "X")
1048
+ for monomer in bs.residue_iter(atom_array)
1049
+ )
1050
+ num_res = len(sequence)
1051
+
1052
+ atom_positions = np.full(
1053
+ [num_res, residue_constants.atom_type_num, 3], np.nan, dtype=np.float32
1054
+ )
1055
+ atom_mask = np.full(
1056
+ [num_res, residue_constants.atom_type_num], False, dtype=bool
1057
+ )
1058
+ residue_index = np.full([num_res], -1, dtype=np.int64)
1059
+ insertion_code = np.full([num_res], "", dtype="<U4")
1060
+
1061
+ confidence = np.ones([num_res], dtype=np.float32)
1062
+
1063
+ for i, res in enumerate(bs.residue_iter(atom_array)):
1064
+ chain = atom_array[atom_array.chain_id == chain_id]
1065
+ assert isinstance(chain, bs.AtomArray)
1066
+
1067
+ res_index = res[0].res_id
1068
+ residue_index[i] = res_index
1069
+ insertion_code[i] = res[0].ins_code
1070
+
1071
+ # Atom level features
1072
+ for atom in res:
1073
+ atom_name = atom.atom_name
1074
+ if atom_name == "SE" and atom.res_name == "MSE":
1075
+ # Put the coords of the selenium atom in the sulphur column
1076
+ atom_name = "SD"
1077
+
1078
+ if atom_name in residue_constants.atom_order:
1079
+ atom_positions[i, residue_constants.atom_order[atom_name]] = (
1080
+ atom.coord
1081
+ )
1082
+ atom_mask[i, residue_constants.atom_order[atom_name]] = True
1083
+ if is_predicted and atom_name == "CA":
1084
+ confidence[i] = atom.b_factor
1085
+
1086
+ assert all(sequence), "Some residue name was not specified correctly"
1087
+
1088
+ return cls(
1089
+ id=file_id,
1090
+ sequence=sequence,
1091
+ chain_id=chain_id,
1092
+ entity_id=entity_id,
1093
+ atom37_positions=atom_positions,
1094
+ atom37_mask=atom_mask.astype(bool),
1095
+ residue_index=residue_index,
1096
+ insertion_code=insertion_code,
1097
+ confidence=confidence,
1098
+ mmcif=None,
1099
+ )
1100
+
1101
+ @classmethod
1102
+ def from_mds(cls, data: dict[str, Any]) -> "ProteinChain":
1103
+ return cls(
1104
+ id=data["id"],
1105
+ chain_id=data["chain_id"],
1106
+ entity_id=data["entity_id"],
1107
+ sequence=data["sequence"],
1108
+ residue_index=data["residue_index"],
1109
+ insertion_code=np.asarray(data["insertion_code"]),
1110
+ atom37_positions=data["atom37_positions"],
1111
+ atom37_mask=data["atom37_mask"].astype(bool),
1112
+ confidence=data["confidence"],
1113
+ mmcif=None,
1114
+ )
1115
+
1116
+ @classmethod
1117
+ def from_rcsb(
1118
+ cls,
1119
+ pdb_id: str,
1120
+ chain_id: str | None = None,
1121
+ entity_id: int | None = None,
1122
+ keep_source: bool = False,
1123
+ ) -> ProteinChain:
1124
+ f: io.StringIO = rcsb.fetch(pdb_id, "cif") # type: ignore
1125
+ return cls.from_mmcif(
1126
+ f,
1127
+ id=pdb_id,
1128
+ chain_id=chain_id,
1129
+ entity_id=entity_id,
1130
+ keep_source=keep_source,
1131
+ is_predicted=False,
1132
+ )
1133
+
1134
+ @classmethod
1135
+ def from_atomarray(
1136
+ cls, atom_array: bs.AtomArray, id: str | None = None, is_predicted: bool = False
1137
+ ) -> "ProteinChain":
1138
+ """A simple converter from bs.AtomArray -> ProteinChain.
1139
+ Uses PDB file format as intermediate."""
1140
+ atom_array = atom_array.copy()
1141
+ atom_array.box = None # remove surrounding box, from_pdb won't handle this
1142
+ pdb_file = PDBFile() # pyright: ignore
1143
+ pdb_file.set_structure(atom_array)
1144
+
1145
+ buf = io.StringIO()
1146
+ pdb_file.write(buf)
1147
+ buf.seek(0)
1148
+ return cls.from_pdb(buf, id=id, is_predicted=is_predicted)
1149
+
1150
+ def get_normalization_frame(self) -> Affine3D:
1151
+ """Given a set of coordinates, compute a single frame.
1152
+ Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame.
1153
+
1154
+ Returns:
1155
+ Affine3D: [] tensor of Affine3D frame
1156
+ """
1157
+ coords = torch.from_numpy(self.atom37_positions)
1158
+ frame = get_protein_normalization_frame(coords)
1159
+
1160
+ return frame
1161
+
1162
+ def apply_frame(self, frame: Affine3D) -> ProteinChain:
1163
+ """Given a frame, apply the frame to the protein's coordinates.
1164
+
1165
+ Args:
1166
+ frame (Affine3D): [] tensor of Affine3D frame
1167
+
1168
+ Returns:
1169
+ ProteinChain: Transformed protein chain
1170
+ """
1171
+ coords = torch.from_numpy(self.atom37_positions).to(frame.trans.dtype)
1172
+ coords = apply_frame_to_coords(coords, frame)
1173
+ atom37_positions = coords.numpy()
1174
+ return replace(self, atom37_positions=atom37_positions)
1175
+
1176
+ def normalize_coordinates(self) -> ProteinChain:
1177
+ """Normalize the coordinates of the protein chain."""
1178
+ return self.apply_frame(self.get_normalization_frame())
1179
+
1180
+ def infer_oxygen(self) -> ProteinChain:
1181
+ """Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
1182
+ O_missing_indices = np.argwhere(
1183
+ ~np.isfinite(self.atoms["O"]).all(axis=1)
1184
+ ).squeeze()
1185
+
1186
+ O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
1187
+ N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
1188
+ N = torch.roll(N, -3)
1189
+ N[..., -1, :] = torch.nan
1190
+
1191
+ # Get the frame defined by the CA-C-N atom
1192
+ frames = Affine3D.from_graham_schmidt(CA, C, N)
1193
+ O = frames.apply(O_vector)
1194
+ atom37_positions = self.atom37_positions.copy()
1195
+ atom37_mask = self.atom37_mask.copy()
1196
+
1197
+ atom37_positions[O_missing_indices, residue_constants.atom_order["O"]] = O[
1198
+ O_missing_indices
1199
+ ].numpy()
1200
+ atom37_mask[O_missing_indices, residue_constants.atom_order["O"]] = ~np.isnan(
1201
+ atom37_positions[O_missing_indices, residue_constants.atom_order["O"]]
1202
+ ).any(-1)
1203
+ new_chain = replace(
1204
+ self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
1205
+ )
1206
+ return new_chain
1207
+
1208
+ @cached_property
1209
+ def inferred_cbeta(self) -> np.ndarray:
1210
+ """Infer cbeta positions based on N, C, CA."""
1211
+ N, CA, C = np.moveaxis(self.atoms[["N", "CA", "C"]], 1, 0)
1212
+ # See usage in trDesign codebase.
1213
+ # https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L140
1214
+ CB = infer_CB(C, N, CA, 1.522, 1.927, -2.143)
1215
+ return CB
1216
+
1217
+ def infer_cbeta(self, infer_cbeta_for_glycine: bool = False) -> ProteinChain:
1218
+ """Return a new chain with inferred CB atoms at all residues except GLY.
1219
+
1220
+ Args:
1221
+ infer_cbeta_for_glycine (bool): If True, infers a beta carbon for glycine
1222
+ residues, even though that residue doesn't have one. Default off.
1223
+
1224
+ NOTE(rverkuil): The reason for having this switch in the first place
1225
+ is that sometimes we want a (inferred) CB coordinate for every residue,
1226
+ for example for making a pairwise distance matrix, or doing an RMSD
1227
+ calculation between two designs for a given structural template, w/
1228
+ CB atoms.
1229
+ """
1230
+ atom37_positions = self.atom37_positions.copy()
1231
+ atom37_mask = self.atom37_mask.copy()
1232
+
1233
+ inferred_cbeta_positions = self.inferred_cbeta
1234
+ if not infer_cbeta_for_glycine:
1235
+ inferred_cbeta_positions[np.array(list(self.sequence)) == "G", :] = np.nan
1236
+
1237
+ atom37_positions[:, residue_constants.atom_order["CB"]] = (
1238
+ inferred_cbeta_positions
1239
+ )
1240
+ atom37_mask[:, residue_constants.atom_order["CB"]] = ~np.isnan(
1241
+ atom37_positions[:, residue_constants.atom_order["CB"]]
1242
+ ).any(-1)
1243
+ new_chain = replace(
1244
+ self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
1245
+ )
1246
+ return new_chain
1247
+
1248
+ @cached_property
1249
+ def pdist_CA(self) -> np.ndarray:
1250
+ CA = self.atoms["CA"]
1251
+ pdist_CA = squareform(pdist(CA))
1252
+ return pdist_CA
1253
+
1254
+ @cached_property
1255
+ def pdist_CB(self) -> np.ndarray:
1256
+ pdist_CB = squareform(pdist(self.inferred_cbeta))
1257
+ return pdist_CB
1258
+
1259
+ @classmethod
1260
+ def as_complex(cls, chains: Sequence[ProteinChain]):
1261
+ raise RuntimeError(
1262
+ ".as_complex() has been deprecated in favor of .concat(). "
1263
+ ".concat() will eventually be deprecated in favor of ProteinComplex..."
1264
+ )
1265
+
1266
+ @classmethod
1267
+ def concat(cls, chains: Sequence[ProteinChain], use_chainbreak: bool = True):
1268
+ sep_tokens = {
1269
+ "residue_index": np.array([-1]),
1270
+ "insertion_code": np.array([""]),
1271
+ "atom37_positions": np.full([1, 37, 3], np.inf),
1272
+ "atom37_mask": np.zeros([1, 37], dtype=bool),
1273
+ "confidence": np.array([0]),
1274
+ }
1275
+
1276
+ def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):
1277
+ if use_chainbreak:
1278
+ full_array = []
1279
+ for array in arrays:
1280
+ full_array.append(array)
1281
+ full_array.append(sep)
1282
+ full_array = full_array[:-1]
1283
+ return np.concatenate(full_array, 0)
1284
+ else:
1285
+ return np.concatenate(arrays, 0)
1286
+
1287
+ array_args: dict[str, np.ndarray] = {
1288
+ name: join_arrays([getattr(chain, name) for chain in chains], sep)
1289
+ for name, sep in sep_tokens.items()
1290
+ }
1291
+
1292
+ chain_break = residue_constants.CHAIN_BREAK_TOKEN if use_chainbreak else ""
1293
+ return cls(
1294
+ id=chains[0].id,
1295
+ sequence=chain_break.join(chain.sequence for chain in chains),
1296
+ chain_id="A",
1297
+ entity_id=None,
1298
+ mmcif=None,
1299
+ **array_args,
1300
+ )
1301
+
1302
+ def find_nonpolymer_contacts(self):
1303
+ assert self.mmcif is not None
1304
+ nonpolymer_and_chain_id_to_array = self.mmcif.non_polymer_coords
1305
+
1306
+ results = []
1307
+ for (
1308
+ nonpolymer,
1309
+ _,
1310
+ ), nonpolymer_array in nonpolymer_and_chain_id_to_array.items():
1311
+ assert nonpolymer_array.coord is not None
1312
+ chain_coords = self.atom37_positions[self.atom37_mask]
1313
+ distance = cdist(nonpolymer_array.coord, chain_coords)
1314
+
1315
+ is_contact = distance < 5
1316
+ if not is_contact.any():
1317
+ continue
1318
+ contacting_atoms = np.where(is_contact.any(0))[0]
1319
+ chain_index = np.where(self.atom37_mask)[0]
1320
+ contacting_residues = np.unique(chain_index[contacting_atoms])
1321
+
1322
+ result = {
1323
+ "ligand": nonpolymer.name,
1324
+ "ligand_id": nonpolymer.comp_id,
1325
+ "contacting_residues": contacting_residues.tolist(),
1326
+ }
1327
+ results.append(result)
1328
+ return results
1329
+
1330
+ def select_residue_indices(
1331
+ self, indices: list[int | str], ignore_x_mismatch: bool = False
1332
+ ) -> ProteinChain:
1333
+ numeric_indices = [
1334
+ idx if isinstance(idx, int) else int(idx[1:]) for idx in indices
1335
+ ]
1336
+ mask = np.isin(self.residue_index, numeric_indices)
1337
+ new = self[mask]
1338
+ mismatches = []
1339
+ for aa, idx in zip(new.sequence, indices):
1340
+ if isinstance(idx, int):
1341
+ continue
1342
+ if aa == "X" and ignore_x_mismatch:
1343
+ continue
1344
+ if aa != idx[0]:
1345
+ mismatches.append((aa, idx))
1346
+ if mismatches:
1347
+ mismatch_str = "; ".join(
1348
+ f"Position {idx[1:]}, Expected: {idx[0]}, Received: {aa}"
1349
+ for aa, idx in mismatches
1350
+ )
1351
+ raise RuntimeError(mismatch_str)
1352
+
1353
+ return new
1354
+
1355
+ def to_structure_encoder_inputs(
1356
+ self,
1357
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1358
+ """Convert protein chain to structure encoder inputs.
1359
+
1360
+ Returns:
1361
+ tuple: (coordinates, plddt, residue_index) where:
1362
+ - coordinates: (1, L, 37, 3) tensor of atom positions
1363
+ - plddt: (1, L) tensor of confidence scores
1364
+ - residue_index: (1, L) tensor of residue indices
1365
+ """
1366
+ # Convert to tensors and add batch dimension
1367
+ coordinates = (
1368
+ torch.from_numpy(self.atom37_positions).float().unsqueeze(0)
1369
+ ) # (1, L, 37, 3)
1370
+ plddt = torch.from_numpy(self.confidence).float().unsqueeze(0) # (1, L)
1371
+ residue_index = (
1372
+ torch.from_numpy(self.residue_index).long().unsqueeze(0)
1373
+ ) # (1, L)
1374
+
1375
+ return coordinates, plddt, residue_index
1376
+
esmfold2_protein_complex.py ADDED
@@ -0,0 +1,1241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ import itertools
5
+ import random
6
+ import re
7
+ import warnings
8
+ from dataclasses import asdict, dataclass, replace
9
+ from functools import cached_property
10
+ from pathlib import Path
11
+ from subprocess import check_output
12
+ from tempfile import TemporaryDirectory
13
+ from typing import Any, Iterable, Sequence
14
+
15
+ import biotite.structure as bs
16
+ import brotli
17
+ import msgpack
18
+ import msgpack_numpy
19
+ import numpy as np
20
+ import torch
21
+ from biotite.database import rcsb
22
+ from biotite.file import InvalidFileError
23
+ from biotite.structure.io.pdb import PDBFile
24
+ from biotite.structure.io.pdbx import CIFCategory, CIFColumn, CIFData, CIFFile
25
+ from biotite.structure.io.pdbx import set_structure as set_structure_pdbx
26
+ from biotite.structure.io.pdbx.convert import _get_transformations, get_structure
27
+ from biotite.structure.util import matrix_rotate
28
+ from scipy.spatial import KDTree
29
+
30
+ from . import esmfold2_residue_constants
31
+ from .esmfold2_misc import slice_python_object_as_numpy
32
+ from .esmfold2_affine3d import Affine3D
33
+ from .esmfold2_aligner import Aligner
34
+ from .esmfold2_atom_indexer import AtomIndexer
35
+ from .esmfold2_metrics import compute_gdt_ts, compute_lddt_ca
36
+ from .esmfold2_mmcif_parsing import MmcifWrapper, NoProteinError
37
+ from .esmfold2_protein_chain import (
38
+ ProteinChain,
39
+ _str_key_to_int_key,
40
+ chain_to_ndarray,
41
+ index_by_atom_name,
42
+ infer_CB,
43
+ )
44
+ from .esmfold2_utils_types import PathOrBuffer
45
+
46
+ msgpack_numpy.patch()
47
+
48
+ SINGLE_LETTER_CHAIN_IDS = (
49
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
50
+ )
51
+
52
+
53
+ def _parse_operation_expression(expression):
54
+ """
55
+ Get successive operation steps (IDs) for the given
56
+ ``oper_expression``.
57
+ Form the cartesian product, if necessary.
58
+ Copied from biotite and fixed a bug
59
+ """
60
+ # Split groups by parentheses:
61
+ # use the opening parenthesis as delimiter
62
+ # and just remove the closing parenthesis
63
+ expressions_per_step = expression.replace(")", "").split("(")
64
+ expressions_per_step = [e for e in expressions_per_step if len(e) > 0]
65
+ # Important: Operations are applied from right to left
66
+ expressions_per_step.reverse()
67
+
68
+ operations = []
69
+ for expr in expressions_per_step:
70
+ cur_expr = expr.split(",")
71
+ cur_op = []
72
+ # Deal with e='1-10,20-30,40-50' type expressions
73
+ for e in cur_expr:
74
+ if "-" in e:
75
+ first, last = e.split("-")
76
+ cur_op.extend(str(id) for id in range(int(first), int(last) + 1))
77
+ else:
78
+ cur_op.append(e)
79
+ operations.append(cur_op)
80
+
81
+ # Cartesian product of operations
82
+ return list(itertools.product(*operations))
83
+
84
+
85
+ def _apply_transformations_fast(chains, transformation_dict, operations):
86
+ """
87
+ Get subassembly by applying the given operations to the input
88
+ structure containing affected asym IDs.
89
+ """
90
+ # Additional first dimesion for 'structure.repeat()'
91
+ results = []
92
+
93
+ # Apply corresponding transformation for each copy in the assembly
94
+ for c in chains:
95
+ for operation in operations:
96
+ coord = c.atom37_positions.copy()
97
+ # Execute for each transformation step
98
+ # in the operation expression
99
+ for op_step in operation:
100
+ T = transformation_dict[op_step]
101
+ # Rotate
102
+ coord = matrix_rotate(coord, T.rotation)
103
+ # Translate
104
+ coord += T.target_translation
105
+ new_chain = replace(c, atom37_positions=coord)
106
+ results.append(new_chain)
107
+
108
+ return results
109
+
110
+
111
+ @dataclass
112
+ class ProteinComplexMetadata:
113
+ entity_lookup: dict[int, int]
114
+ chain_lookup: dict[int, str]
115
+ mmcif: MmcifWrapper | None = None
116
+ # This is a dictionary that maps assembly ids to the list of unique chains
117
+ # in that assembly. Allows for usage of `switch_assembly`.
118
+ assembly_composition: dict[str, list[str]] | None = None
119
+
120
+
121
+ @dataclass
122
+ class DockQSingleScore:
123
+ native_chains: tuple[str, str]
124
+ DockQ: float
125
+ interface_rms: float
126
+ ligand_rms: float
127
+ fnat: float
128
+ fnonnat: float
129
+ clashes: float
130
+ F1: float
131
+ DockQ_F1: float
132
+
133
+
134
+ @dataclass
135
+ class DockQResult:
136
+ total_dockq: float
137
+ native_interfaces: int
138
+ chain_mapping: dict[str, str]
139
+ interfaces: dict[tuple[str, str], DockQSingleScore]
140
+ # zip(aligned.chain_iter(), native.chain_iter()) gives you the pairing
141
+ # aligned.rmsd(native) should give you a low rmsd irrespective of shuffling
142
+ aligned: ProteinComplex
143
+ aligned_rmsd: float
144
+
145
+
146
+ @dataclass(frozen=True)
147
+ class ProteinComplex:
148
+ """Dataclass with atom37 representation of an entire protein complex."""
149
+
150
+ id: str
151
+ sequence: str
152
+ entity_id: np.ndarray # entities map to unique sequences
153
+ chain_id: np.ndarray # multiple chains might share an entity id
154
+ sym_id: np.ndarray # complexes might be copies of the same chain
155
+ residue_index: np.ndarray
156
+ insertion_code: np.ndarray
157
+ atom37_positions: np.ndarray
158
+ atom37_mask: np.ndarray
159
+ confidence: np.ndarray
160
+ # This metadata is parsed from the MMCIF file. For synthetic data, we do a best effort.
161
+ metadata: ProteinComplexMetadata
162
+ atom37_confidence: np.ndarray | None = None # [L, 37] per-atom pLDDT
163
+
164
+ def __post_init__(self):
165
+ l = len(self.sequence)
166
+ assert self.atom37_positions.shape[0] == l, (self.atom37_positions.shape, l)
167
+ assert self.atom37_mask.shape[0] == l, (self.atom37_mask.shape, l)
168
+ assert self.residue_index.shape[0] == l, (self.residue_index.shape, l)
169
+ assert self.insertion_code.shape[0] == l, (self.insertion_code.shape, l)
170
+ assert self.confidence.shape[0] == l, (self.confidence.shape, l)
171
+ assert self.entity_id.shape[0] == l, (self.entity_id.shape, l)
172
+ assert self.chain_id.shape[0] == l, (self.chain_id.shape, l)
173
+ assert self.sym_id.shape[0] == l, (self.sym_id.shape, l)
174
+ if self.atom37_confidence is not None:
175
+ assert self.atom37_confidence.shape == self.atom37_mask.shape, (
176
+ self.atom37_confidence.shape,
177
+ self.atom37_mask.shape,
178
+ )
179
+
180
+ def __getitem__(self, idx: int | list[int] | slice | np.ndarray):
181
+ """This function slices protein complexes without consideration of chain breaks
182
+ NOTE: When slicing with a boolean mask, it's possible that the output array won't
183
+ be the expected length. This is because we do our best to preserve chainbreak tokens.
184
+ """
185
+
186
+ if isinstance(idx, int):
187
+ idx = [idx]
188
+ if isinstance(idx, list):
189
+ raise ValueError(
190
+ "ProteinComplex doesn't supports indexing with lists of indices"
191
+ )
192
+
193
+ if isinstance(idx, np.ndarray):
194
+ is_chainbreak = np.asarray([s == "|" for s in self.sequence])
195
+ idx = idx.astype(bool) | is_chainbreak
196
+
197
+ complex = self._unsafe_slice(idx)
198
+ if len(complex) == 0:
199
+ return complex
200
+
201
+ # detect runs of chainbreaks by searching for instances of '||' in complex.sequence
202
+ chainbreak_runs = np.asarray(
203
+ [
204
+ complex.sequence[i : i + 2] == "||"
205
+ for i in range(len(complex.sequence) - 1)
206
+ ]
207
+ + [complex.sequence[-1] == "|"]
208
+ )
209
+ # We should remove as many chainbreaks as possible from the start of the sequence
210
+ for i in range(len(chainbreak_runs)):
211
+ if complex.sequence[i] == "|":
212
+ chainbreak_runs[i] = True
213
+ else:
214
+ break
215
+ complex = complex._unsafe_slice(~chainbreak_runs)
216
+ return complex
217
+
218
+ def _unsafe_slice(self, idx: int | list[int] | slice | np.ndarray):
219
+ sequence = slice_python_object_as_numpy(self.sequence, idx)
220
+ return replace(
221
+ self,
222
+ sequence=sequence,
223
+ entity_id=self.entity_id[..., idx],
224
+ chain_id=self.chain_id[..., idx],
225
+ sym_id=self.sym_id[..., idx],
226
+ residue_index=self.residue_index[..., idx],
227
+ insertion_code=self.insertion_code[..., idx],
228
+ atom37_positions=self.atom37_positions[..., idx, :, :],
229
+ atom37_mask=self.atom37_mask[..., idx, :],
230
+ confidence=self.confidence[..., idx],
231
+ atom37_confidence=self.atom37_confidence[..., idx, :]
232
+ if self.atom37_confidence is not None
233
+ else None,
234
+ )
235
+
236
+ def __len__(self):
237
+ return len(self.sequence)
238
+
239
+ @property
240
+ def num_chains(self):
241
+ return len(self.chain_boundaries)
242
+
243
+ @cached_property
244
+ def atoms(self) -> AtomIndexer:
245
+ return AtomIndexer(self, property="atom37_positions", dim=-2)
246
+
247
+ @cached_property
248
+ def atom_mask(self) -> AtomIndexer:
249
+ return AtomIndexer(self, property="atom37_mask", dim=-1)
250
+
251
+ @cached_property
252
+ def chain_lengths(self) -> np.ndarray:
253
+ return np.diff(self.chain_boundaries, axis=1).flatten()
254
+
255
+ @cached_property
256
+ def chain_boundaries(self) -> list[tuple[int, int]]:
257
+ cb = [-1]
258
+ for i, s in enumerate(self.sequence):
259
+ if s == "|":
260
+ cb.append(i)
261
+ cb.append(len(self))
262
+ return [(cb[i] + 1, cb[i + 1]) for i in range(len(cb) - 1)]
263
+
264
+ def get_chain_by_index(self, index: int) -> ProteinChain:
265
+ try:
266
+ start, end = self.chain_boundaries[index]
267
+ return self[start:end].as_chain()
268
+ except IndexError:
269
+ raise IndexError(f"Chain index {index} out of bounds")
270
+
271
+ def get_chain_by_id(
272
+ self, chain_id: str, sample_chain_if_duplicate: bool = True
273
+ ) -> ProteinChain:
274
+ valid_indices = [
275
+ index
276
+ for index, id_of_index in self.metadata.chain_lookup.items()
277
+ if id_of_index == chain_id
278
+ ]
279
+ if not valid_indices:
280
+ raise KeyError(f"Chain ID {chain_id} not found")
281
+ if sample_chain_if_duplicate:
282
+ index_to_return = random.choice(valid_indices)
283
+ return self.get_chain_by_index(index_to_return)
284
+ else:
285
+ if len(valid_indices) > 1:
286
+ raise ValueError(f"Multiple chains with chain ID {chain_id} found")
287
+ return self.get_chain_by_index(valid_indices[0])
288
+
289
+ def chain_iter(self) -> Iterable[ProteinChain]:
290
+ for start, end in self.chain_boundaries:
291
+ c = self[start:end]
292
+ yield c.as_chain()
293
+
294
+ def as_chain(self, force_conversion: bool = False) -> ProteinChain:
295
+ """Convert the ProteinComplex to a ProteinChain.
296
+
297
+ Args:
298
+ force_conversion (bool): Forces the conversion into a protein chain even if the complex has multiple chains.
299
+ The purpose of this is to use ProteinChain specific functions (like cbeta_contacts).
300
+
301
+ """
302
+ if not force_conversion:
303
+ assert len(np.unique(self.chain_id)) == 1, f"{self.id}"
304
+ assert len(np.unique(self.entity_id)) == 1, f"{self.id}"
305
+ if self.chain_id[0] not in self.metadata.chain_lookup:
306
+ warnings.warn("Chain ID not found in metadata, using 'A' as default")
307
+ if self.entity_id[0] not in self.metadata.entity_lookup:
308
+ warnings.warn("Entity ID not found in metadata, using None as default")
309
+ chain_id = self.metadata.chain_lookup.get(self.chain_id[0], "A")
310
+ entity_id = self.metadata.entity_lookup.get(self.entity_id[0], None)
311
+ else:
312
+ chain_id = "A"
313
+ entity_id = None
314
+
315
+ return ProteinChain(
316
+ id=self.id,
317
+ sequence=self.sequence,
318
+ chain_id=chain_id,
319
+ entity_id=entity_id,
320
+ atom37_positions=self.atom37_positions,
321
+ atom37_mask=self.atom37_mask,
322
+ residue_index=self.residue_index,
323
+ insertion_code=self.insertion_code,
324
+ confidence=self.confidence,
325
+ mmcif=self.metadata.mmcif,
326
+ atom37_confidence=self.atom37_confidence,
327
+ )
328
+
329
+ @classmethod
330
+ def from_pdb(
331
+ cls, path: PathOrBuffer, id: str | None = None, is_predicted: bool = False
332
+ ) -> "ProteinComplex":
333
+ atom_array = PDBFile.read(path).get_structure(
334
+ model=1, extra_fields=["b_factor"]
335
+ )
336
+
337
+ chains = []
338
+ for chain in bs.chain_iter(atom_array):
339
+ chain = chain[~chain.hetero]
340
+ if len(chain) == 0:
341
+ continue
342
+ chains.append(ProteinChain.from_atomarray(chain, id, is_predicted))
343
+ return ProteinComplex.from_chains(chains)
344
+
345
+ def to_pdb(self, path: PathOrBuffer, include_insertions: bool = True):
346
+ atom_array = None
347
+ for chain in self.chain_iter():
348
+ carr = (
349
+ chain.atom_array
350
+ if include_insertions
351
+ else chain.atom_array_no_insertions
352
+ )
353
+ atom_array = carr if atom_array is None else atom_array + carr
354
+ f = PDBFile()
355
+ f.set_structure(atom_array)
356
+ f.write(path)
357
+
358
+ def to_pdb_string(self, include_insertions: bool = True) -> str:
359
+ buf = io.StringIO()
360
+ self.to_pdb(buf, include_insertions=include_insertions)
361
+ buf.seek(0)
362
+ return buf.read()
363
+
364
+ def normalize_chain_ids_for_pdb(self):
365
+ # Since PDB files have 1-letter chain IDs and don't support the idea of a symmetric index,
366
+ # we can normalize it instead which might be necessary for DockQ and to_pdb.
367
+ ids = SINGLE_LETTER_CHAIN_IDS
368
+ chains = []
369
+ for i, chain in enumerate(self.chain_iter()):
370
+ chain = replace(chain, chain_id=ids[i])
371
+ if i > len(ids):
372
+ raise RuntimeError("Too many chains to write to PDB file")
373
+ chains.append(chain)
374
+
375
+ return ProteinComplex.from_chains(chains)
376
+
377
+ def find_assembly_ids_with_chain(self, id: str) -> list[str]:
378
+ good_chains = []
379
+ if (comp := self.metadata.assembly_composition) is not None:
380
+ for assembly_id, chain_ids in comp.items():
381
+ if id in chain_ids:
382
+ good_chains.append(assembly_id)
383
+ else:
384
+ raise ValueError(
385
+ "Cannot switch assemblies on this ProteinComplex, you must create the assembly from mmcif to support this"
386
+ )
387
+ return good_chains
388
+
389
+ def switch_assembly(self, id: str):
390
+ assert self.metadata.mmcif is not None
391
+ return get_assembly_fast(self.metadata.mmcif, assembly_id=id)
392
+
393
+ def state_dict(self, backbone_only=False, json_serializable=False):
394
+ """This state dict is optimized for storage, so it turns things to fp16 whenever
395
+ possible. Note that we also only support int32 residue indices, I'm hoping we don't
396
+ need more than 2**32 residues..."""
397
+ dct = {k: v for k, v in vars(self).items()}
398
+ if backbone_only:
399
+ dct["atom37_mask"][:, 3:] = False
400
+ dct["atom37_positions"] = dct["atom37_positions"][dct["atom37_mask"]]
401
+ if dct.get("atom37_confidence") is not None:
402
+ dct["atom37_confidence"] = dct["atom37_confidence"][dct["atom37_mask"]]
403
+ else:
404
+ dct.pop("atom37_confidence", None)
405
+ for k, v in dct.items():
406
+ if isinstance(v, np.ndarray):
407
+ match v.dtype:
408
+ case np.int64:
409
+ dct[k] = v.astype(np.int32)
410
+ case np.float64 | np.float32:
411
+ dct[k] = v.astype(np.float16)
412
+ case _:
413
+ pass
414
+ if json_serializable:
415
+ dct[k] = v.tolist()
416
+ elif isinstance(v, ProteinComplexMetadata):
417
+ dct[k] = asdict(v)
418
+ dct["metadata"]["mmcif"] = None
419
+ # These can be populated with non-serializable objects and are not needed for reconstruction
420
+ dct.pop("atoms", None)
421
+ dct.pop("atom_mask", None)
422
+ dct.pop("per_chain_kd_trees", None)
423
+ return dct
424
+
425
+ def to_blob(self, backbone_only=False) -> bytes:
426
+ return brotli.compress(msgpack.dumps(self.state_dict(backbone_only)), quality=5)
427
+
428
+ @classmethod
429
+ def from_state_dict(cls, dct):
430
+ # Note: assembly_composition is *supposed* to have string keys.
431
+ dct = _str_key_to_int_key(dct, ignore_keys=["assembly_composition"])
432
+
433
+ for k, v in dct.items():
434
+ if isinstance(v, list):
435
+ dct[k] = np.array(v)
436
+
437
+ atom37 = np.full((*dct["atom37_mask"].shape, 3), np.nan)
438
+ atom37[dct["atom37_mask"]] = dct["atom37_positions"]
439
+ dct["atom37_positions"] = atom37
440
+ if "atom37_confidence" in dct:
441
+ atom37_conf = np.full(dct["atom37_mask"].shape, np.nan, dtype=np.float32)
442
+ atom37_conf[dct["atom37_mask"]] = dct["atom37_confidence"]
443
+ dct["atom37_confidence"] = atom37_conf
444
+ dct = {
445
+ k: (
446
+ v.astype(np.float32)
447
+ if k in ["atom37_positions", "confidence", "atom37_confidence"]
448
+ else v
449
+ )
450
+ for k, v in dct.items()
451
+ }
452
+ if "chain_boundaries" in dct:
453
+ del dct["chain_boundaries"]
454
+ if "chain_boundaries" in dct["metadata"]:
455
+ del dct["metadata"]["chain_boundaries"]
456
+ dct["metadata"] = ProteinComplexMetadata(**dct["metadata"])
457
+ return cls(**dct)
458
+
459
+ @classmethod
460
+ def from_blob(cls, input: Path | str | io.BytesIO | bytes):
461
+ """NOTE(@zlin): blob + sparse coding + brotli + fp16 reduces memory
462
+ of chains from 52G/1M chains to 20G/1M chains, I think this is a good first
463
+ shot at compressing and dumping chains to disk. I'm sure there's better ways."""
464
+ match input:
465
+ case Path() | str():
466
+ bytes = Path(input).read_bytes()
467
+ case io.BytesIO():
468
+ bytes = input.getvalue()
469
+ case _:
470
+ bytes = input
471
+ return cls.from_state_dict(
472
+ msgpack.loads(brotli.decompress(bytes), strict_map_key=False)
473
+ )
474
+
475
+ @classmethod
476
+ def from_rcsb(cls, pdb_id: str, keep_source: bool = False) -> ProteinComplex:
477
+ f: io.StringIO = rcsb.fetch(pdb_id, "cif") # type: ignore
478
+ return cls.from_mmcif(f, id=pdb_id, keep_source=keep_source, is_predicted=False)
479
+
480
+ @classmethod
481
+ def from_mmcif(
482
+ cls,
483
+ path: PathOrBuffer,
484
+ id: str | None = None,
485
+ assembly_id: str | None = None,
486
+ is_predicted: bool = False,
487
+ keep_source: bool = False,
488
+ ):
489
+ """Return a ProteinComplex object from an mmcif file.
490
+ TODO(@zeming): there's actually multiple complexes per file, but for ease of implementation,
491
+ we only consider the first defined complex!
492
+
493
+ Args:
494
+ path (str | Path | io.TextIO): Path or buffer to read mmcif file from. Should be uncompressed.
495
+ id (str, optional): String identifier to assign to structure. Will attempt to infer otherwise.
496
+ is_predicted (bool): If True, reads b factor as the confidence readout. Default: False.
497
+ chain_id (str, optional): Select a chain corresponding to (author) chain id.
498
+ """
499
+ mmcif = MmcifWrapper.read(path, id)
500
+ return get_assembly_fast(mmcif, assembly_id=assembly_id)
501
+
502
+ @classmethod
503
+ def from_chains(
504
+ cls,
505
+ chains: Sequence[ProteinChain],
506
+ mmcif: MmcifWrapper | None = None,
507
+ all_assembly_metadata_dictionary: dict[str, list[str]] | None = None,
508
+ ):
509
+ if not chains:
510
+ raise ValueError(
511
+ "Cannot create a ProteinComplex from an empty list of chains"
512
+ )
513
+
514
+ # TODO(roshan): Make a proper protein complex class
515
+ def join_arrays(arrays: Sequence[np.ndarray], sep: np.ndarray):
516
+ full_array = []
517
+ for array in arrays:
518
+ full_array.append(array)
519
+ full_array.append(sep)
520
+ full_array = full_array[:-1]
521
+ return np.concatenate(full_array, 0)
522
+
523
+ sep_tokens = {
524
+ "residue_index": np.array([-1]),
525
+ "insertion_code": np.array([""]),
526
+ "atom37_positions": np.full([1, 37, 3], np.nan),
527
+ "atom37_mask": np.zeros([1, 37], dtype=bool),
528
+ "confidence": np.array([0]),
529
+ }
530
+
531
+ any_has_atom37_conf = any(c.atom37_confidence is not None for c in chains)
532
+ if any_has_atom37_conf:
533
+ sep_tokens["atom37_confidence"] = np.full([1, 37], np.nan, dtype=np.float32)
534
+
535
+ def _get_chain_attr(chain: ProteinChain, name: str) -> np.ndarray:
536
+ val = getattr(chain, name)
537
+ if val is None and name == "atom37_confidence":
538
+ return np.full([len(chain), 37], np.nan, dtype=np.float32)
539
+ return val
540
+
541
+ array_args: dict[str, np.ndarray] = {
542
+ name: join_arrays([_get_chain_attr(chain, name) for chain in chains], sep)
543
+ for name, sep in sep_tokens.items()
544
+ }
545
+
546
+ multimer_arrays = []
547
+ chain2num_max = -1
548
+ chain2num = {}
549
+ ent2num_max = -1
550
+ ent2num = {}
551
+ total_index = 0
552
+ for i, c in enumerate(chains):
553
+ num_res = c.residue_index.shape[0]
554
+ if c.chain_id not in chain2num:
555
+ chain2num[c.chain_id] = (chain2num_max := chain2num_max + 1)
556
+ chain_id_array = np.full([num_res], chain2num[c.chain_id], dtype=np.int64)
557
+
558
+ if c.entity_id is None:
559
+ entity_num = (ent2num_max := ent2num_max + 1)
560
+ else:
561
+ if c.entity_id not in ent2num:
562
+ ent2num[c.entity_id] = (ent2num_max := ent2num_max + 1)
563
+ entity_num = ent2num[c.entity_id]
564
+ entity_id_array = np.full([num_res], entity_num, dtype=np.int64)
565
+
566
+ sym_id_array = np.full([num_res], i, dtype=np.int64)
567
+
568
+ multimer_arrays.append(
569
+ {
570
+ "chain_id": chain_id_array,
571
+ "entity_id": entity_id_array,
572
+ "sym_id": sym_id_array,
573
+ }
574
+ )
575
+
576
+ total_index += num_res + 1
577
+
578
+ sep = np.array([-1])
579
+ update = {
580
+ name: join_arrays([dct[name] for dct in multimer_arrays], sep=sep)
581
+ for name in ["chain_id", "entity_id", "sym_id"]
582
+ }
583
+ array_args.update(update)
584
+
585
+ metadata = ProteinComplexMetadata(
586
+ mmcif=mmcif,
587
+ chain_lookup={v: k for k, v in chain2num.items()},
588
+ entity_lookup={v: k for k, v in ent2num.items()},
589
+ assembly_composition=all_assembly_metadata_dictionary,
590
+ )
591
+
592
+ return cls(
593
+ id=chains[0].id,
594
+ sequence=residue_constants.CHAIN_BREAK_TOKEN.join(
595
+ chain.sequence for chain in chains
596
+ ),
597
+ metadata=metadata,
598
+ **array_args,
599
+ )
600
+
601
+ def infer_oxygen(self) -> ProteinComplex:
602
+ """Oxygen position is fixed given N, CA, C atoms. Infer it if not provided."""
603
+ O_missing_indices = np.argwhere(
604
+ ~np.isfinite(self.atoms["O"]).all(axis=1)
605
+ ).squeeze()
606
+
607
+ O_vector = torch.tensor([0.6240, -1.0613, 0.0103], dtype=torch.float32)
608
+ N, CA, C = torch.from_numpy(self.atoms[["N", "CA", "C"]]).float().unbind(dim=1)
609
+ N = torch.roll(N, -3)
610
+ N[..., -1, :] = torch.nan
611
+
612
+ # Get the frame defined by the CA-C-N atom
613
+ frames = Affine3D.from_graham_schmidt(CA, C, N)
614
+ O = frames.apply(O_vector)
615
+ atom37_positions = self.atom37_positions.copy()
616
+ atom37_mask = self.atom37_mask.copy()
617
+
618
+ atom37_positions[O_missing_indices, residue_constants.atom_order["O"]] = O[
619
+ O_missing_indices
620
+ ].numpy()
621
+ atom37_mask[O_missing_indices, residue_constants.atom_order["O"]] = ~np.isnan(
622
+ atom37_positions[O_missing_indices, residue_constants.atom_order["O"]]
623
+ ).any(-1)
624
+ new_chain = replace(
625
+ self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
626
+ )
627
+ return new_chain
628
+
629
+ def infer_cbeta(self, infer_cbeta_for_glycine: bool = False) -> ProteinComplex:
630
+ """Return a new chain with inferred CB atoms at all residues except GLY.
631
+
632
+ Args:
633
+ infer_cbeta_for_glycine (bool): If True, infers a beta carbon for glycine
634
+ residues, even though that residue doesn't have one. Default off.
635
+
636
+ NOTE(rverkuil): The reason for having this switch in the first place
637
+ is that sometimes we want a (inferred) CB coordinate for every residue,
638
+ for example for making a pairwise distance matrix, or doing an RMSD
639
+ calculation between two designs for a given structural template, w/
640
+ CB atoms.
641
+ """
642
+ atom37_positions = self.atom37_positions.copy()
643
+ atom37_mask = self.atom37_mask.copy()
644
+
645
+ N, CA, C = np.moveaxis(self.atoms[["N", "CA", "C"]], 1, 0)
646
+ # See usage in trDesign codebase.
647
+ # https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L140
648
+ inferred_cbeta_positions = infer_CB(C, N, CA, 1.522, 1.927, -2.143)
649
+ if not infer_cbeta_for_glycine:
650
+ inferred_cbeta_positions[np.array(list(self.sequence)) == "G", :] = np.nan
651
+
652
+ atom37_positions[:, residue_constants.atom_order["CB"]] = (
653
+ inferred_cbeta_positions
654
+ )
655
+ atom37_mask[:, residue_constants.atom_order["CB"]] = ~np.isnan(
656
+ atom37_positions[:, residue_constants.atom_order["CB"]]
657
+ ).any(-1)
658
+ new_chain = replace(
659
+ self, atom37_positions=atom37_positions, atom37_mask=atom37_mask
660
+ )
661
+ return new_chain
662
+
663
+ @classmethod
664
+ def from_open_source(cls, pc: ProteinComplex):
665
+ # TODO(@zeming): deprecated, should delete
666
+ return pc
667
+
668
+ @classmethod
669
+ def concat(cls, objs: list[ProteinComplex]) -> ProteinComplex:
670
+ pdb_ids = [obj.id for obj in objs]
671
+ if len(set(pdb_ids)) > 1:
672
+ raise RuntimeError(
673
+ "Concatention of protein complexes across different PDB ids is unsupported"
674
+ )
675
+ return ProteinComplex.from_chains(
676
+ list(itertools.chain.from_iterable(obj.chain_iter() for obj in objs))
677
+ )
678
+
679
+ def _sanity_check_complexes_are_comparable(self, other: ProteinComplex):
680
+ assert len(self) == len(other), "Protein complexes must have the same length"
681
+ assert len(list(self.chain_iter())) == len(
682
+ list(other.chain_iter())
683
+ ), "Protein complexes must have the same number of chains"
684
+
685
+ def rmsd(
686
+ self,
687
+ target: ProteinComplex,
688
+ also_check_reflection: bool = False,
689
+ mobile_inds: list[int] | np.ndarray | None = None,
690
+ target_inds: list[int] | np.ndarray | None = None,
691
+ only_compute_backbone_rmsd: bool = False,
692
+ compute_chain_assignment: bool = True,
693
+ ):
694
+ """
695
+ Compute the RMSD between this protein chain and another.
696
+
697
+ Args:
698
+ target (ProteinComplex): The target (other) protein complex to compare to.
699
+ also_check_reflection (bool, optional): If True, also check if the reflection of the mobile atoms has a lower RMSD.
700
+ mobile_inds (list[int], optional): The indices of the mobile atoms to align. These are NOT residue indices
701
+ target_inds (list[int], optional): The indices of the target atoms to align. These are NOT residue indices
702
+ only_compute_backbone_rmsd (bool, optional): If True, only compute the RMSD of the backbone atoms.
703
+ """
704
+ if compute_chain_assignment:
705
+ aligned = self.dockq(target).aligned
706
+ else:
707
+ aligned = self
708
+
709
+ aligner = Aligner(
710
+ aligned if mobile_inds is None else aligned[mobile_inds],
711
+ target if target_inds is None else target[target_inds],
712
+ only_compute_backbone_rmsd,
713
+ )
714
+ avg_rmsd = aligner.rmsd
715
+
716
+ if not also_check_reflection:
717
+ return avg_rmsd
718
+
719
+ aligner = Aligner(
720
+ aligned if mobile_inds is None else aligned[mobile_inds],
721
+ target if target_inds is None else target[target_inds],
722
+ only_compute_backbone_rmsd,
723
+ use_reflection=True,
724
+ )
725
+ avg_rmsd_neg = aligner.rmsd
726
+
727
+ return min(avg_rmsd, avg_rmsd_neg)
728
+
729
+ def lddt_ca(
730
+ self,
731
+ target: ProteinComplex,
732
+ mobile_inds: list[int] | np.ndarray | None = None,
733
+ target_inds: list[int] | np.ndarray | None = None,
734
+ compute_chain_assignment: bool = True,
735
+ **kwargs,
736
+ ) -> float | np.ndarray:
737
+ """Compute the LDDT between this protein complex and another.
738
+
739
+ Arguments:
740
+ target (ProteinComplex): The other protein complex to compare to.
741
+ mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
742
+ target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
743
+
744
+ Returns:
745
+ float | np.ndarray: The LDDT score between the two protein chains, either
746
+ a single float or per-residue LDDT scores if `per_residue` is True.
747
+ """
748
+ if compute_chain_assignment:
749
+ aligned = self.dockq(target).aligned
750
+ else:
751
+ aligned = self
752
+ lddt = compute_lddt_ca(
753
+ torch.tensor(aligned.atom37_positions[mobile_inds]).unsqueeze(0),
754
+ torch.tensor(target.atom37_positions[target_inds]).unsqueeze(0),
755
+ torch.tensor(aligned.atom37_mask[mobile_inds]).unsqueeze(0),
756
+ **kwargs,
757
+ )
758
+ return float(lddt) if lddt.numel() == 1 else lddt.numpy().flatten()
759
+
760
+ def gdt_ts(
761
+ self,
762
+ target: ProteinComplex,
763
+ mobile_inds: list[int] | np.ndarray | None = None,
764
+ target_inds: list[int] | np.ndarray | None = None,
765
+ compute_chain_assignment: bool = True,
766
+ **kwargs,
767
+ ) -> float | np.ndarray:
768
+ """Compute the GDT_TS between this protein complex and another.
769
+
770
+ Arguments:
771
+ target (ProteinComplex): The other protein complex to compare to.
772
+ mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
773
+ target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices
774
+
775
+ Returns:
776
+ float: The GDT_TS score between the two protein chains.
777
+ """
778
+ if compute_chain_assignment:
779
+ aligned = self.dockq(target).aligned
780
+ else:
781
+ aligned = self
782
+ gdt_ts = compute_gdt_ts(
783
+ mobile=torch.tensor(
784
+ index_by_atom_name(aligned.atom37_positions[mobile_inds], "CA"),
785
+ dtype=torch.float32,
786
+ ).unsqueeze(0),
787
+ target=torch.tensor(
788
+ index_by_atom_name(target.atom37_positions[target_inds], "CA"),
789
+ dtype=torch.float32,
790
+ ).unsqueeze(0),
791
+ atom_exists_mask=torch.tensor(
792
+ index_by_atom_name(aligned.atom37_mask[mobile_inds], "CA", dim=-1)
793
+ & index_by_atom_name(target.atom37_mask[target_inds], "CA", dim=-1)
794
+ ).unsqueeze(0),
795
+ **kwargs,
796
+ )
797
+ return float(gdt_ts) if gdt_ts.numel() == 1 else gdt_ts.numpy().flatten()
798
+
799
+ def dockq(self, native: ProteinComplex):
800
+ # This function uses dockqv2 to compute the DockQ score. Because it does a mapping
801
+ # over all possible chains, it's quite slow. Be careful not to use this in an inference loop
802
+ # or something that requires fast scoring. It defaults to 8 CPUs.
803
+ #
804
+ # TODO(@zeming): Because we haven't properly implemented protein complexes for mmcif,
805
+ # if your protein has multi-letter or repeated chain IDs, this will fail. Please call
806
+ # pc = pc.normalize_chain_ids_for_pdb() before calling this function in that case (limit is 62 chains)
807
+
808
+ try:
809
+ pass
810
+ except BaseException:
811
+ raise RuntimeError(
812
+ "DockQ is not installed. Please update your environment."
813
+ )
814
+ self._sanity_check_complexes_are_comparable(native)
815
+
816
+ def sanity_check_chain_ids(pc: ProteinComplex):
817
+ ids = []
818
+ for i, chain in enumerate(pc.chain_iter()):
819
+ if i > len(SINGLE_LETTER_CHAIN_IDS):
820
+ raise ValueError("Too many chains to write to PDB file")
821
+ if len(chain.chain_id) > 1:
822
+ raise ValueError(
823
+ "We only supports single letter chain IDs for DockQ"
824
+ )
825
+ ids.append(chain.chain_id)
826
+ if len(set(ids)) != len(ids):
827
+ raise ValueError(f"Duplicate chain IDs in protein complex: {ids}")
828
+ return ids
829
+
830
+ sanity_check_chain_ids(self)
831
+ sanity_check_chain_ids(native)
832
+
833
+ with TemporaryDirectory() as tdir:
834
+ dir = Path(tdir)
835
+ self.to_pdb(dir / "self.pdb")
836
+ native.to_pdb(dir / "native.pdb")
837
+
838
+ output = check_output(["DockQ", dir / "self.pdb", dir / "native.pdb"])
839
+ lines = output.decode().split("\n")
840
+
841
+ # Remove the header comments
842
+ start_index = next(
843
+ i for i, line in enumerate(lines) if line.startswith("Model")
844
+ )
845
+ lines = lines[start_index:]
846
+
847
+ result = {}
848
+ interfaces = []
849
+ current_interface: dict = {}
850
+
851
+ for line in lines:
852
+ line = line.strip()
853
+ if not line:
854
+ continue
855
+
856
+ if line.startswith("Model :"):
857
+ pass # Tmp pdb file location, it's useless...
858
+ elif line.startswith("Native :"):
859
+ pass # Tmp pdb file location, it's useless...
860
+ elif line.startswith("Total DockQ"):
861
+ total_dockq_match = re.search(
862
+ r"Total DockQ over (\d+) native interfaces: ([\d.]+) with (.*) model:native mapping",
863
+ line,
864
+ )
865
+ if total_dockq_match:
866
+ result["value"] = float(total_dockq_match.group(2))
867
+ result["native interfaces"] = int(total_dockq_match.group(1))
868
+ native_chains, self_chains = total_dockq_match.group(3).split(":")
869
+ result["mapping"] = dict(zip(native_chains, self_chains))
870
+ else:
871
+ raise RuntimeError(
872
+ "Failed to parse DockQ output, maybe your DockQ version is wrong?"
873
+ )
874
+ elif line.startswith("Native chains:"):
875
+ if current_interface:
876
+ interfaces.append(current_interface)
877
+ current_interface = {
878
+ "Native chains": line.split(":")[1].strip().split(", ")
879
+ }
880
+ elif line.startswith("Model chains:"):
881
+ current_interface["Model chains"] = (
882
+ line.split(":")[1].strip().split(", ")
883
+ )
884
+ elif ":" in line:
885
+ key, value = line.split(":", 1)
886
+ current_interface[key.strip()] = float(value.strip())
887
+
888
+ if current_interface:
889
+ interfaces.append(current_interface)
890
+
891
+ def parse_dict(d: dict[str, Any]) -> DockQSingleScore:
892
+ return DockQSingleScore(
893
+ native_chains=tuple(d["Native chains"]), # type: ignore
894
+ DockQ=float(d["DockQ"]),
895
+ interface_rms=float(d["irms"]),
896
+ ligand_rms=float(d["Lrms"]), # Note the capitalization difference
897
+ fnat=float(d["fnat"]),
898
+ fnonnat=float(d["fnonnat"]),
899
+ clashes=float(d["clashes"]),
900
+ F1=float(d["F1"]),
901
+ DockQ_F1=float(d["DockQ_F1"]),
902
+ )
903
+
904
+ inv_mapping = {v: k for k, v in result["mapping"].items()}
905
+
906
+ self_chain_map = {c.chain_id: c for c in self.chain_iter()}
907
+ realigned = []
908
+ for chain in native.chain_iter():
909
+ realigned.append(self_chain_map[inv_mapping[chain.chain_id]])
910
+
911
+ realigned = ProteinComplex.from_chains(realigned)
912
+ aligner = Aligner(realigned, native)
913
+ realigned = aligner.apply(realigned)
914
+
915
+ result = DockQResult(
916
+ total_dockq=result["value"],
917
+ native_interfaces=result["native interfaces"],
918
+ chain_mapping=result["mapping"],
919
+ interfaces={
920
+ (i["Model chains"][0], i["Model chains"][1]): parse_dict(i)
921
+ for i in interfaces
922
+ },
923
+ aligned=realigned,
924
+ aligned_rmsd=aligner.rmsd,
925
+ )
926
+
927
+ return result
928
+
929
+ @cached_property
930
+ def per_chain_kd_trees(self):
931
+ # Iterate over chains, build KDTree for each chain
932
+ kdtrees = []
933
+
934
+ CA = self.atoms["CA"]
935
+
936
+ for start, end in self.chain_boundaries:
937
+ chain_CA = CA[start:end]
938
+ chain_CA = chain_CA[np.isfinite(chain_CA).all(axis=-1)]
939
+ kdtrees.append(KDTree(chain_CA))
940
+
941
+ return kdtrees
942
+
943
+ def chain_adjacency(self, cutoff: float = 8.0) -> np.ndarray:
944
+ # Compute adjacency matrix for protein complex
945
+ num_chains = self.num_chains
946
+ adjacency = np.zeros((num_chains, num_chains), dtype=bool)
947
+ for (i, kdtree), (j, kdtree2) in itertools.combinations(
948
+ enumerate(self.per_chain_kd_trees), 2
949
+ ):
950
+ adj = kdtree.query_ball_tree(kdtree2, cutoff)
951
+ any_is_adjacent = any(len(a) > 0 for a in adj)
952
+ adjacency[i, j] = any_is_adjacent
953
+ adjacency[j, i] = any_is_adjacent
954
+ return adjacency
955
+
956
+ def chain_adjacency_by_index(self, index: int, cutoff: float = 8.0) -> np.ndarray:
957
+ num_chains = len(self.chain_boundaries)
958
+ adjacency = np.zeros(num_chains, dtype=bool)
959
+ for i, kdtree in enumerate(self.per_chain_kd_trees):
960
+ if i == index:
961
+ continue
962
+ adj = kdtree.query_ball_tree(self.per_chain_kd_trees[index], cutoff)
963
+ adjacency[i] = any(len(a) > 0 for a in adj)
964
+ return adjacency
965
+
966
+ def add_prefix_to_chain_ids(self, prefix: str) -> ProteinComplex:
967
+ """Rename all chains in the complex with a given prefix.
968
+
969
+ Args:
970
+ prefix (str): The prefix to use for the new chain IDs. Each chain will be
971
+ named as "{prefix}_{chain_id}".
972
+
973
+ Returns:
974
+ ProteinComplex: A new protein complex with renamed chains.
975
+ """
976
+ new_chains = []
977
+ for chain in self.chain_iter():
978
+ # Create new chain with updated chain_id
979
+ new_chain = replace(chain, chain_id=f"{prefix}_{chain.chain_id}")
980
+ new_chains.append(new_chain)
981
+ return ProteinComplex.from_chains(new_chains)
982
+
983
+ def sasa(self, by_residue: bool = True):
984
+ chain = self.as_chain(force_conversion=True)
985
+ return chain.sasa(by_residue=by_residue)
986
+
987
+ def to_mmcif_string(self) -> str:
988
+ """Convert the ProteinComplex to mmCIF format.
989
+
990
+ Returns:
991
+ str: The mmCIF content as a string.
992
+ """
993
+ # Convert the ProteinComplex to a biotite AtomArray
994
+ # Collect all atoms from all chains
995
+ all_atoms = []
996
+ for chain in self.chain_iter():
997
+ chain_atom_array = chain.atom_array
998
+ # Convert AtomArray to list of atoms and add to collection
999
+ all_atoms.extend(chain_atom_array)
1000
+
1001
+ # Create combined AtomArray from all atoms
1002
+ if not all_atoms:
1003
+ raise ValueError("No atoms found in protein complex")
1004
+
1005
+ atom_array = bs.array(all_atoms)
1006
+
1007
+ # Create CIF file
1008
+ f = CIFFile()
1009
+ set_structure_pdbx(f, atom_array, data_block=self.id)
1010
+
1011
+ # Add entity information for proper mmCIF structure
1012
+ self._add_entity_information(f)
1013
+
1014
+ # Write to string
1015
+ output = io.StringIO()
1016
+ f.write(output)
1017
+ return output.getvalue()
1018
+
1019
+ def _add_entity_information(self, cif_file: CIFFile) -> None:
1020
+ """Add entity, entity_poly, and struct_asym sections to CIF file."""
1021
+
1022
+ # Group chains by sequence to create unique entities
1023
+ entity_map = {} # sequence -> entity_id
1024
+ chain_to_entity = {} # chain_id -> entity_id
1025
+ entity_sequences = {} # entity_id -> sequence
1026
+ entity_id_counter = 1
1027
+
1028
+ for chain in self.chain_iter():
1029
+ sequence = chain.sequence
1030
+ if sequence not in entity_map:
1031
+ entity_map[sequence] = entity_id_counter
1032
+ entity_sequences[entity_id_counter] = sequence
1033
+ entity_id_counter += 1
1034
+ chain_to_entity[chain.chain_id] = entity_map[sequence]
1035
+
1036
+ # Create _entity section
1037
+ entity_ids = []
1038
+ entity_types = []
1039
+ entity_descriptions = []
1040
+
1041
+ for entity_id in sorted(entity_sequences.keys()):
1042
+ entity_ids.append(str(entity_id))
1043
+ entity_types.append("polymer")
1044
+ entity_descriptions.append(f"Protein chain (entity {entity_id})")
1045
+
1046
+ cif_file.block["entity"] = CIFCategory(
1047
+ name="entity",
1048
+ columns={
1049
+ "id": CIFColumn(
1050
+ data=CIFData(array=np.array(entity_ids), dtype=np.str_)
1051
+ ),
1052
+ "type": CIFColumn(
1053
+ data=CIFData(array=np.array(entity_types), dtype=np.str_)
1054
+ ),
1055
+ "pdbx_description": CIFColumn(
1056
+ data=CIFData(array=np.array(entity_descriptions), dtype=np.str_)
1057
+ ),
1058
+ },
1059
+ )
1060
+
1061
+ # Create _entity_poly section
1062
+ poly_entity_ids = []
1063
+ poly_types = []
1064
+ poly_nstd_linkages = []
1065
+ poly_sequences = []
1066
+
1067
+ for entity_id in sorted(entity_sequences.keys()):
1068
+ poly_entity_ids.append(str(entity_id))
1069
+ poly_types.append("polypeptide(L)")
1070
+ poly_nstd_linkages.append("no")
1071
+ poly_sequences.append(entity_sequences[entity_id])
1072
+
1073
+ cif_file.block["entity_poly"] = CIFCategory(
1074
+ name="entity_poly",
1075
+ columns={
1076
+ "entity_id": CIFColumn(
1077
+ data=CIFData(array=np.array(poly_entity_ids), dtype=np.str_)
1078
+ ),
1079
+ "type": CIFColumn(
1080
+ data=CIFData(array=np.array(poly_types), dtype=np.str_)
1081
+ ),
1082
+ "nstd_linkage": CIFColumn(
1083
+ data=CIFData(array=np.array(poly_nstd_linkages), dtype=np.str_)
1084
+ ),
1085
+ "pdbx_seq_one_letter_code": CIFColumn(
1086
+ data=CIFData(array=np.array(poly_sequences), dtype=np.str_)
1087
+ ),
1088
+ },
1089
+ )
1090
+
1091
+ # Create _struct_asym section
1092
+ asym_ids = []
1093
+ asym_entity_ids = []
1094
+ asym_details = []
1095
+
1096
+ for chain in self.chain_iter():
1097
+ asym_ids.append(chain.chain_id)
1098
+ asym_entity_ids.append(str(chain_to_entity[chain.chain_id]))
1099
+ asym_details.append("")
1100
+
1101
+ cif_file.block["struct_asym"] = CIFCategory(
1102
+ name="struct_asym",
1103
+ columns={
1104
+ "id": CIFColumn(data=CIFData(array=np.array(asym_ids), dtype=np.str_)),
1105
+ "entity_id": CIFColumn(
1106
+ data=CIFData(array=np.array(asym_entity_ids), dtype=np.str_)
1107
+ ),
1108
+ "details": CIFColumn(
1109
+ data=CIFData(array=np.array(asym_details), dtype=np.str_)
1110
+ ),
1111
+ },
1112
+ )
1113
+
1114
+
1115
+ def get_assembly_fast(
1116
+ mmcif: MmcifWrapper,
1117
+ assembly_id=None,
1118
+ model=None,
1119
+ data_block=None,
1120
+ altloc="first",
1121
+ use_author_fields=True,
1122
+ ):
1123
+ pdbx_file = mmcif.raw
1124
+ if pdbx_file is None:
1125
+ raise InvalidFileError("No mmCIF data loaded")
1126
+ assembly_gen_category = pdbx_file.block["pdbx_struct_assembly_gen"]
1127
+ if assembly_gen_category is None:
1128
+ raise InvalidFileError("File has no 'pdbx_struct_assembly_gen' category")
1129
+
1130
+ struct_oper_category = pdbx_file.block["pdbx_struct_oper_list"]
1131
+ if struct_oper_category is None:
1132
+ raise InvalidFileError("File has no 'pdbx_struct_oper_list' category")
1133
+
1134
+ if assembly_id is None:
1135
+ assembly_id = assembly_gen_category["assembly_id"].data.array[0]
1136
+ elif assembly_id not in assembly_gen_category["assembly_id"].data.array:
1137
+ raise KeyError(f"File has no Assembly ID '{assembly_id}'")
1138
+
1139
+ ### Calculate all possible transformations
1140
+ transformations = _get_transformations(struct_oper_category)
1141
+
1142
+ ### Get structure according to additional parameters
1143
+ structure = get_structure(
1144
+ pdbx_file, model, data_block, altloc, ["label_asym_id"], use_author_fields
1145
+ )[0] # type: ignore
1146
+ # TODO(@zeming) This line will remove all non-protein structural elements,
1147
+ # we should remove this when we want to parse these too.
1148
+ structure: bs.AtomArray = structure[
1149
+ bs.filter_amino_acids(structure) & ~structure.hetero # type: ignore
1150
+ ]
1151
+ if len(structure) == 0:
1152
+ raise NoProteinError
1153
+ unique_asym_ids = np.unique(structure.label_asym_id) # type: ignore
1154
+ asym2chain = {}
1155
+ asym2auth = {}
1156
+ for asym_id in unique_asym_ids:
1157
+ sub_structure: bs.AtomArray = structure[structure.label_asym_id == asym_id] # type: ignore
1158
+ chain_id: str = sub_structure[0].chain_id # type: ignore
1159
+ (
1160
+ sequence,
1161
+ atom_positions,
1162
+ atom_mask,
1163
+ residue_index,
1164
+ insertion_code,
1165
+ confidence,
1166
+ entity_id,
1167
+ ) = chain_to_ndarray(sub_structure, mmcif, chain_id, False)
1168
+
1169
+ asym2chain[asym_id] = ProteinChain(
1170
+ id=mmcif.id or "unknown",
1171
+ sequence=sequence,
1172
+ chain_id=chain_id,
1173
+ entity_id=entity_id,
1174
+ atom37_positions=atom_positions,
1175
+ atom37_mask=atom_mask,
1176
+ residue_index=residue_index,
1177
+ insertion_code=insertion_code,
1178
+ confidence=confidence,
1179
+ mmcif=None,
1180
+ )
1181
+ asym2auth[asym_id] = chain_id
1182
+
1183
+ ### Get transformations and apply them to the affected asym IDs
1184
+ assembly = []
1185
+ assembly_id_dict: dict[str, list[str]] = {}
1186
+
1187
+ # Process the target assembly ID
1188
+ for aid, op_expr, asym_id_expr in zip(
1189
+ assembly_gen_category["assembly_id"].data.array,
1190
+ assembly_gen_category["oper_expression"].data.array,
1191
+ assembly_gen_category["asym_id_list"].data.array,
1192
+ ):
1193
+ if aid == assembly_id:
1194
+ # Parse operations and asym IDs for this specific entry
1195
+ operations = _parse_operation_expression(op_expr)
1196
+ asym_ids = asym_id_expr.split(",")
1197
+
1198
+ # Filter affected asym IDs to only protein chains, preserving order
1199
+ sub_structures = [
1200
+ asym2chain[asym_id] for asym_id in asym_ids if asym_id in asym2chain
1201
+ ]
1202
+
1203
+ # Apply transformations
1204
+ sub_assembly = _apply_transformations_fast(
1205
+ sub_structures, transformations, operations
1206
+ )
1207
+ assembly.extend(sub_assembly)
1208
+
1209
+ # Build assembly_id_dict for this entry
1210
+ assembly_id_dict[aid] = assembly_id_dict.get(aid, []) + [
1211
+ asym2auth[id_] for id_ in asym_ids if id_ in asym2auth
1212
+ ]
1213
+
1214
+ if len(assembly) == 0:
1215
+ raise NoProteinError
1216
+ return ProteinComplex.from_chains(assembly, mmcif, assembly_id_dict)
1217
+
1218
+
1219
+ def protein_chain_to_protein_complex(chain: ProteinChain) -> ProteinComplex:
1220
+ if "|" not in chain.sequence:
1221
+ return ProteinComplex.from_chains([chain])
1222
+ chain_breaks = np.array(list(chain.sequence)) == "|"
1223
+ chain_break_inds = np.where(chain_breaks)[0]
1224
+ chain_break_inds = np.concatenate([[0], chain_break_inds, [len(chain)]])
1225
+ chain_break_inds = np.array(list(zip(chain_break_inds[:-1], chain_break_inds[1:])))
1226
+ complex_chains = []
1227
+ for start, end in chain_break_inds:
1228
+ if start != 0:
1229
+ start += 1
1230
+ complex_chains.append(chain[start:end])
1231
+ complex_chains = [
1232
+ ProteinChain.from_atom37(
1233
+ chain.atom37_positions,
1234
+ sequence=chain.sequence,
1235
+ chain_id=SINGLE_LETTER_CHAIN_IDS[i],
1236
+ entity_id=i,
1237
+ )
1238
+ for i, chain in enumerate(complex_chains)
1239
+ ]
1240
+ return ProteinComplex.from_chains(complex_chains)
1241
+
esmfold2_protein_structure.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Tuple, TypeVar
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from torch.amp import autocast # type: ignore
10
+
11
+ from . import esmfold2_residue_constants
12
+ from .esmfold2_misc import unbinpack
13
+ from .esmfold2_affine3d import Affine3D
14
+
15
+ ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor)
16
+
17
+
18
+ def index_by_atom_name(
19
+ atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2
20
+ ) -> ArrayOrTensor:
21
+ squeeze = False
22
+ if isinstance(atom_names, str):
23
+ atom_names = [atom_names]
24
+ squeeze = True
25
+ indices = [residue_constants.atom_order[atom_name] for atom_name in atom_names]
26
+ dim = dim % atom37.ndim
27
+ index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim))
28
+ result = atom37[index] # type: ignore
29
+ if squeeze:
30
+ result = result.squeeze(dim)
31
+ return result
32
+
33
+
34
+ def infer_cbeta_from_atom37(
35
+ atom37: ArrayOrTensor, L: float = 1.522, A: float = 1.927, D: float = -2.143
36
+ ):
37
+ """
38
+ Inspired by a util in trDesign:
39
+ https://github.com/gjoni/trDesign/blob/f2d5930b472e77bfacc2f437b3966e7a708a8d37/02-GD/utils.py#L92
40
+
41
+ input: atom37, (L)ength, (A)ngle, and (D)ihedral
42
+ output: 4th coord
43
+ """
44
+ N = index_by_atom_name(atom37, "N", dim=-2)
45
+ CA = index_by_atom_name(atom37, "CA", dim=-2)
46
+ C = index_by_atom_name(atom37, "C", dim=-2)
47
+
48
+ if isinstance(atom37, np.ndarray):
49
+
50
+ def normalize(x: ArrayOrTensor):
51
+ return x / np.linalg.norm(x, axis=-1, keepdims=True)
52
+
53
+ cross = np.cross
54
+ else:
55
+ normalize = F.normalize # type: ignore
56
+ cross = torch.cross
57
+
58
+ with np.errstate(invalid="ignore"): # inf - inf = nan is ok here
59
+ vec_nca = N - CA
60
+ vec_nc = N - C
61
+ nca = normalize(vec_nca)
62
+ n = normalize(cross(vec_nc, nca)) # type: ignore
63
+ m = [nca, cross(n, nca), n]
64
+ d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
65
+ return CA + sum([m * d for m, d in zip(m, d)])
66
+
67
+
68
+ @torch.no_grad()
69
+ @autocast("cuda", enabled=False)
70
+ def compute_alignment_tensors(
71
+ mobile: torch.Tensor,
72
+ target: torch.Tensor,
73
+ atom_exists_mask: torch.Tensor | None = None,
74
+ sequence_id: torch.Tensor | None = None,
75
+ ):
76
+ """
77
+ Align two batches of structures with support for masking invalid atoms using PyTorch.
78
+
79
+ Args:
80
+ - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
81
+ - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
82
+ - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
83
+ - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
84
+
85
+ Returns:
86
+ - centered_mobile (torch.Tensor): Batch of coordinates of structure centered mobile (B, N, 3)
87
+ - centroid_mobile (torch.Tensor): Batch of coordinates of mobile centeroid (B, 3)
88
+ - centered_target (torch.Tensor): Batch of coordinates of structure centered target (B, N, 3)
89
+ - centroid_target (torch.Tensor): Batch of coordinates of target centeroid (B, 3)
90
+ - rotation_matrix (torch.Tensor): Batch of coordinates of rotation matrix (B, 3, 3)
91
+ - num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,)
92
+ """
93
+
94
+ # Ensure both batches have the same number of structures, atoms, and dimensions
95
+ if sequence_id is not None:
96
+ mobile = unbinpack(mobile, sequence_id, pad_value=torch.nan)
97
+ target = unbinpack(target, sequence_id, pad_value=torch.nan)
98
+ if atom_exists_mask is not None:
99
+ atom_exists_mask = unbinpack(atom_exists_mask, sequence_id, pad_value=0)
100
+ else:
101
+ atom_exists_mask = torch.isfinite(target).all(-1)
102
+
103
+ assert mobile.shape == target.shape, "Batch structure shapes do not match!"
104
+
105
+ # Number of structures in the batch
106
+ batch_size = mobile.shape[0]
107
+
108
+ # if [B, Nres, Natom, 3], resize
109
+ if mobile.dim() == 4:
110
+ mobile = mobile.view(batch_size, -1, 3)
111
+ if target.dim() == 4:
112
+ target = target.view(batch_size, -1, 3)
113
+ if atom_exists_mask is not None and atom_exists_mask.dim() == 3:
114
+ atom_exists_mask = atom_exists_mask.view(batch_size, -1)
115
+
116
+ # Number of atoms
117
+ num_atoms = mobile.shape[1]
118
+
119
+ # Apply masks if provided
120
+ if atom_exists_mask is not None:
121
+ mobile = mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
122
+ target = target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
123
+ else:
124
+ atom_exists_mask = torch.ones(
125
+ batch_size, num_atoms, dtype=torch.bool, device=mobile.device
126
+ )
127
+
128
+ num_valid_atoms = atom_exists_mask.sum(dim=-1, keepdim=True)
129
+ # Compute centroids for each batch
130
+ centroid_mobile = mobile.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1)
131
+ centroid_target = target.sum(dim=-2, keepdim=True) / num_valid_atoms.unsqueeze(-1)
132
+
133
+ # Handle potential division by zero if all atoms are invalid in a structure
134
+ centroid_mobile[num_valid_atoms == 0] = 0
135
+ centroid_target[num_valid_atoms == 0] = 0
136
+
137
+ # Center structures by subtracting centroids
138
+ centered_mobile = mobile - centroid_mobile
139
+ centered_target = target - centroid_target
140
+
141
+ centered_mobile = centered_mobile.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
142
+ centered_target = centered_target.masked_fill(~atom_exists_mask.unsqueeze(-1), 0)
143
+
144
+ # Compute covariance matrix for each batch
145
+ covariance_matrix = torch.matmul(centered_mobile.transpose(1, 2), centered_target)
146
+
147
+ # Singular Value Decomposition for each batch
148
+ u, _, v = torch.svd(covariance_matrix)
149
+
150
+ # Calculate rotation matrices for each batch
151
+ rotation_matrix = torch.matmul(u, v.transpose(1, 2))
152
+
153
+ return (
154
+ centered_mobile,
155
+ centroid_mobile,
156
+ centered_target,
157
+ centroid_target,
158
+ rotation_matrix,
159
+ num_valid_atoms,
160
+ )
161
+
162
+
163
+ @torch.no_grad()
164
+ @autocast("cuda", enabled=False)
165
+ def compute_rmsd_no_alignment(
166
+ aligned: torch.Tensor,
167
+ target: torch.Tensor,
168
+ num_valid_atoms: torch.Tensor,
169
+ reduction: str = "batch",
170
+ ) -> torch.Tensor:
171
+ """
172
+ Compute RMSD between two batches of structures without alignment.
173
+
174
+ Args:
175
+ - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
176
+ - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
177
+ - num_valid_atoms (torch.Tensor): Batch of number of valid atoms for alignment (B,)
178
+ - reduction (str): One of "batch", "per_sample", "per_residue".
179
+
180
+ Returns:
181
+
182
+ If reduction == "batch":
183
+ (torch.Tensor): 0-dim, Average Root Mean Square Deviation between the structures for each batch
184
+ If reduction == "per_sample":
185
+ (torch.Tensor): (B,)-dim, Root Mean Square Deviation between the structures for each batch
186
+ If reduction == "per_residue":
187
+ (torch.Tensor): (B, N)-dim, Root Mean Square Deviation between the structures for residue in the batch
188
+ """
189
+ if reduction not in ("per_residue", "per_sample", "batch"):
190
+ raise ValueError("Unrecognized reduction: '{reduction}'")
191
+ # Compute RMSD for each batch
192
+ diff = aligned - target
193
+ if reduction == "per_residue":
194
+ mean_squared_error = diff.square().view(diff.size(0), -1, 9).mean(dim=-1)
195
+ else:
196
+ mean_squared_error = diff.square().sum(dim=(1, 2)) / (
197
+ num_valid_atoms.squeeze(-1)
198
+ )
199
+
200
+ rmsd = torch.sqrt(mean_squared_error)
201
+ if reduction in ("per_sample", "per_residue"):
202
+ return rmsd
203
+ elif reduction == "batch":
204
+ avg_rmsd = rmsd.masked_fill(num_valid_atoms.squeeze(-1) == 0, 0).sum() / (
205
+ (num_valid_atoms > 0).sum() + 1e-8
206
+ )
207
+ return avg_rmsd
208
+ else:
209
+ raise ValueError(reduction)
210
+
211
+
212
+ @torch.no_grad()
213
+ @autocast("cuda", enabled=False)
214
+ def compute_affine_and_rmsd(
215
+ mobile: torch.Tensor,
216
+ target: torch.Tensor,
217
+ atom_exists_mask: torch.Tensor | None = None,
218
+ sequence_id: torch.Tensor | None = None,
219
+ ) -> Tuple[Affine3D, torch.Tensor]:
220
+ """
221
+ Compute RMSD between two batches of structures with support for masking invalid atoms using PyTorch.
222
+
223
+ Args:
224
+ - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
225
+ - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
226
+ - atom_exists_mask (torch.Tensor, optional): Mask for Whether an atom exists of shape (B, N)
227
+ - sequence_id (torch.Tensor, optional): Sequence id tensor for binpacking.
228
+
229
+ Returns:
230
+ - affine (Affine3D): Transformation between mobile and target structure
231
+ - avg_rmsd (torch.Tensor): Average Root Mean Square Deviation between the structures for each batch
232
+ """
233
+
234
+ (
235
+ centered_mobile,
236
+ centroid_mobile,
237
+ centered_target,
238
+ centroid_target,
239
+ rotation_matrix,
240
+ num_valid_atoms,
241
+ ) = compute_alignment_tensors(
242
+ mobile=mobile,
243
+ target=target,
244
+ atom_exists_mask=atom_exists_mask,
245
+ sequence_id=sequence_id,
246
+ )
247
+
248
+ # Apply rotation to mobile centroid
249
+ translation = torch.matmul(-centroid_mobile, rotation_matrix) + centroid_target
250
+ affine = Affine3D.from_tensor_pair(
251
+ translation, rotation_matrix.unsqueeze(dim=-3).transpose(-2, -1)
252
+ )
253
+
254
+ # Apply transformation to centered structure to compute rmsd
255
+ rotated_mobile = torch.matmul(centered_mobile, rotation_matrix)
256
+ avg_rmsd = compute_rmsd_no_alignment(
257
+ rotated_mobile, centered_target, num_valid_atoms, reduction="batch"
258
+ )
259
+
260
+ return affine, avg_rmsd
261
+
262
+
263
+ def compute_gdt_ts_no_alignment(
264
+ aligned: torch.Tensor,
265
+ target: torch.Tensor,
266
+ atom_exists_mask: torch.Tensor,
267
+ reduction: str = "batch",
268
+ ) -> torch.Tensor:
269
+ """
270
+ Compute GDT_TS between two batches of structures without alignment.
271
+
272
+ Args:
273
+ - mobile (torch.Tensor): Batch of coordinates of structure to be superimposed in shape (B, N, 3)
274
+ - target (torch.Tensor): Batch of coordinates of structure that is fixed in shape (B, N, 3)
275
+ - atom_exists_mask (torch.Tensor): Mask for Whether an atom exists of shape (B, N). noo
276
+ - reduction (str): One of "batch", "per_sample".
277
+
278
+ Returns:
279
+ If reduction == "batch":
280
+ (torch.Tensor): 0-dim, GDT_TS between the structures for each batch
281
+ If reduction == "per_sample":
282
+ (torch.Tensor): (B,)-dim, GDT_TS between the structures for each sample in the batch
283
+ """
284
+ if reduction not in ("per_sample", "batch"):
285
+ raise ValueError("Unrecognized reduction: '{reduction}'")
286
+
287
+ if atom_exists_mask is None:
288
+ atom_exists_mask = torch.isfinite(target).all(dim=-1)
289
+
290
+ deviation = torch.linalg.vector_norm(aligned - target, dim=-1)
291
+ num_valid_atoms = atom_exists_mask.sum(dim=-1)
292
+
293
+ # Compute GDT_TS
294
+ score = (
295
+ ((deviation < 1) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
296
+ + ((deviation < 2) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
297
+ + ((deviation < 4) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
298
+ + ((deviation < 8) * atom_exists_mask).sum(dim=-1) / num_valid_atoms
299
+ ) * 0.25
300
+
301
+ if reduction == "batch":
302
+ return score.mean()
303
+ elif reduction == "per_sample":
304
+ return score
305
+ else:
306
+ raise ValueError("Unrecognized reduction: '{reduction}'")
307
+
esmfold2_residue_constants.py ADDED
@@ -0,0 +1,1224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 EvolutionaryScale
2
+ # Copyright 2021 AlQuraishi Laboratory
3
+ # Copyright 2021 DeepMind Technologies Limited
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Constants used in AlphaFold."""
18
+
19
+ import collections
20
+ import functools
21
+ from pathlib import Path
22
+ from typing import List, Mapping, Tuple
23
+
24
+ import numpy as np
25
+
26
+ # import tree
27
+
28
+ # Internal import (35fd).
29
+
30
+
31
+ # Distance from one CA to next CA [trans configuration: omega = 180].
32
+ ca_ca = 3.80209737096
33
+
34
+ # Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
35
+ # this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
36
+ # chi angles so their chi angle lists are empty.
37
+ chi_angles_atoms = {
38
+ "ALA": [],
39
+ # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
40
+ "ARG": [
41
+ ["N", "CA", "CB", "CG"],
42
+ ["CA", "CB", "CG", "CD"],
43
+ ["CB", "CG", "CD", "NE"],
44
+ ["CG", "CD", "NE", "CZ"],
45
+ ],
46
+ "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
47
+ "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
48
+ "CYS": [["N", "CA", "CB", "SG"]],
49
+ "GLN": [
50
+ ["N", "CA", "CB", "CG"],
51
+ ["CA", "CB", "CG", "CD"],
52
+ ["CB", "CG", "CD", "OE1"],
53
+ ],
54
+ "GLU": [
55
+ ["N", "CA", "CB", "CG"],
56
+ ["CA", "CB", "CG", "CD"],
57
+ ["CB", "CG", "CD", "OE1"],
58
+ ],
59
+ "GLY": [],
60
+ "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
61
+ "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
62
+ "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
63
+ "LYS": [
64
+ ["N", "CA", "CB", "CG"],
65
+ ["CA", "CB", "CG", "CD"],
66
+ ["CB", "CG", "CD", "CE"],
67
+ ["CG", "CD", "CE", "NZ"],
68
+ ],
69
+ "MET": [
70
+ ["N", "CA", "CB", "CG"],
71
+ ["CA", "CB", "CG", "SD"],
72
+ ["CB", "CG", "SD", "CE"],
73
+ ],
74
+ "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
75
+ "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
76
+ "SER": [["N", "CA", "CB", "OG"]],
77
+ "THR": [["N", "CA", "CB", "OG1"]],
78
+ "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
79
+ "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
80
+ "VAL": [["N", "CA", "CB", "CG1"]],
81
+ "UNK": [],
82
+ }
83
+
84
+ # If chi angles given in fixed-length array, this matrix determines how to mask
85
+ # them for each AA type. The order is as per restype_order (see below).
86
+ chi_angles_mask = [
87
+ [0.0, 0.0, 0.0, 0.0], # ALA
88
+ [1.0, 1.0, 1.0, 1.0], # ARG
89
+ [1.0, 1.0, 0.0, 0.0], # ASN
90
+ [1.0, 1.0, 0.0, 0.0], # ASP
91
+ [1.0, 0.0, 0.0, 0.0], # CYS
92
+ [1.0, 1.0, 1.0, 0.0], # GLN
93
+ [1.0, 1.0, 1.0, 0.0], # GLU
94
+ [0.0, 0.0, 0.0, 0.0], # GLY
95
+ [1.0, 1.0, 0.0, 0.0], # HIS
96
+ [1.0, 1.0, 0.0, 0.0], # ILE
97
+ [1.0, 1.0, 0.0, 0.0], # LEU
98
+ [1.0, 1.0, 1.0, 1.0], # LYS
99
+ [1.0, 1.0, 1.0, 0.0], # MET
100
+ [1.0, 1.0, 0.0, 0.0], # PHE
101
+ [1.0, 1.0, 0.0, 0.0], # PRO
102
+ [1.0, 0.0, 0.0, 0.0], # SER
103
+ [1.0, 0.0, 0.0, 0.0], # THR
104
+ [1.0, 1.0, 0.0, 0.0], # TRP
105
+ [1.0, 1.0, 0.0, 0.0], # TYR
106
+ [1.0, 0.0, 0.0, 0.0], # VAL
107
+ [0.0, 0.0, 0.0, 0.0], # UNK
108
+ ]
109
+
110
+ # The following chi angles are pi periodic: they can be rotated by a multiple
111
+ # of pi without affecting the structure.
112
+ chi_pi_periodic = [
113
+ [0.0, 0.0, 0.0, 0.0], # ALA
114
+ [0.0, 0.0, 0.0, 0.0], # ARG
115
+ [0.0, 0.0, 0.0, 0.0], # ASN
116
+ [0.0, 1.0, 0.0, 0.0], # ASP
117
+ [0.0, 0.0, 0.0, 0.0], # CYS
118
+ [0.0, 0.0, 0.0, 0.0], # GLN
119
+ [0.0, 0.0, 1.0, 0.0], # GLU
120
+ [0.0, 0.0, 0.0, 0.0], # GLY
121
+ [0.0, 0.0, 0.0, 0.0], # HIS
122
+ [0.0, 0.0, 0.0, 0.0], # ILE
123
+ [0.0, 0.0, 0.0, 0.0], # LEU
124
+ [0.0, 0.0, 0.0, 0.0], # LYS
125
+ [0.0, 0.0, 0.0, 0.0], # MET
126
+ [0.0, 1.0, 0.0, 0.0], # PHE
127
+ [0.0, 0.0, 0.0, 0.0], # PRO
128
+ [0.0, 0.0, 0.0, 0.0], # SER
129
+ [0.0, 0.0, 0.0, 0.0], # THR
130
+ [0.0, 0.0, 0.0, 0.0], # TRP
131
+ [0.0, 1.0, 0.0, 0.0], # TYR
132
+ [0.0, 0.0, 0.0, 0.0], # VAL
133
+ [0.0, 0.0, 0.0, 0.0], # UNK
134
+ ]
135
+
136
+ # Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
137
+ # psi and chi angles:
138
+ # 0: 'backbone group',
139
+ # 1: 'pre-omega-group', (empty)
140
+ # 2: 'phi-group', (currently empty, because it defines only hydrogens)
141
+ # 3: 'psi-group',
142
+ # 4,5,6,7: 'chi1,2,3,4-group'
143
+ # The atom positions are relative to the axis-end-atom of the corresponding
144
+ # rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
145
+ # is defined such that the dihedral-angle-definiting atom (the last entry in
146
+ # chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
147
+ # format: [atomname, group_idx, rel_position]
148
+ rigid_group_atom_positions = {
149
+ "ALA": [
150
+ ["N", 0, (-0.525, 1.363, 0.000)],
151
+ ["CA", 0, (0.000, 0.000, 0.000)],
152
+ ["C", 0, (1.526, -0.000, -0.000)],
153
+ ["CB", 0, (-0.529, -0.774, -1.205)],
154
+ ["O", 3, (0.627, 1.062, 0.000)],
155
+ ],
156
+ "ARG": [
157
+ ["N", 0, (-0.524, 1.362, -0.000)],
158
+ ["CA", 0, (0.000, 0.000, 0.000)],
159
+ ["C", 0, (1.525, -0.000, -0.000)],
160
+ ["CB", 0, (-0.524, -0.778, -1.209)],
161
+ ["O", 3, (0.626, 1.062, 0.000)],
162
+ ["CG", 4, (0.616, 1.390, -0.000)],
163
+ ["CD", 5, (0.564, 1.414, 0.000)],
164
+ ["NE", 6, (0.539, 1.357, -0.000)],
165
+ ["NH1", 7, (0.206, 2.301, 0.000)],
166
+ ["NH2", 7, (2.078, 0.978, -0.000)],
167
+ ["CZ", 7, (0.758, 1.093, -0.000)],
168
+ ],
169
+ "ASN": [
170
+ ["N", 0, (-0.536, 1.357, 0.000)],
171
+ ["CA", 0, (0.000, 0.000, 0.000)],
172
+ ["C", 0, (1.526, -0.000, -0.000)],
173
+ ["CB", 0, (-0.531, -0.787, -1.200)],
174
+ ["O", 3, (0.625, 1.062, 0.000)],
175
+ ["CG", 4, (0.584, 1.399, 0.000)],
176
+ ["ND2", 5, (0.593, -1.188, 0.001)],
177
+ ["OD1", 5, (0.633, 1.059, 0.000)],
178
+ ],
179
+ "ASP": [
180
+ ["N", 0, (-0.525, 1.362, -0.000)],
181
+ ["CA", 0, (0.000, 0.000, 0.000)],
182
+ ["C", 0, (1.527, 0.000, -0.000)],
183
+ ["CB", 0, (-0.526, -0.778, -1.208)],
184
+ ["O", 3, (0.626, 1.062, -0.000)],
185
+ ["CG", 4, (0.593, 1.398, -0.000)],
186
+ ["OD1", 5, (0.610, 1.091, 0.000)],
187
+ ["OD2", 5, (0.592, -1.101, -0.003)],
188
+ ],
189
+ "CYS": [
190
+ ["N", 0, (-0.522, 1.362, -0.000)],
191
+ ["CA", 0, (0.000, 0.000, 0.000)],
192
+ ["C", 0, (1.524, 0.000, 0.000)],
193
+ ["CB", 0, (-0.519, -0.773, -1.212)],
194
+ ["O", 3, (0.625, 1.062, -0.000)],
195
+ ["SG", 4, (0.728, 1.653, 0.000)],
196
+ ],
197
+ "GLN": [
198
+ ["N", 0, (-0.526, 1.361, -0.000)],
199
+ ["CA", 0, (0.000, 0.000, 0.000)],
200
+ ["C", 0, (1.526, 0.000, 0.000)],
201
+ ["CB", 0, (-0.525, -0.779, -1.207)],
202
+ ["O", 3, (0.626, 1.062, -0.000)],
203
+ ["CG", 4, (0.615, 1.393, 0.000)],
204
+ ["CD", 5, (0.587, 1.399, -0.000)],
205
+ ["NE2", 6, (0.593, -1.189, -0.001)],
206
+ ["OE1", 6, (0.634, 1.060, 0.000)],
207
+ ],
208
+ "GLU": [
209
+ ["N", 0, (-0.528, 1.361, 0.000)],
210
+ ["CA", 0, (0.000, 0.000, 0.000)],
211
+ ["C", 0, (1.526, -0.000, -0.000)],
212
+ ["CB", 0, (-0.526, -0.781, -1.207)],
213
+ ["O", 3, (0.626, 1.062, 0.000)],
214
+ ["CG", 4, (0.615, 1.392, 0.000)],
215
+ ["CD", 5, (0.600, 1.397, 0.000)],
216
+ ["OE1", 6, (0.607, 1.095, -0.000)],
217
+ ["OE2", 6, (0.589, -1.104, -0.001)],
218
+ ],
219
+ "GLY": [
220
+ ["N", 0, (-0.572, 1.337, 0.000)],
221
+ ["CA", 0, (0.000, 0.000, 0.000)],
222
+ ["C", 0, (1.517, -0.000, -0.000)],
223
+ ["O", 3, (0.626, 1.062, -0.000)],
224
+ ],
225
+ "HIS": [
226
+ ["N", 0, (-0.527, 1.360, 0.000)],
227
+ ["CA", 0, (0.000, 0.000, 0.000)],
228
+ ["C", 0, (1.525, 0.000, 0.000)],
229
+ ["CB", 0, (-0.525, -0.778, -1.208)],
230
+ ["O", 3, (0.625, 1.063, 0.000)],
231
+ ["CG", 4, (0.600, 1.370, -0.000)],
232
+ ["CD2", 5, (0.889, -1.021, 0.003)],
233
+ ["ND1", 5, (0.744, 1.160, -0.000)],
234
+ ["CE1", 5, (2.030, 0.851, 0.002)],
235
+ ["NE2", 5, (2.145, -0.466, 0.004)],
236
+ ],
237
+ "ILE": [
238
+ ["N", 0, (-0.493, 1.373, -0.000)],
239
+ ["CA", 0, (0.000, 0.000, 0.000)],
240
+ ["C", 0, (1.527, -0.000, -0.000)],
241
+ ["CB", 0, (-0.536, -0.793, -1.213)],
242
+ ["O", 3, (0.627, 1.062, -0.000)],
243
+ ["CG1", 4, (0.534, 1.437, -0.000)],
244
+ ["CG2", 4, (0.540, -0.785, -1.199)],
245
+ ["CD1", 5, (0.619, 1.391, 0.000)],
246
+ ],
247
+ "LEU": [
248
+ ["N", 0, (-0.520, 1.363, 0.000)],
249
+ ["CA", 0, (0.000, 0.000, 0.000)],
250
+ ["C", 0, (1.525, -0.000, -0.000)],
251
+ ["CB", 0, (-0.522, -0.773, -1.214)],
252
+ ["O", 3, (0.625, 1.063, -0.000)],
253
+ ["CG", 4, (0.678, 1.371, 0.000)],
254
+ ["CD1", 5, (0.530, 1.430, -0.000)],
255
+ ["CD2", 5, (0.535, -0.774, 1.200)],
256
+ ],
257
+ "LYS": [
258
+ ["N", 0, (-0.526, 1.362, -0.000)],
259
+ ["CA", 0, (0.000, 0.000, 0.000)],
260
+ ["C", 0, (1.526, 0.000, 0.000)],
261
+ ["CB", 0, (-0.524, -0.778, -1.208)],
262
+ ["O", 3, (0.626, 1.062, -0.000)],
263
+ ["CG", 4, (0.619, 1.390, 0.000)],
264
+ ["CD", 5, (0.559, 1.417, 0.000)],
265
+ ["CE", 6, (0.560, 1.416, 0.000)],
266
+ ["NZ", 7, (0.554, 1.387, 0.000)],
267
+ ],
268
+ "MET": [
269
+ ["N", 0, (-0.521, 1.364, -0.000)],
270
+ ["CA", 0, (0.000, 0.000, 0.000)],
271
+ ["C", 0, (1.525, 0.000, 0.000)],
272
+ ["CB", 0, (-0.523, -0.776, -1.210)],
273
+ ["O", 3, (0.625, 1.062, -0.000)],
274
+ ["CG", 4, (0.613, 1.391, -0.000)],
275
+ ["SD", 5, (0.703, 1.695, 0.000)],
276
+ ["CE", 6, (0.320, 1.786, -0.000)],
277
+ ],
278
+ "PHE": [
279
+ ["N", 0, (-0.518, 1.363, 0.000)],
280
+ ["CA", 0, (0.000, 0.000, 0.000)],
281
+ ["C", 0, (1.524, 0.000, -0.000)],
282
+ ["CB", 0, (-0.525, -0.776, -1.212)],
283
+ ["O", 3, (0.626, 1.062, -0.000)],
284
+ ["CG", 4, (0.607, 1.377, 0.000)],
285
+ ["CD1", 5, (0.709, 1.195, -0.000)],
286
+ ["CD2", 5, (0.706, -1.196, 0.000)],
287
+ ["CE1", 5, (2.102, 1.198, -0.000)],
288
+ ["CE2", 5, (2.098, -1.201, -0.000)],
289
+ ["CZ", 5, (2.794, -0.003, -0.001)],
290
+ ],
291
+ "PRO": [
292
+ ["N", 0, (-0.566, 1.351, -0.000)],
293
+ ["CA", 0, (0.000, 0.000, 0.000)],
294
+ ["C", 0, (1.527, -0.000, 0.000)],
295
+ ["CB", 0, (-0.546, -0.611, -1.293)],
296
+ ["O", 3, (0.621, 1.066, 0.000)],
297
+ ["CG", 4, (0.382, 1.445, 0.0)],
298
+ # ['CD', 5, (0.427, 1.440, 0.0)],
299
+ ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
300
+ ],
301
+ "SER": [
302
+ ["N", 0, (-0.529, 1.360, -0.000)],
303
+ ["CA", 0, (0.000, 0.000, 0.000)],
304
+ ["C", 0, (1.525, -0.000, -0.000)],
305
+ ["CB", 0, (-0.518, -0.777, -1.211)],
306
+ ["O", 3, (0.626, 1.062, -0.000)],
307
+ ["OG", 4, (0.503, 1.325, 0.000)],
308
+ ],
309
+ "THR": [
310
+ ["N", 0, (-0.517, 1.364, 0.000)],
311
+ ["CA", 0, (0.000, 0.000, 0.000)],
312
+ ["C", 0, (1.526, 0.000, -0.000)],
313
+ ["CB", 0, (-0.516, -0.793, -1.215)],
314
+ ["O", 3, (0.626, 1.062, 0.000)],
315
+ ["CG2", 4, (0.550, -0.718, -1.228)],
316
+ ["OG1", 4, (0.472, 1.353, 0.000)],
317
+ ],
318
+ "TRP": [
319
+ ["N", 0, (-0.521, 1.363, 0.000)],
320
+ ["CA", 0, (0.000, 0.000, 0.000)],
321
+ ["C", 0, (1.525, -0.000, 0.000)],
322
+ ["CB", 0, (-0.523, -0.776, -1.212)],
323
+ ["O", 3, (0.627, 1.062, 0.000)],
324
+ ["CG", 4, (0.609, 1.370, -0.000)],
325
+ ["CD1", 5, (0.824, 1.091, 0.000)],
326
+ ["CD2", 5, (0.854, -1.148, -0.005)],
327
+ ["CE2", 5, (2.186, -0.678, -0.007)],
328
+ ["CE3", 5, (0.622, -2.530, -0.007)],
329
+ ["NE1", 5, (2.140, 0.690, -0.004)],
330
+ ["CH2", 5, (3.028, -2.890, -0.013)],
331
+ ["CZ2", 5, (3.283, -1.543, -0.011)],
332
+ ["CZ3", 5, (1.715, -3.389, -0.011)],
333
+ ],
334
+ "TYR": [
335
+ ["N", 0, (-0.522, 1.362, 0.000)],
336
+ ["CA", 0, (0.000, 0.000, 0.000)],
337
+ ["C", 0, (1.524, -0.000, -0.000)],
338
+ ["CB", 0, (-0.522, -0.776, -1.213)],
339
+ ["O", 3, (0.627, 1.062, -0.000)],
340
+ ["CG", 4, (0.607, 1.382, -0.000)],
341
+ ["CD1", 5, (0.716, 1.195, -0.000)],
342
+ ["CD2", 5, (0.713, -1.194, -0.001)],
343
+ ["CE1", 5, (2.107, 1.200, -0.002)],
344
+ ["CE2", 5, (2.104, -1.201, -0.003)],
345
+ ["OH", 5, (4.168, -0.002, -0.005)],
346
+ ["CZ", 5, (2.791, -0.001, -0.003)],
347
+ ],
348
+ "VAL": [
349
+ ["N", 0, (-0.494, 1.373, -0.000)],
350
+ ["CA", 0, (0.000, 0.000, 0.000)],
351
+ ["C", 0, (1.527, -0.000, -0.000)],
352
+ ["CB", 0, (-0.533, -0.795, -1.213)],
353
+ ["O", 3, (0.627, 1.062, -0.000)],
354
+ ["CG1", 4, (0.540, 1.429, -0.000)],
355
+ ["CG2", 4, (0.533, -0.776, 1.203)],
356
+ ],
357
+ # Assume alanine positions for unknown AA
358
+ "UNK": [
359
+ ["N", 0, (-0.525, 1.363, 0.000)],
360
+ ["CA", 0, (0.000, 0.000, 0.000)],
361
+ ["C", 0, (1.526, -0.000, -0.000)],
362
+ ],
363
+ }
364
+
365
+ # A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
366
+ residue_atoms = {
367
+ "ALA": ["C", "CA", "CB", "N", "O"],
368
+ "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
369
+ "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
370
+ "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
371
+ "CYS": ["C", "CA", "CB", "N", "O", "SG"],
372
+ "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
373
+ "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
374
+ "GLY": ["C", "CA", "N", "O"],
375
+ "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
376
+ "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
377
+ "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
378
+ "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
379
+ "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
380
+ "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
381
+ "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
382
+ "SER": ["C", "CA", "CB", "N", "O", "OG"],
383
+ "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
384
+ "TRP": [
385
+ "C",
386
+ "CA",
387
+ "CB",
388
+ "CG",
389
+ "CD1",
390
+ "CD2",
391
+ "CE2",
392
+ "CE3",
393
+ "CZ2",
394
+ "CZ3",
395
+ "CH2",
396
+ "N",
397
+ "NE1",
398
+ "O",
399
+ ],
400
+ "TYR": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O", "OH"],
401
+ "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
402
+ "UNK": ["C", "CA", "N"],
403
+ }
404
+
405
+ # Naming swaps for ambiguous atom names.
406
+ # Due to symmetries in the amino acids the naming of atoms is ambiguous in
407
+ # 4 of the 20 amino acids.
408
+ # (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
409
+ # in LEU, VAL and ARG can be resolved by using the 3d constellations of
410
+ # the 'ambiguous' atoms and their neighbours)
411
+ # TODO: ^ interpret this
412
+ residue_atom_renaming_swaps = {
413
+ "ASP": {"OD1": "OD2"},
414
+ "GLU": {"OE1": "OE2"},
415
+ "PHE": {"CD1": "CD2", "CE1": "CE2"},
416
+ "TYR": {"CD1": "CD2", "CE1": "CE2"},
417
+ }
418
+
419
+ # Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
420
+ van_der_waals_radius = {"C": 1.7, "N": 1.55, "O": 1.52, "S": 1.8}
421
+
422
+ Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"])
423
+ BondAngle = collections.namedtuple(
424
+ "BondAngle", ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"]
425
+ )
426
+
427
+
428
+ @functools.lru_cache(maxsize=None)
429
+ def load_stereo_chemical_props() -> (
430
+ Tuple[
431
+ Mapping[str, List[Bond]],
432
+ Mapping[str, List[Bond]],
433
+ Mapping[str, List[BondAngle]],
434
+ ]
435
+ ):
436
+ """Load stereo_chemical_props.txt into a nice structure.
437
+
438
+ Load literature values for bond lengths and bond angles and translate
439
+ bond angles into the length of the opposite edge of the triangle
440
+ ("residue_virtual_bonds").
441
+
442
+ Returns:
443
+ residue_bonds: dict that maps resname --> list of Bond tuples
444
+ residue_virtual_bonds: dict that maps resname --> list of Bond tuples
445
+ residue_bond_angles: dict that maps resname --> list of BondAngle tuples
446
+ """
447
+ stereo_chemical_props = Path(
448
+ "evolutionaryscale/structure/stereo_chemical_props.txt"
449
+ ).read_text()
450
+
451
+ lines_iter = iter(stereo_chemical_props.splitlines())
452
+ # Load bond lengths.
453
+ residue_bonds = {}
454
+ next(lines_iter) # Skip header line.
455
+ for line in lines_iter:
456
+ if line.strip() == "-":
457
+ break
458
+ bond, resname, length, stddev = line.split()
459
+ atom1, atom2 = bond.split("-")
460
+ if resname not in residue_bonds:
461
+ residue_bonds[resname] = []
462
+ residue_bonds[resname].append(Bond(atom1, atom2, float(length), float(stddev)))
463
+ residue_bonds["UNK"] = []
464
+
465
+ # Load bond angles.
466
+ residue_bond_angles = {}
467
+ next(lines_iter) # Skip empty line.
468
+ next(lines_iter) # Skip header line.
469
+ for line in lines_iter:
470
+ if line.strip() == "-":
471
+ break
472
+ bond, resname, angle_degree, stddev_degree = line.split()
473
+ atom1, atom2, atom3 = bond.split("-")
474
+ if resname not in residue_bond_angles:
475
+ residue_bond_angles[resname] = []
476
+ residue_bond_angles[resname].append(
477
+ BondAngle(
478
+ atom1,
479
+ atom2,
480
+ atom3,
481
+ float(angle_degree) / 180.0 * np.pi,
482
+ float(stddev_degree) / 180.0 * np.pi,
483
+ )
484
+ )
485
+ residue_bond_angles["UNK"] = []
486
+
487
+ def make_bond_key(atom1_name, atom2_name):
488
+ """Unique key to lookup bonds."""
489
+ return "-".join(sorted([atom1_name, atom2_name]))
490
+
491
+ # Translate bond angles into distances ("virtual bonds").
492
+ residue_virtual_bonds = {}
493
+ for resname, bond_angles in residue_bond_angles.items():
494
+ # Create a fast lookup dict for bond lengths.
495
+ bond_cache = {}
496
+ for b in residue_bonds[resname]:
497
+ bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
498
+ residue_virtual_bonds[resname] = []
499
+ for ba in bond_angles:
500
+ bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
501
+ bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
502
+
503
+ # Compute distance between atom1 and atom3 using the law of cosines
504
+ # c^2 = a^2 + b^2 - 2ab*cos(gamma).
505
+ gamma = ba.angle_rad
506
+ length = np.sqrt(
507
+ bond1.length**2
508
+ + bond2.length**2
509
+ - 2 * bond1.length * bond2.length * np.cos(gamma)
510
+ )
511
+
512
+ # Propagation of uncertainty assuming uncorrelated errors.
513
+ dl_outer = 0.5 / length
514
+ dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
515
+ dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
516
+ dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
517
+ stddev = np.sqrt(
518
+ (dl_dgamma * ba.stddev) ** 2
519
+ + (dl_db1 * bond1.stddev) ** 2
520
+ + (dl_db2 * bond2.stddev) ** 2
521
+ )
522
+ residue_virtual_bonds[resname].append(
523
+ Bond(ba.atom1_name, ba.atom3name, length, stddev)
524
+ )
525
+
526
+ return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
527
+
528
+
529
+ # Between-residue bond lengths for general bonds (first element) and for Proline
530
+ # (second element).
531
+ between_res_bond_length_c_n = [1.329, 1.341]
532
+ between_res_bond_length_stddev_c_n = [0.014, 0.016]
533
+
534
+ # Between-residue cos_angles.
535
+ between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
536
+ between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
537
+
538
+ # This mapping is used when we need to store atom data in a format that requires
539
+ # fixed atom data size for every residue (e.g. a numpy array).
540
+ atom_types = [
541
+ "N",
542
+ "CA",
543
+ "C",
544
+ "CB",
545
+ "O",
546
+ "CG",
547
+ "CG1",
548
+ "CG2",
549
+ "OG",
550
+ "OG1",
551
+ "SG",
552
+ "CD",
553
+ "CD1",
554
+ "CD2",
555
+ "ND1",
556
+ "ND2",
557
+ "OD1",
558
+ "OD2",
559
+ "SD",
560
+ "CE",
561
+ "CE1",
562
+ "CE2",
563
+ "CE3",
564
+ "NE",
565
+ "NE1",
566
+ "NE2",
567
+ "OE1",
568
+ "OE2",
569
+ "CH2",
570
+ "NH1",
571
+ "NH2",
572
+ "OH",
573
+ "CZ",
574
+ "CZ2",
575
+ "CZ3",
576
+ "NZ",
577
+ "OXT",
578
+ ]
579
+ atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
580
+ atom_type_num = len(atom_types) # := 37.
581
+
582
+ # A compact atom encoding with 14 columns
583
+ # pylint: disable=line-too-long
584
+ # pylint: disable=bad-whitespace
585
+ restype_name_to_atom14_names = {
586
+ "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
587
+ "ARG": [
588
+ "N",
589
+ "CA",
590
+ "C",
591
+ "O",
592
+ "CB",
593
+ "CG",
594
+ "CD",
595
+ "NE",
596
+ "CZ",
597
+ "NH1",
598
+ "NH2",
599
+ "",
600
+ "",
601
+ "",
602
+ ],
603
+ "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
604
+ "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
605
+ "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
606
+ "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", ""],
607
+ "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", ""],
608
+ "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
609
+ "HIS": [
610
+ "N",
611
+ "CA",
612
+ "C",
613
+ "O",
614
+ "CB",
615
+ "CG",
616
+ "ND1",
617
+ "CD2",
618
+ "CE1",
619
+ "NE2",
620
+ "",
621
+ "",
622
+ "",
623
+ "",
624
+ ],
625
+ "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
626
+ "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
627
+ "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
628
+ "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
629
+ "PHE": [
630
+ "N",
631
+ "CA",
632
+ "C",
633
+ "O",
634
+ "CB",
635
+ "CG",
636
+ "CD1",
637
+ "CD2",
638
+ "CE1",
639
+ "CE2",
640
+ "CZ",
641
+ "",
642
+ "",
643
+ "",
644
+ ],
645
+ "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
646
+ "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
647
+ "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
648
+ "TRP": [
649
+ "N",
650
+ "CA",
651
+ "C",
652
+ "O",
653
+ "CB",
654
+ "CG",
655
+ "CD1",
656
+ "CD2",
657
+ "NE1",
658
+ "CE2",
659
+ "CE3",
660
+ "CZ2",
661
+ "CZ3",
662
+ "CH2",
663
+ ],
664
+ "TYR": [
665
+ "N",
666
+ "CA",
667
+ "C",
668
+ "O",
669
+ "CB",
670
+ "CG",
671
+ "CD1",
672
+ "CD2",
673
+ "CE1",
674
+ "CE2",
675
+ "CZ",
676
+ "OH",
677
+ "",
678
+ "",
679
+ ],
680
+ "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
681
+ "UNK": ["N", "CA", "C", "", "", "", "", "", "", "", "", "", "", ""],
682
+ }
683
+ # pylint: enable=line-too-long
684
+ # pylint: enable=bad-whitespace
685
+
686
+
687
+ # This is the standard residue order when coding AA type as a number.
688
+ # Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
689
+ restypes = [
690
+ "A",
691
+ "R",
692
+ "N",
693
+ "D",
694
+ "C",
695
+ "Q",
696
+ "E",
697
+ "G",
698
+ "H",
699
+ "I",
700
+ "L",
701
+ "K",
702
+ "M",
703
+ "F",
704
+ "P",
705
+ "S",
706
+ "T",
707
+ "W",
708
+ "Y",
709
+ "V",
710
+ ]
711
+ restype_order = {restype: i for i, restype in enumerate(restypes)}
712
+ restype_num = len(restypes) # := 20.
713
+ unk_restype_index = restype_num # Catch-all index for unknown restypes.
714
+
715
+ restypes_with_x = restypes + ["X"]
716
+ restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
717
+
718
+ bb_atoms = ["N", "CA", "C", "O"]
719
+
720
+ # Hydrophobicity by residue (positive values are hydrophobic). Derived from Black & Mould (1991), normalized by subtracting 0.5.
721
+ hydrophobicity = {
722
+ "ALA": 0.116,
723
+ "ARG": -0.5,
724
+ "ASN": -0.264,
725
+ "ASP": -0.472,
726
+ "CYS": 0.18,
727
+ "GLN": -0.249,
728
+ "GLU": -0.457,
729
+ "GLY": 0.001,
730
+ "HIS": -0.335,
731
+ "ILE": 0.443,
732
+ "LEU": 0.443,
733
+ "LYS": -0.217,
734
+ "MET": 0.238,
735
+ "PHE": 0.5,
736
+ "PRO": 0.211,
737
+ "SER": -0.141,
738
+ "THR": -0.05,
739
+ "TRP": 0.378,
740
+ "TYR": 0.38,
741
+ "VAL": 0.325,
742
+ }
743
+
744
+ # Side chain max accessible surface area in Ala-X-Ala tripeptide (from Chennamsetty et al. 2010).
745
+ side_chain_asa = {
746
+ "ALA": 64.7809,
747
+ "ARG": 210.02,
748
+ "ASN": 113.187,
749
+ "ASP": 110.209,
750
+ "CYS": 95.2439,
751
+ "GLN": 147.855,
752
+ "GLU": 143.924,
753
+ "GLY": 23.1338,
754
+ "HIS": 146.449,
755
+ "ILE": 151.242,
756
+ "LEU": 139.524,
757
+ "LYS": 177.366,
758
+ "MET": 164.674,
759
+ "PHE": 186.7,
760
+ "PRO": 111.533,
761
+ "SER": 81.2159,
762
+ "THR": 111.597,
763
+ "TRP": 229.619,
764
+ "TYR": 200.306,
765
+ "VAL": 124.237,
766
+ }
767
+
768
+ # Approximate Volumes of amino acids in cubic angstroms.
769
+ # https://www.imgt.org/IMGTeducation/Aide-memoire/_UK/aminoacids/abbreviation.html
770
+ amino_acid_volumes = {
771
+ "A": 88.6, # Alanine
772
+ "R": 173.4, # Arginine
773
+ "N": 114.1, # Asparagine
774
+ "D": 111.1, # Aspartic acid
775
+ "C": 108.5, # Cysteine
776
+ "Q": 143.8, # Glutamine
777
+ "E": 138.4, # Glutamic acid
778
+ "G": 60.1, # Glycine
779
+ "H": 153.2, # Histidine
780
+ "I": 166.7, # Isoleucine
781
+ "L": 166.7, # Leucine
782
+ "K": 168.6, # Lysine
783
+ "M": 162.9, # Methionine
784
+ "F": 189.9, # Phenylalanine
785
+ "P": 112.7, # Proline
786
+ "S": 89.0, # Serine
787
+ "T": 116.1, # Threonine
788
+ "W": 227.8, # Tryptophan
789
+ "Y": 193.6, # Tyrosine
790
+ "V": 140.0, # Valine
791
+ "X": 88.6, # Unknown, use Alanine as approximation
792
+ }
793
+
794
+
795
+ def sequence_to_onehot(
796
+ sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
797
+ ) -> np.ndarray:
798
+ """Maps the given sequence into a one-hot encoded matrix.
799
+
800
+ Args:
801
+ sequence: An amino acid sequence.
802
+ mapping: A dictionary mapping amino acids to integers.
803
+ map_unknown_to_x: If True, any amino acid that is not in the mapping will be
804
+ mapped to the unknown amino acid 'X'. If the mapping doesn't contain
805
+ amino acid 'X', an error will be thrown. If False, any amino acid not in
806
+ the mapping will throw an error.
807
+
808
+ Returns:
809
+ A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
810
+ the sequence.
811
+
812
+ Raises:
813
+ ValueError: If the mapping doesn't contain values from 0 to
814
+ num_unique_aas - 1 without any gaps.
815
+ """
816
+ num_entries = max(mapping.values()) + 1
817
+
818
+ if sorted(set(mapping.values())) != list(range(num_entries)):
819
+ raise ValueError(
820
+ "The mapping must have values from 0 to num_unique_aas-1 "
821
+ "without any gaps. Got: %s" % sorted(mapping.values())
822
+ )
823
+
824
+ one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
825
+
826
+ for aa_index, aa_type in enumerate(sequence):
827
+ if map_unknown_to_x:
828
+ if aa_type.isalpha() and aa_type.isupper():
829
+ aa_id = mapping.get(aa_type, mapping["X"])
830
+ else:
831
+ raise ValueError(f"Invalid character in the sequence: {aa_type}")
832
+ else:
833
+ aa_id = mapping[aa_type]
834
+ one_hot_arr[aa_index, aa_id] = 1
835
+
836
+ return one_hot_arr
837
+
838
+
839
+ restype_1to3 = {
840
+ "A": "ALA",
841
+ "R": "ARG",
842
+ "N": "ASN",
843
+ "D": "ASP",
844
+ "C": "CYS",
845
+ "Q": "GLN",
846
+ "E": "GLU",
847
+ "G": "GLY",
848
+ "H": "HIS",
849
+ "I": "ILE",
850
+ "L": "LEU",
851
+ "K": "LYS",
852
+ "M": "MET",
853
+ "F": "PHE",
854
+ "P": "PRO",
855
+ "S": "SER",
856
+ "T": "THR",
857
+ "W": "TRP",
858
+ "Y": "TYR",
859
+ "V": "VAL",
860
+ "X": "UNK",
861
+ }
862
+
863
+
864
+ # NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
865
+ # 1-to-1 mapping of 3 letter names to one letter names. The latter contains
866
+ # many more, and less common, three letter names as keys and maps many of these
867
+ # to the same one letter name (including 'X' and 'U' which we don't use here).
868
+ restype_3to1 = {v: k for k, v in restype_1to3.items()}
869
+
870
+ # Define a restype name for all unknown residues.
871
+ unk_restype = "UNK"
872
+
873
+ resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
874
+ resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
875
+
876
+ hydrophobic_resnames = {"VAL", "ILE", "LEU", "PHE", "MET", "TRP"}
877
+
878
+ # The mapping here uses hhblits convention, so that B is mapped to D, J and O
879
+ # are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
880
+ # remaining 20 amino acids are kept in alphabetical order.
881
+ # There are 2 non-amino acid codes, X (representing any amino acid) and
882
+ # "-" representing a missing amino acid in an alignment. The id for these
883
+ # codes is put at the end (20 and 21) so that they can easily be ignored if
884
+ # desired.
885
+ HHBLITS_AA_TO_ID = {
886
+ "A": 0,
887
+ "B": 2,
888
+ "C": 1,
889
+ "D": 2,
890
+ "E": 3,
891
+ "F": 4,
892
+ "G": 5,
893
+ "H": 6,
894
+ "I": 7,
895
+ "J": 20,
896
+ "K": 8,
897
+ "L": 9,
898
+ "M": 10,
899
+ "N": 11,
900
+ "O": 20,
901
+ "P": 12,
902
+ "Q": 13,
903
+ "R": 14,
904
+ "S": 15,
905
+ "T": 16,
906
+ "U": 1,
907
+ "V": 17,
908
+ "W": 18,
909
+ "X": 20,
910
+ "Y": 19,
911
+ "Z": 3,
912
+ "-": 21,
913
+ }
914
+
915
+ # Partial inversion of HHBLITS_AA_TO_ID.
916
+ ID_TO_HHBLITS_AA = {
917
+ 0: "A",
918
+ 1: "C", # Also U.
919
+ 2: "D", # Also B.
920
+ 3: "E", # Also Z.
921
+ 4: "F",
922
+ 5: "G",
923
+ 6: "H",
924
+ 7: "I",
925
+ 8: "K",
926
+ 9: "L",
927
+ 10: "M",
928
+ 11: "N",
929
+ 12: "P",
930
+ 13: "Q",
931
+ 14: "R",
932
+ 15: "S",
933
+ 16: "T",
934
+ 17: "V",
935
+ 18: "W",
936
+ 19: "Y",
937
+ 20: "X", # Includes J and O.
938
+ 21: "-",
939
+ }
940
+
941
+ restypes_with_x_and_gap = restypes + ["X", "-"]
942
+ MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
943
+ restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
944
+ for i in range(len(restypes_with_x_and_gap))
945
+ )
946
+
947
+
948
+ def _make_standard_atom_mask() -> np.ndarray:
949
+ """Returns [num_res_types, num_atom_types] mask array."""
950
+ # +1 to account for unknown (all 0s).
951
+ mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
952
+ for restype, restype_letter in enumerate(restypes):
953
+ restype_name = restype_1to3[restype_letter]
954
+ atom_names = residue_atoms[restype_name]
955
+ for atom_name in atom_names:
956
+ atom_type = atom_order[atom_name]
957
+ mask[restype, atom_type] = 1
958
+ return mask
959
+
960
+
961
+ STANDARD_ATOM_MASK = _make_standard_atom_mask()
962
+
963
+
964
+ # A one hot representation for the first and second atoms defining the axis
965
+ # of rotation for each chi-angle in each residue.
966
+ def chi_angle_atom(atom_index: int) -> np.ndarray:
967
+ """Define chi-angle rigid groups via one-hot representations."""
968
+ chi_angles_index = {}
969
+ one_hots = []
970
+
971
+ for k, v in chi_angles_atoms.items():
972
+ indices = [atom_types.index(s[atom_index]) for s in v]
973
+ indices.extend([-1] * (4 - len(indices)))
974
+ chi_angles_index[k] = indices
975
+
976
+ for r in restypes:
977
+ res3 = restype_1to3[r]
978
+ one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
979
+ one_hots.append(one_hot)
980
+
981
+ one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
982
+ one_hot = np.stack(one_hots, axis=0)
983
+ one_hot = np.transpose(one_hot, [0, 2, 1])
984
+
985
+ return one_hot
986
+
987
+
988
+ chi_atom_1_one_hot = chi_angle_atom(1)
989
+ chi_atom_2_one_hot = chi_angle_atom(2)
990
+
991
+ # An array like chi_angles_atoms but using indices rather than names.
992
+ chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
993
+ # chi_angles_atom_indices = tree.map_structure(
994
+ # lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
995
+ # )
996
+ chi_angles_atom_indices = np.array(
997
+ [
998
+ chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
999
+ for chi_atoms in chi_angles_atom_indices
1000
+ ]
1001
+ )
1002
+
1003
+ # Mapping from (res_name, atom_name) pairs to the atom's chi group index
1004
+ # and atom index within that group.
1005
+ chi_groups_for_atom = collections.defaultdict(list)
1006
+ for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
1007
+ for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
1008
+ for atom_i, atom in enumerate(chi_group):
1009
+ chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
1010
+ chi_groups_for_atom = dict(chi_groups_for_atom)
1011
+
1012
+
1013
+ def _make_rigid_transformation_4x4(ex, ey, translation):
1014
+ """Create a rigid 4x4 transformation matrix from two axes and transl."""
1015
+ # Normalize ex.
1016
+ ex_normalized = ex / np.linalg.norm(ex)
1017
+
1018
+ # make ey perpendicular to ex
1019
+ ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
1020
+ ey_normalized /= np.linalg.norm(ey_normalized)
1021
+
1022
+ # compute ez as cross product
1023
+ eznorm = np.cross(ex_normalized, ey_normalized)
1024
+ m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
1025
+ m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
1026
+ return m
1027
+
1028
+
1029
+ # create an array with (restype, atomtype) --> rigid_group_idx
1030
+ # and an array with (restype, atomtype, coord) for the atom positions
1031
+ # and compute affine transformation matrices (4,4) from one rigid group to the
1032
+ # previous group
1033
+ restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
1034
+ restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
1035
+ restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
1036
+ restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
1037
+ restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
1038
+ restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
1039
+ restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
1040
+
1041
+
1042
+ def _make_rigid_group_constants():
1043
+ """Fill the arrays above."""
1044
+ for restype, restype_letter in enumerate(restypes_with_x):
1045
+ resname = restype_1to3[restype_letter]
1046
+ for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]:
1047
+ atomtype = atom_order[atomname]
1048
+ restype_atom37_to_rigid_group[restype, atomtype] = group_idx
1049
+ restype_atom37_mask[restype, atomtype] = 1
1050
+ restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
1051
+
1052
+ atom14idx = restype_name_to_atom14_names[resname].index(atomname)
1053
+ restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
1054
+ restype_atom14_mask[restype, atom14idx] = 1
1055
+ restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position
1056
+
1057
+ for restype, restype_letter in enumerate(restypes_with_x):
1058
+ resname = restype_1to3[restype_letter]
1059
+ atom_positions = {
1060
+ name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]
1061
+ }
1062
+
1063
+ # backbone to backbone is the identity transform
1064
+ restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
1065
+
1066
+ # pre-omega-frame to backbone (currently dummy identity matrix)
1067
+ restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
1068
+
1069
+ # phi-frame to backbone
1070
+ mat = _make_rigid_transformation_4x4(
1071
+ ex=atom_positions["N"] - atom_positions["CA"],
1072
+ ey=np.array([1.0, 0.0, 0.0]),
1073
+ translation=atom_positions["N"],
1074
+ )
1075
+ restype_rigid_group_default_frame[restype, 2, :, :] = mat
1076
+
1077
+ # psi-frame to backbone
1078
+ mat = _make_rigid_transformation_4x4(
1079
+ ex=atom_positions["C"] - atom_positions["CA"],
1080
+ ey=atom_positions["CA"] - atom_positions["N"],
1081
+ translation=atom_positions["C"],
1082
+ )
1083
+ restype_rigid_group_default_frame[restype, 3, :, :] = mat
1084
+
1085
+ # chi1-frame to backbone
1086
+ if chi_angles_mask[restype][0]:
1087
+ base_atom_names = chi_angles_atoms[resname][0]
1088
+ base_atom_positions = [atom_positions[name] for name in base_atom_names]
1089
+ mat = _make_rigid_transformation_4x4(
1090
+ ex=base_atom_positions[2] - base_atom_positions[1],
1091
+ ey=base_atom_positions[0] - base_atom_positions[1],
1092
+ translation=base_atom_positions[2],
1093
+ )
1094
+ restype_rigid_group_default_frame[restype, 4, :, :] = mat
1095
+
1096
+ # chi2-frame to chi1-frame
1097
+ # chi3-frame to chi2-frame
1098
+ # chi4-frame to chi3-frame
1099
+ # luckily all rotation axes for the next frame start at (0,0,0) of the
1100
+ # previous frame
1101
+ for chi_idx in range(1, 4):
1102
+ if chi_angles_mask[restype][chi_idx]:
1103
+ axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
1104
+ axis_end_atom_position = atom_positions[axis_end_atom_name]
1105
+ mat = _make_rigid_transformation_4x4(
1106
+ ex=axis_end_atom_position,
1107
+ ey=np.array([-1.0, 0.0, 0.0]),
1108
+ translation=axis_end_atom_position,
1109
+ )
1110
+ restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
1111
+
1112
+
1113
+ _make_rigid_group_constants()
1114
+
1115
+
1116
+ def make_atom14_dists_bounds(overlap_tolerance=1.5, bond_length_tolerance_factor=15.0):
1117
+ """compute upper and lower bounds for bonds to assess violations."""
1118
+ restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
1119
+ restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
1120
+ restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
1121
+ residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
1122
+ for restype, restype_letter in enumerate(restypes):
1123
+ resname = restype_1to3[restype_letter]
1124
+ atom_list = restype_name_to_atom14_names[resname]
1125
+
1126
+ # create lower and upper bounds for clashes
1127
+ for atom1_idx, atom1_name in enumerate(atom_list):
1128
+ if not atom1_name:
1129
+ continue
1130
+ atom1_radius = van_der_waals_radius[atom1_name[0]]
1131
+ for atom2_idx, atom2_name in enumerate(atom_list):
1132
+ if (not atom2_name) or atom1_idx == atom2_idx:
1133
+ continue
1134
+ atom2_radius = van_der_waals_radius[atom2_name[0]]
1135
+ lower = atom1_radius + atom2_radius - overlap_tolerance
1136
+ upper = 1e10
1137
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
1138
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
1139
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
1140
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
1141
+
1142
+ # overwrite lower and upper bounds for bonds and angles
1143
+ for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
1144
+ atom1_idx = atom_list.index(b.atom1_name)
1145
+ atom2_idx = atom_list.index(b.atom2_name)
1146
+ lower = b.length - bond_length_tolerance_factor * b.stddev
1147
+ upper = b.length + bond_length_tolerance_factor * b.stddev
1148
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
1149
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
1150
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
1151
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
1152
+ restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
1153
+ restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
1154
+ return {
1155
+ "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
1156
+ "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
1157
+ "stddev": restype_atom14_bond_stddev, # shape (21,14,14)
1158
+ }
1159
+
1160
+
1161
+ restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
1162
+ restype_atom14_ambiguous_atoms_swap_idx = np.tile(np.arange(14, dtype=int), (21, 1))
1163
+
1164
+
1165
+ def _make_atom14_ambiguity_feats():
1166
+ for res, pairs in residue_atom_renaming_swaps.items():
1167
+ res_idx = restype_order[restype_3to1[res]]
1168
+ for atom1, atom2 in pairs.items():
1169
+ atom1_idx = restype_name_to_atom14_names[res].index(atom1)
1170
+ atom2_idx = restype_name_to_atom14_names[res].index(atom2)
1171
+ restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
1172
+ restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
1173
+ restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = atom2_idx
1174
+ restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = atom1_idx
1175
+
1176
+
1177
+ _make_atom14_ambiguity_feats()
1178
+
1179
+
1180
+ def aatype_to_str_sequence(aatype):
1181
+ return "".join([restypes_with_x[aatype[i]] for i in range(len(aatype))])
1182
+
1183
+
1184
+ # NOTE(thayes): These are computed based on the average CA->C and CA->N norm from rigid_group_atom_positions
1185
+ CA_TO_N_NORM = 1.4591
1186
+ CA_TO_C_NORM = 1.5252
1187
+
1188
+
1189
+ def _make_restype_atom37_to_atom14():
1190
+ """Map from atom37 to atom14 per residue type."""
1191
+ restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
1192
+ for rt in restypes:
1193
+ atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
1194
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
1195
+ restype_atom37_to_atom14.append(
1196
+ [
1197
+ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
1198
+ for name in atom_types
1199
+ ]
1200
+ )
1201
+
1202
+ restype_atom37_to_atom14.append([0] * 37)
1203
+ restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32)
1204
+ return restype_atom37_to_atom14
1205
+
1206
+
1207
+ def _make_restype_atom14_to_atom37():
1208
+ """Map from atom14 to atom37 per residue type."""
1209
+ restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
1210
+ for rt in restypes:
1211
+ atom_names = restype_name_to_atom14_names[restype_1to3[rt]]
1212
+ restype_atom14_to_atom37.append(
1213
+ [(atom_order[name] if name else 0) for name in atom_names]
1214
+ )
1215
+ # Add dummy mapping for restype 'UNK'
1216
+ restype_atom14_to_atom37.append([0] * 14)
1217
+ restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32)
1218
+ return restype_atom14_to_atom37
1219
+
1220
+
1221
+ RESTYPE_ATOM14_TO_ATOM37 = _make_restype_atom14_to_atom37()
1222
+ RESTYPE_ATOM37_TO_ATOM14 = _make_restype_atom37_to_atom14()
1223
+ CHAIN_BREAK_TOKEN = "|"
1224
+
esmfold2_sequential_dataclass.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass, fields, replace
3
+ from typing import TypeVar
4
+
5
+ import numpy as np
6
+
7
+ from .esmfold2_misc import concat_objects, slice_any_object
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class SequentialDataclass(ABC):
14
+ """
15
+ This is a builder on a dataclass that allows for automatic slicing and concatenation.
16
+
17
+ When representing multimodal data, we often have multiple datatypes which have sequence dimensions that are the same (e.g. the length of the protein).
18
+
19
+ When applying a transformation like a crop, we want to apply this to all tensors at the same time (e.g. crop the sequence, structure, and function).
20
+
21
+ We also have some fields that are not sequential (like an id, or data source), which we don't want to crop.
22
+
23
+ The SequentialDataclass abstracts this cropping away, allowing you to define dataclasses that implement `__len__`, `__getitem__` and `concat` automatically.
24
+
25
+ This is done through the `metadata` field, which can take 3 values:
26
+ `sequence` (bool): True or False, tells the dataclass whether this field is a sequential type. Default: False.
27
+ `sequence_dim` (int): Which dimension is the sequential dimension (e.g. for a list of inverse folded sequences, we want to index each sequence in the list, not the list itself). Default: 0.
28
+ `join_token` (Any): What token to use to join when concatenating elements. Default: None.
29
+
30
+
31
+ Example:
32
+
33
+ @dataclass(frozen=True)
34
+ class Foo(SequentialDataclass):
35
+ id: str
36
+ sequence: str = field(metadata={"sequence": True, "join_token": "|"})
37
+ tensor: torch.Tensor = field(metadata={"sequence": True, "join_token": torch.nan})
38
+
39
+ def __len__(self):
40
+ # Must implement the __len__ method
41
+ return len(self.sequence)
42
+
43
+ >>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(5))
44
+ Foo(id='foo', sequence='ABCDE', tensor=tensor([ 0.0252, -0.3335, -0.5143, 0.0251, -1.0717]))
45
+
46
+ >>> foo[1:4]
47
+ Foo(id='foo', sequence='BCD', tensor=tensor([-0.3335, -0.5143, 0.0251]))
48
+
49
+ >>> foo[np.arange(5) < 3]
50
+ Foo(id='foo', sequence='ABC', tensor=tensor([ 0.0252, -0.3335, -0.5143]))
51
+
52
+ >>> Foo.concat([foo[:2], foo[3:]])
53
+ Foo(id='foo', sequence='AB|DE', tensor=tensor([ 0.0252, -0.3335, nan, 0.0251, -1.0717]))
54
+
55
+ # Trying to create a type where the sequence lengths do not match raises an error
56
+ >>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(6))
57
+ ValueError: Mismatch in sequence length for field: tensor. Expected 5, received 6
58
+
59
+ """
60
+
61
+ def __post_init__(self):
62
+ self._check_sequence_lengths_match()
63
+
64
+ @abstractmethod
65
+ def __len__(self):
66
+ raise NotImplementedError
67
+
68
+ def __getitem__(self, idx: int | list[int] | slice | np.ndarray):
69
+ updated_fields = {}
70
+ if isinstance(idx, int):
71
+ # make it so that things remain sequential
72
+ idx = [idx]
73
+
74
+ for fld in fields(self):
75
+ if fld.metadata.get("sequence", False):
76
+ # this is a sequence, should be the same length as all other sequences
77
+ sequence_dim = fld.metadata.get("sequence_dim", 0)
78
+ value = getattr(self, fld.name)
79
+ if value is None:
80
+ continue
81
+ match sequence_dim:
82
+ case 0:
83
+ # sequence is first dimension
84
+ value = getattr(self, fld.name)
85
+ value = slice_any_object(value, idx)
86
+ updated_fields[fld.name] = value
87
+ case 1:
88
+ new_value = [slice_any_object(item, idx) for item in value]
89
+ updated_fields[fld.name] = value.__class__(new_value)
90
+ case _:
91
+ raise NotImplementedError(
92
+ "Arbitrary slicing for different sequence length fields is not implemented"
93
+ )
94
+
95
+ return replace(self, **updated_fields)
96
+
97
+ def _check_sequence_lengths_match(self):
98
+ """Checks if sequence lengths of all "sequence" fields match."""
99
+ for fld in fields(self):
100
+ if fld.metadata.get("sequence", False) and fld.name != "complex":
101
+ # this is a sequence, should be the same length as all other sequences
102
+ sequence_dim = fld.metadata.get("sequence_dim", 0)
103
+ value = getattr(self, fld.name)
104
+ if value is None:
105
+ continue
106
+ match sequence_dim:
107
+ case 0:
108
+ # sequence is first dimension
109
+ value = getattr(self, fld.name)
110
+ if len(value) != len(self):
111
+ raise ValueError(
112
+ f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(value)}"
113
+ )
114
+ case 1:
115
+ for item in value:
116
+ if len(item) != len(self):
117
+ raise ValueError(
118
+ f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(item)}"
119
+ )
120
+ case _:
121
+ raise NotImplementedError(
122
+ "Arbitrary matching for different sequence length fields is not implemented"
123
+ )
124
+
125
+ @classmethod
126
+ def concat(cls, items: list[T], **kwargs) -> T:
127
+ updated_fields = {}
128
+ for fld in fields(cls):
129
+ if fld.metadata.get("sequence", False):
130
+ # this is a sequence, should be the same length as all other sequences
131
+ sequence_dim = fld.metadata.get("sequence_dim", 0)
132
+ join_value = fld.metadata.get("join_token", None)
133
+ if getattr(items[0], fld.name) is None:
134
+ continue
135
+ values = [getattr(item, fld.name) for item in items]
136
+ match sequence_dim:
137
+ case 0:
138
+ # sequence is first dimension
139
+ value = concat_objects(values, join_value)
140
+ updated_fields[fld.name] = value
141
+ case 1:
142
+ new_value = [
143
+ concat_objects(item, join_value) for item in zip(*values)
144
+ ]
145
+ updated_fields[fld.name] = getattr(
146
+ items[0], fld.name
147
+ ).__class__(new_value)
148
+ case _:
149
+ raise NotImplementedError(
150
+ "Arbitrary joining for different sequence length fields is not implemented"
151
+ )
152
+ updated_fields.update(kwargs)
153
+
154
+ return replace(
155
+ items[0], # type: ignore
156
+ **updated_fields,
157
+ )
158
+
esmfold2_system.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import subprocess
3
+ import typing as T
4
+ from pathlib import Path
5
+
6
+ PathLike = T.Union[str, Path]
7
+ PathOrBuffer = T.Union[PathLike, io.StringIO]
8
+
9
+
10
+ def run_subprocess_with_errorcheck(
11
+ *popenargs,
12
+ capture_output: bool = False,
13
+ quiet: bool = False,
14
+ env: dict[str, str] | None = None,
15
+ shell: bool = False,
16
+ executable: str | None = None,
17
+ **kws,
18
+ ) -> subprocess.CompletedProcess:
19
+ """A command similar to subprocess.run, however the errormessage will
20
+ contain the stderr when using this function. This makes it significantly
21
+ easier to diagnose issues.
22
+ """
23
+ try:
24
+ if capture_output:
25
+ stdout = subprocess.PIPE
26
+ elif quiet:
27
+ stdout = subprocess.DEVNULL
28
+ else:
29
+ stdout = None
30
+
31
+ p = subprocess.run(
32
+ *popenargs,
33
+ stderr=subprocess.PIPE,
34
+ stdout=stdout,
35
+ check=True,
36
+ env=env,
37
+ shell=shell,
38
+ executable=executable,
39
+ **kws,
40
+ )
41
+ except subprocess.CalledProcessError as e:
42
+ raise RuntimeError(
43
+ f"Command failed with errorcode {e.returncode}." f"\n\n{e.stderr.decode()}"
44
+ )
45
+ return p
46
+
esmfold2_types.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Re-exports of the canonical SPI dataclasses from input_builder.
2
+
3
+ This module exists so the HF processor and downstream code can import the
4
+ ESMFold2 input types from a single namespace without picking up internal-only
5
+ sibling utilities. The actual definitions live in
6
+ ``esm.utils.structure.input_builder``.
7
+ """
8
+
9
+ from .esmfold2_msa import MSA
10
+ from .esmfold2_parsing import FastaEntry
11
+ from .esmfold2_input_builder import (
12
+ CovalentBond,
13
+ DistogramConditioning,
14
+ DNAInput,
15
+ LigandInput,
16
+ Modification,
17
+ ProteinInput,
18
+ RNAInput,
19
+ StructurePredictionInput,
20
+ )
21
+
22
+ __all__ = [
23
+ "FastaEntry",
24
+ "MSA",
25
+ "Modification",
26
+ "ProteinInput",
27
+ "RNAInput",
28
+ "DNAInput",
29
+ "LigandInput",
30
+ "DistogramConditioning",
31
+ "CovalentBond",
32
+ "StructurePredictionInput",
33
+ ]
34
+
esmfold2_utils_types.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Union
7
+
8
+ from cloudpathlib import CloudPath
9
+
10
+ PathLike = Union[str, Path, CloudPath]
11
+ PathOrBuffer = Union[PathLike, io.StringIO]
12
+
13
+
14
+ @dataclass
15
+ class FunctionAnnotation:
16
+ """Represents an annotation of a protein's function over a range of residues.
17
+
18
+ Fields:
19
+ label (str): An entry in either the function_tokens or residue_annotations tokenizer vocabs
20
+ start (int): Start index of this annotation. 1-indexed, inclusive.
21
+ end (int): End index of this annotation. 1-indexed, inclusive.
22
+ """
23
+
24
+ label: str
25
+ start: int
26
+ end: int
27
+
28
+ def to_tuple(self) -> tuple[str, int, int]:
29
+ return self.label, self.start, self.end
30
+
31
+ def __len__(self) -> int:
32
+ """Length of the annotation."""
33
+ return self.end - self.start + 1
34
+
modeling_esmc.py ADDED
@@ -0,0 +1,1667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Biohub. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PyTorch ESMC model."""
15
+
16
+ import importlib
17
+ import math
18
+ import re
19
+ from dataclasses import dataclass
20
+ from typing import Optional, cast
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+ from torch.nn import functional as F
26
+
27
+ from transformers.modeling_outputs import (
28
+ MaskedLMOutput,
29
+ ModelOutput,
30
+ SequenceClassifierOutput,
31
+ TokenClassifierOutput,
32
+ )
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.utils import (
35
+ auto_docstring,
36
+ can_return_tuple,
37
+ is_flash_attn_2_available,
38
+ logging,
39
+ )
40
+ from .configuration_esmc import ESMCConfig
41
+ from .modeling_esmc_sae import _ESMCSAELayer
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ _CONFIG_FOR_DOC = "ESMCConfig"
46
+
47
+ # Optional accelerated kernels. Pure-PyTorch fallbacks below if absent.
48
+ if is_flash_attn_2_available():
49
+ flash_attn_module = importlib.import_module("flash_attn")
50
+ flash_bert_padding = importlib.import_module("flash_attn.bert_padding")
51
+ flash_attn_varlen_qkvpacked_func = (
52
+ flash_attn_module.flash_attn_varlen_qkvpacked_func
53
+ )
54
+ pad_input = flash_bert_padding.pad_input
55
+ unpad_input = flash_bert_padding.unpad_input
56
+
57
+ _flash_attn_available = True
58
+ else:
59
+ pad_input = unpad_input = flash_attn_varlen_qkvpacked_func = None
60
+ _flash_attn_available = False
61
+
62
+ try:
63
+ flash_rotary = importlib.import_module("flash_attn.ops.triton.rotary")
64
+ apply_triton_rotary = flash_rotary.apply_rotary
65
+
66
+ _flash_attn_rotary_available = torch.cuda.is_available()
67
+ except ImportError:
68
+ apply_triton_rotary = None # type: ignore[assignment]
69
+ _flash_attn_rotary_available = False
70
+
71
+ # Transformer Engine: fused LayerNorm+Linear / LayerNorm+MLP kernels with
72
+ # fp32 reduction inside the LayerNorm. Recommended on GPU for accurate bf16
73
+ # inference; without it the pure-PyTorch fallback drifts ~O(10) in fp32 and
74
+ # ~O(100) in bf16 on the unnormalized residual stream (perplexity stays
75
+ # within rounding noise).
76
+ try:
77
+ te = importlib.import_module("transformer_engine.pytorch")
78
+
79
+ _te_available = True
80
+ except ImportError:
81
+ te = None # type: ignore[assignment]
82
+ _te_available = False
83
+
84
+ # xformers: preferred SDPA implementation on GPU. Provides a fused
85
+ # bf16 attention kernel with deterministic reduction order. Flash
86
+ # Attention 2 and PyTorch's ``F.scaled_dot_product_attention`` are
87
+ # progressively-less-preferred fallbacks.
88
+ try:
89
+ xops = importlib.import_module("xformers.ops")
90
+
91
+ _xformers_available = True
92
+ except ImportError:
93
+ xops = None # type: ignore[assignment]
94
+ _xformers_available = False
95
+
96
+ # Flash Attention 2: secondary SDPA fallback. Used when xformers is not
97
+ # installed; fp16 / bf16 only.
98
+ if _flash_attn_available:
99
+ flash_attn_func = flash_attn_module.flash_attn_func
100
+ else:
101
+ flash_attn_func = None # type: ignore[assignment]
102
+
103
+ if not _te_available:
104
+ logger.warning(
105
+ "ESMC: transformer_engine is not installed; falling back to "
106
+ "pure-PyTorch LayerNorm+Linear / LayerNorm+MLP. Outputs will differ "
107
+ "numerically — measured on the unnormalized residual stream (before "
108
+ "the final LayerNorm), ~O(10) max-diff in fp32 and ~O(100) in bf16; "
109
+ "after the final LayerNorm these shrink to a few ULP and perplexity "
110
+ "stays within rounding noise. Install with "
111
+ "`pip install transformer-engine[pytorch]` to enable fused fp32-"
112
+ "reduction LayerNorm."
113
+ )
114
+
115
+ if not _xformers_available and not _flash_attn_available:
116
+ logger.warning(
117
+ "ESMC: neither xformers nor flash-attn is installed; falling back "
118
+ "to PyTorch ``F.scaled_dot_product_attention``. The attention "
119
+ "reduction order in bf16 differs from a fused kernel by ~1 bf16 "
120
+ "ULP per attention block; compounded across the 80-block stack "
121
+ "this reaches ~O(100) max-diff on the unnormalized residual stream. "
122
+ "Install xformers (preferred) with `pip install xformers` for a "
123
+ "fused attention kernel."
124
+ )
125
+
126
+ if torch.cuda.is_available() and not _flash_attn_rotary_available:
127
+ logger.warning(
128
+ "ESMC: flash-attn rotary kernel not installed; falling back to "
129
+ "pure-PyTorch RoPE. For faster GPU inference run `pip install flash-attn`."
130
+ )
131
+
132
+
133
+ # ---------------------------------------------------------------------------
134
+ # Output dataclasses
135
+ # ---------------------------------------------------------------------------
136
+
137
+
138
+ @dataclass
139
+ class ESMCOutput(ModelOutput):
140
+ """
141
+ Args:
142
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, d_model)`):
143
+ Sequence of hidden states at the output of the last layer, after layer normalisation.
144
+ hidden_states (`torch.FloatTensor`, *optional*):
145
+ Stacked hidden states for all encoder layers.
146
+ Shape ``(n_layers, batch_size, sequence_length, d_model)``.
147
+ Returned when ``output_hidden_states=True``.
148
+ sae_outputs (`dict[str, torch.Tensor]`, *optional*):
149
+ SAE feature magnitudes keyed by SAE model name (sparse tensors).
150
+ Only populated when SAE models have been registered via
151
+ ``add_sae_models`` and ``compute_sae=True``.
152
+ attentions (`tuple(torch.FloatTensor)`, *optional*):
153
+ Per-layer attention weights of shape
154
+ ``(batch_size, num_heads, sequence_length, sequence_length)``.
155
+ Returned when ``output_attentions=True``. Not available on the
156
+ ``flash_attention_2`` path.
157
+ """
158
+
159
+ last_hidden_state: torch.FloatTensor | None = None
160
+ hidden_states: torch.FloatTensor | None = None
161
+ sae_outputs: dict[str, torch.Tensor] | None = None
162
+ attentions: tuple[torch.FloatTensor, ...] | None = None
163
+
164
+
165
+ @dataclass
166
+ class ESMCMaskedLMOutput(MaskedLMOutput):
167
+ """
168
+ Args:
169
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
170
+ Masked language modelling loss. Returned when ``labels`` are provided.
171
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`):
172
+ Prediction scores of the language modelling head.
173
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, d_model)`):
174
+ Final hidden states after layer normalisation.
175
+ hidden_states (`torch.FloatTensor`, *optional*):
176
+ Stacked hidden states. Shape ``(n_layers, batch_size, sequence_length, d_model)``.
177
+ sae_outputs (`dict[str, torch.Tensor]`, *optional*):
178
+ SAE feature magnitudes keyed by SAE model name (sparse tensors).
179
+ attentions (`tuple(torch.FloatTensor)`, *optional*):
180
+ Per-layer attention weights of shape
181
+ ``(batch_size, num_heads, sequence_length, sequence_length)``.
182
+ Returned when ``output_attentions=True``.
183
+ """
184
+
185
+ loss: torch.FloatTensor | None = None
186
+ logits: torch.FloatTensor | None = None
187
+ last_hidden_state: torch.FloatTensor | None = None
188
+ hidden_states: torch.FloatTensor | None = None
189
+ sae_outputs: dict[str, torch.Tensor] | None = None
190
+ attentions: tuple[torch.FloatTensor, ...] | None = None
191
+
192
+
193
+ @dataclass
194
+ class ESMCTokenClassifierOutput(TokenClassifierOutput):
195
+ """
196
+ Args:
197
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
198
+ Token classification loss. Returned when ``labels`` are provided.
199
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_labels)`):
200
+ Classification scores (before SoftMax).
201
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, d_model)`):
202
+ Final hidden states after layer normalisation.
203
+ hidden_states (`torch.FloatTensor`, *optional*):
204
+ Stacked hidden states. Shape ``(n_layers, batch_size, sequence_length, d_model)``.
205
+ sae_outputs (`dict[str, torch.Tensor]`, *optional*):
206
+ SAE feature magnitudes keyed by SAE model name (sparse tensors).
207
+ attentions (`tuple(torch.FloatTensor)`, *optional*):
208
+ Per-layer attention weights of shape
209
+ ``(batch_size, num_heads, sequence_length, sequence_length)``.
210
+ Returned when ``output_attentions=True``.
211
+ """
212
+
213
+ loss: torch.FloatTensor | None = None
214
+ logits: torch.FloatTensor | None = None
215
+ last_hidden_state: torch.FloatTensor | None = None
216
+ hidden_states: torch.FloatTensor | None = None
217
+ sae_outputs: dict[str, torch.Tensor] | None = None
218
+ attentions: tuple[torch.FloatTensor, ...] | None = None
219
+
220
+
221
+ @dataclass
222
+ class ESMCSequenceClassifierOutput(SequenceClassifierOutput):
223
+ """
224
+ Args:
225
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
226
+ Sequence classification loss. Returned when ``labels`` are provided.
227
+ logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
228
+ Classification scores (before SoftMax).
229
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, d_model)`):
230
+ Final hidden states after layer normalisation.
231
+ hidden_states (`torch.FloatTensor`, *optional*):
232
+ Stacked hidden states. Shape ``(n_layers, batch_size, sequence_length, d_model)``.
233
+ sae_outputs (`dict[str, torch.Tensor]`, *optional*):
234
+ SAE feature magnitudes keyed by SAE model name (sparse tensors).
235
+ attentions (`tuple(torch.FloatTensor)`, *optional*):
236
+ Per-layer attention weights of shape
237
+ ``(batch_size, num_heads, sequence_length, sequence_length)``.
238
+ Returned when ``output_attentions=True``.
239
+ """
240
+
241
+ loss: torch.FloatTensor | None = None
242
+ logits: torch.FloatTensor | None = None
243
+ last_hidden_state: torch.FloatTensor | None = None
244
+ hidden_states: torch.FloatTensor | None = None
245
+ sae_outputs: dict[str, torch.Tensor] | None = None
246
+ attentions: tuple[torch.FloatTensor, ...] | None = None
247
+
248
+
249
+ # ---------------------------------------------------------------------------
250
+ # Rotary position embedding helpers
251
+ # ---------------------------------------------------------------------------
252
+
253
+
254
+ def _rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
255
+ if not interleaved:
256
+ x1, x2 = x.chunk(2, dim=-1)
257
+ return torch.cat((-x2, x1), dim=-1)
258
+ x1, x2 = x[..., ::2], x[..., 1::2]
259
+ return torch.stack((-x2, x1), dim=-1).flatten(-2, -1)
260
+
261
+
262
+ def _apply_rotary_emb_torch(
263
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
264
+ ) -> torch.Tensor:
265
+ """Apply rotary position embeddings (pure PyTorch, no Triton dependency).
266
+
267
+ Args:
268
+ x: ``(batch, seqlen, n_heads, head_dim)``
269
+ cos: ``(seqlen, rotary_dim / 2)``
270
+ sin: ``(seqlen, rotary_dim / 2)``
271
+ """
272
+ ro_dim = cos.shape[-1] * 2
273
+ seqlen = x.size(1)
274
+ cos = cos[:seqlen].unsqueeze(1).repeat(1, 1, 2)
275
+ sin = sin[:seqlen].unsqueeze(1).repeat(1, 1, 2)
276
+ return torch.cat(
277
+ [
278
+ x[..., :ro_dim] * cos + _rotate_half(x[..., :ro_dim], interleaved) * sin,
279
+ x[..., ro_dim:],
280
+ ],
281
+ dim=-1,
282
+ )
283
+
284
+
285
+ class RotaryEmbedding(nn.Module):
286
+ """Rotary position embeddings (RoPE) as described in `RoFormer`_.
287
+
288
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
289
+
290
+ Args:
291
+ dim: Size of a single attention head.
292
+ base: Frequency base for the sinusoidal positions.
293
+ interleaved: If ``True`` rotate adjacent pairs (GPT-J style) instead of
294
+ splitting the head dimension in half (GPT-NeoX style).
295
+ scaling_factor: Linear scaling factor applied to position indices.
296
+ pos_idx_in_fp32: Compute position indices in float32 to avoid bf16
297
+ rounding errors at large sequence lengths.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ dim: int,
303
+ base: float = 10000.0,
304
+ interleaved: bool = False,
305
+ scale_base: float | None = None,
306
+ scaling_factor: float = 1.0,
307
+ pos_idx_in_fp32: bool = True,
308
+ device=None,
309
+ ):
310
+ super().__init__()
311
+ self.dim = dim
312
+ self.base = base
313
+ self.interleaved = interleaved
314
+ self.scale_base = scale_base
315
+ self.scaling_factor = scaling_factor
316
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
317
+
318
+ self._seq_len_cached = 0
319
+ self._cos_cached: torch.Tensor | None = None
320
+ self._sin_cached: torch.Tensor | None = None
321
+ self._cos_k_cached: torch.Tensor | None = None
322
+ self._sin_k_cached: torch.Tensor | None = None
323
+
324
+ self.reset_parameters(device=device)
325
+
326
+ def reset_parameters(self, device=None):
327
+ inv_freq = self._compute_inv_freq(device)
328
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
329
+ arange = torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
330
+ scale = (
331
+ (arange + 0.4 * self.dim) / (1.4 * self.dim)
332
+ if self.scale_base is not None
333
+ else None
334
+ )
335
+ self.register_buffer("scale", scale, persistent=False)
336
+
337
+ def _compute_inv_freq(self, device=None) -> torch.Tensor:
338
+ return 1.0 / (
339
+ self.base
340
+ ** (
341
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
342
+ / self.dim
343
+ )
344
+ )
345
+
346
+ def _update_cos_sin_cache(self, seqlen: int, device=None, dtype=None):
347
+ if self.inv_freq.is_meta:
348
+ self.reset_parameters(device=device)
349
+ if (
350
+ seqlen > self._seq_len_cached
351
+ or self._cos_cached is None
352
+ or self._cos_cached.device != device
353
+ or self._cos_cached.dtype != dtype
354
+ or (self.training and self._cos_cached.is_inference())
355
+ ):
356
+ self._seq_len_cached = seqlen
357
+ if self.pos_idx_in_fp32:
358
+ t = (
359
+ torch.arange(seqlen, device=device, dtype=torch.float32)
360
+ / self.scaling_factor
361
+ )
362
+ inv_freq = (
363
+ self.inv_freq.to(torch.float32)
364
+ if self.inv_freq.dtype != torch.float32
365
+ else self.inv_freq
366
+ )
367
+ else:
368
+ t = (
369
+ torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) # type: ignore[call-overload]
370
+ / self.scaling_factor
371
+ )
372
+ inv_freq = self.inv_freq
373
+ freqs = torch.outer(t, inv_freq) # type: ignore[arg-type]
374
+
375
+ if self.scale is None:
376
+ self._cos_cached = torch.cos(freqs).to(dtype)
377
+ self._sin_cached = torch.sin(freqs).to(dtype)
378
+ else:
379
+ _scale: torch.Tensor = self.scale # type: ignore[assignment]
380
+ power = (
381
+ torch.arange(seqlen, dtype=_scale.dtype, device=_scale.device)
382
+ - seqlen // 2
383
+ ) / self.scale_base # type: ignore[operator]
384
+ scale = _scale.to(device=power.device) ** power.unsqueeze(-1)
385
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
386
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
387
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
388
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
389
+
390
+ def _apply(self, fn, recurse=True):
391
+ if self.inv_freq.is_meta:
392
+ self.reset_parameters(device="cpu")
393
+ result = super()._apply(fn, recurse=recurse)
394
+ # Recompute inv_freq on the new device: CPU vs CUDA ``pow`` differ by
395
+ # ~1 fp32 ULP, which compounds across attention layers. Keep this
396
+ # buffer fp32 even when the module is cast to bf16/fp16; otherwise the
397
+ # rounded RoPE frequencies drift from the internal ESMC path.
398
+ new_inv_freq = self._compute_inv_freq(device=self.inv_freq.device)
399
+ self.register_buffer("inv_freq", new_inv_freq, persistent=False)
400
+ self._seq_len_cached = 0
401
+ self._cos_cached = None
402
+ self._sin_cached = None
403
+ self._cos_k_cached = None
404
+ self._sin_k_cached = None
405
+ return result
406
+
407
+ def forward(
408
+ self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0
409
+ ) -> tuple[torch.Tensor, torch.Tensor]:
410
+ """Apply RoPE to query and key tensors.
411
+
412
+ Args:
413
+ q: ``(batch, seqlen, n_heads, head_dim)``
414
+ k: ``(batch, seqlen, n_heads, head_dim)``
415
+ seqlen_offset: Offset used in incremental decoding.
416
+
417
+ Returns:
418
+ Tuple of rotated ``(q, k)`` tensors with the same shape as the inputs.
419
+ """
420
+ self._update_cos_sin_cache(
421
+ q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype
422
+ )
423
+ assert self._cos_cached is not None and self._sin_cached is not None
424
+
425
+ if self.scale is not None:
426
+ raise NotImplementedError("XPos scaling is not supported in this path.")
427
+
428
+ cos = self._cos_cached[seqlen_offset:]
429
+ sin = self._sin_cached[seqlen_offset:]
430
+
431
+ if _flash_attn_rotary_available and q.device.type == "cuda":
432
+ q_rot = apply_triton_rotary(q, cos, sin, interleaved=self.interleaved) # type: ignore[misc]
433
+ k_rot = apply_triton_rotary(k, cos, sin, interleaved=self.interleaved) # type: ignore[misc]
434
+ else:
435
+ q_rot = _apply_rotary_emb_torch(q, cos, sin, self.interleaved)
436
+ k_rot = _apply_rotary_emb_torch(k, cos, sin, self.interleaved)
437
+ return q_rot, k_rot
438
+
439
+
440
+ class _TritonRotaryEmbedding(RotaryEmbedding):
441
+ """RoPE variant that delegates to the Flash-Attention Triton kernel.
442
+
443
+ Only used inside :class:`_FlashMultiHeadAttention` when Flash Attention 2
444
+ is available. The ``forward`` signature differs from :class:`RotaryEmbedding`
445
+ because Flash Attention packs Q, K, V together.
446
+ """
447
+
448
+ def forward(
449
+ self, qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int
450
+ ) -> torch.Tensor: # type: ignore[override]
451
+ """Apply RoPE in-place to a packed ``(N, 3, n_heads, head_dim)`` tensor."""
452
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
453
+ assert self._cos_cached is not None and self._sin_cached is not None
454
+ assert apply_triton_rotary is not None
455
+
456
+ apply_triton_rotary(
457
+ qkv[:, 0],
458
+ self._cos_cached,
459
+ self._sin_cached,
460
+ cu_seqlens=cu_seqlens,
461
+ max_seqlen=max_seqlen,
462
+ inplace=True,
463
+ )
464
+ apply_triton_rotary(
465
+ qkv[:, 1],
466
+ self._cos_cached,
467
+ self._sin_cached,
468
+ cu_seqlens=cu_seqlens,
469
+ max_seqlen=max_seqlen,
470
+ inplace=True,
471
+ )
472
+ return qkv
473
+
474
+
475
+ # ---------------------------------------------------------------------------
476
+ # Feed-forward network helpers
477
+ # ---------------------------------------------------------------------------
478
+
479
+
480
+ def _swiglu_hidden_dim(expansion_ratio: float, d_model: int) -> int:
481
+ """Round hidden dim to the nearest multiple of 256 after applying expansion_ratio."""
482
+ return int(((expansion_ratio * d_model) + 255) // 256 * 256)
483
+
484
+
485
+ class _SwiGLU(nn.Module):
486
+ """SwiGLU activation: ``silu(x1) * x2`` where ``x`` is split along the last dim."""
487
+
488
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
489
+ x1, x2 = x.chunk(2, dim=-1)
490
+ return F.silu(x1) * x2
491
+
492
+
493
+ class _PyTorchLayerNormLinear(nn.Module):
494
+ """LayerNorm followed by a Linear projection, sharing the parameter
495
+ names ``layer_norm_weight``, ``layer_norm_bias`` and ``weight`` so the
496
+ state-dict layout matches the accelerated TE module loaded on GPU.
497
+ """
498
+
499
+ def __init__(self, d_in: int, d_out: int, eps: float = 1e-5) -> None:
500
+ super().__init__()
501
+ self.d_in = d_in
502
+ self.eps = eps
503
+ self.layer_norm_weight = nn.Parameter(torch.ones(d_in))
504
+ self.layer_norm_bias = nn.Parameter(torch.zeros(d_in))
505
+ self.weight = nn.Parameter(torch.empty(d_out, d_in))
506
+ nn.init.normal_(self.weight, std=0.02)
507
+
508
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
509
+ x = F.layer_norm(
510
+ x, (self.d_in,), self.layer_norm_weight, self.layer_norm_bias, self.eps
511
+ )
512
+ return F.linear(x, self.weight)
513
+
514
+
515
+ class _PyTorchLayerNormMLP(nn.Module):
516
+ """LayerNorm + SwiGLU MLP, sharing the parameter names
517
+ ``layer_norm_weight``, ``layer_norm_bias``, ``fc1_weight``,
518
+ ``fc2_weight`` so the state-dict layout matches the accelerated TE
519
+ module loaded on GPU.
520
+ """
521
+
522
+ def __init__(
523
+ self, hidden_size: int, ffn_hidden_size: int, eps: float = 1e-5
524
+ ) -> None:
525
+ super().__init__()
526
+ self.hidden_size = hidden_size
527
+ self.ffn_hidden_size = ffn_hidden_size
528
+ self.eps = eps
529
+ self.layer_norm_weight = nn.Parameter(torch.ones(hidden_size))
530
+ self.layer_norm_bias = nn.Parameter(torch.zeros(hidden_size))
531
+ self.fc1_weight = nn.Parameter(torch.empty(2 * ffn_hidden_size, hidden_size))
532
+ self.fc2_weight = nn.Parameter(torch.empty(hidden_size, ffn_hidden_size))
533
+ nn.init.normal_(self.fc1_weight, std=0.02)
534
+ nn.init.normal_(self.fc2_weight, std=0.02)
535
+
536
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
537
+ x = F.layer_norm(
538
+ x,
539
+ (self.hidden_size,),
540
+ self.layer_norm_weight,
541
+ self.layer_norm_bias,
542
+ self.eps,
543
+ )
544
+ x = F.linear(x, self.fc1_weight)
545
+ x1, x2 = x.chunk(2, dim=-1)
546
+ x = F.silu(x1) * x2
547
+ return F.linear(x, self.fc2_weight)
548
+
549
+
550
+ def _swiglu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool) -> nn.Module:
551
+ """LayerNorm + SwiGLU MLP. Uses Transformer Engine's fused LN+MLP when
552
+ available; otherwise returns the pure-PyTorch fallback with matching
553
+ state-dict layout."""
554
+ assert not bias, "ESMC was trained with bias=False; bias=True not supported"
555
+ hidden = _swiglu_hidden_dim(expansion_ratio, d_model)
556
+ if _te_available:
557
+ return te.LayerNormMLP( # type: ignore[union-attr]
558
+ hidden_size=d_model,
559
+ ffn_hidden_size=hidden,
560
+ bias=bias,
561
+ activation="swiglu",
562
+ init_method=None,
563
+ output_layer_init_method=None,
564
+ )
565
+ return _PyTorchLayerNormMLP(hidden_size=d_model, ffn_hidden_size=hidden)
566
+
567
+
568
+ def _make_attn_layernorm_qkv(d_model: int, bias: bool) -> nn.Module:
569
+ """LayerNorm + fused QKV projection. Uses Transformer Engine when
570
+ available; pure-PyTorch fallback otherwise."""
571
+ assert not bias, "ESMC was trained with bias=False; bias=True not supported"
572
+ if _te_available:
573
+ return te.LayerNormLinear( # type: ignore[union-attr]
574
+ d_model, d_model * 3, bias=bias, init_method=None
575
+ )
576
+ return _PyTorchLayerNormLinear(d_model, d_model * 3)
577
+
578
+
579
+ def _make_attn_out_proj(d_model: int, bias: bool) -> nn.Module:
580
+ """Attention output projection. Uses Transformer Engine when available;
581
+ pure-PyTorch ``nn.Linear`` otherwise."""
582
+ if _te_available:
583
+ return te.Linear( # type: ignore[union-attr]
584
+ d_model, d_model, bias=bias, init_method=None
585
+ )
586
+ return nn.Linear(d_model, d_model, bias=bias)
587
+
588
+
589
+ def _gelu_ln_ffn(d_model: int, expansion_ratio: float, bias: bool) -> nn.Sequential:
590
+ hidden = int(expansion_ratio * d_model)
591
+ return nn.Sequential(
592
+ nn.LayerNorm(d_model),
593
+ nn.Linear(d_model, hidden, bias=bias),
594
+ nn.GELU(),
595
+ nn.Linear(hidden, d_model, bias=bias),
596
+ )
597
+
598
+
599
+ # ---------------------------------------------------------------------------
600
+ # Attention
601
+ # ---------------------------------------------------------------------------
602
+
603
+
604
+ def _scaled_dot_product_attention(
605
+ q: torch.Tensor,
606
+ k: torch.Tensor,
607
+ v: torch.Tensor,
608
+ *,
609
+ n_heads: int,
610
+ d_head: int,
611
+ seq_id: torch.Tensor | None,
612
+ ) -> torch.Tensor:
613
+ """Scaled dot-product attention with optional chain-aware mask.
614
+
615
+ Dispatches in order of preference:
616
+ 1. xformers ``memory_efficient_attention`` — preferred fused kernel,
617
+ requires ``xformers``, no chain mask.
618
+ 2. Flash Attention 2 (``flash_attn.flash_attn_func``) — secondary
619
+ fused kernel, requires ``flash-attn``, no chain mask, fp16 /
620
+ bf16 only.
621
+ 3. PyTorch's ``F.scaled_dot_product_attention`` — last-resort path;
622
+ also handles the chain-aware mask when ``seq_id`` is present
623
+ and the fp32 path that Flash Attention 2 does not support.
624
+ """
625
+ if seq_id is None and _xformers_available:
626
+ b, s, _ = q.shape
627
+ q4 = q.view(b, s, n_heads, d_head)
628
+ k4 = k.view(b, s, n_heads, d_head)
629
+ v4 = v.view(b, s, n_heads, d_head)
630
+ context = xops.memory_efficient_attention( # type: ignore[union-attr]
631
+ q4, k4, v4, attn_bias=None, scale=d_head**-0.5
632
+ )
633
+ return context.reshape(b, s, n_heads * d_head)
634
+ if (
635
+ seq_id is None
636
+ and _flash_attn_available
637
+ and q.dtype in (torch.float16, torch.bfloat16)
638
+ ):
639
+ b, s, _ = q.shape
640
+ q4 = q.view(b, s, n_heads, d_head)
641
+ k4 = k.view(b, s, n_heads, d_head)
642
+ v4 = v.view(b, s, n_heads, d_head)
643
+ context = flash_attn_func( # type: ignore[misc]
644
+ q4, k4, v4, dropout_p=0.0, softmax_scale=d_head**-0.5
645
+ )
646
+ return context.reshape(b, s, n_heads * d_head) # type: ignore[union-attr]
647
+ b, s, _ = q.shape
648
+ q = q.view(b, s, n_heads, -1).transpose(1, 2)
649
+ k = k.view(b, s, n_heads, -1).transpose(1, 2)
650
+ v = v.view(b, s, n_heads, -1).transpose(1, 2)
651
+ if seq_id is not None:
652
+ mask = (seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)).unsqueeze(1)
653
+ context = F.scaled_dot_product_attention(q, k, v, mask)
654
+ else:
655
+ context = F.scaled_dot_product_attention(q, k, v)
656
+ _, h, _, d_out = context.shape
657
+ return context.transpose(1, 2).reshape(b, s, h * d_out)
658
+
659
+
660
+ class MultiHeadAttention(nn.Module):
661
+ """Multi-head self-attention with QK LayerNorm and RoPE.
662
+
663
+ Args:
664
+ d_model: Model hidden dimension.
665
+ n_heads: Number of attention heads.
666
+ bias: Whether to use bias in linear layers.
667
+ qk_layernorm: Whether to apply LayerNorm to queries and keys before
668
+ computing attention scores.
669
+ """
670
+
671
+ def __init__(
672
+ self, d_model: int, n_heads: int, bias: bool = False, qk_layernorm: bool = True
673
+ ):
674
+ super().__init__()
675
+ self.d_model = d_model
676
+ self.n_heads = n_heads
677
+ self.d_head = d_model // n_heads
678
+
679
+ assert not bias, "ESMC was trained with bias=False; bias=True not supported"
680
+ self.layernorm_qkv = _make_attn_layernorm_qkv(d_model, bias)
681
+ self.out_proj = _make_attn_out_proj(d_model, bias)
682
+
683
+ if qk_layernorm:
684
+ self.q_ln = nn.LayerNorm(d_model, bias=bias)
685
+ self.k_ln = nn.LayerNorm(d_model, bias=bias)
686
+ else:
687
+ self.q_ln = nn.Identity()
688
+ self.k_ln = nn.Identity()
689
+
690
+ self.rotary = RotaryEmbedding(d_model // n_heads)
691
+
692
+ def _apply_rotary(
693
+ self, q: torch.Tensor, k: torch.Tensor
694
+ ) -> tuple[torch.Tensor, torch.Tensor]:
695
+ q = q.unflatten(-1, (self.n_heads, self.d_head))
696
+ k = k.unflatten(-1, (self.n_heads, self.d_head))
697
+ q, k = self.rotary(q, k)
698
+ q = q.flatten(-2, -1)
699
+ k = k.flatten(-2, -1)
700
+ return q, k
701
+
702
+ def forward(
703
+ self,
704
+ x: torch.Tensor,
705
+ seq_id: torch.Tensor | None,
706
+ output_attentions: bool = False,
707
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
708
+ """Return ``(context, attn_weights)``.
709
+
710
+ ``attn_weights`` is ``None`` unless ``output_attentions=True`` — the
711
+ fused SDPA backends (xformers, flash-attn 2, ``F.scaled_dot_product_attention``)
712
+ don't expose attention probabilities, so capturing them forces a
713
+ materialized ``softmax(Q @ K.T / sqrt(d)) @ V`` path with shape
714
+ ``(B, H, L, L)``.
715
+ """
716
+ qkv = self.layernorm_qkv(x)
717
+ q, k, v = torch.chunk(qkv, 3, dim=-1)
718
+ q = self.q_ln(q).to(q.dtype)
719
+ k = self.k_ln(k).to(q.dtype)
720
+ q, k = self._apply_rotary(q, k)
721
+
722
+ b, s, _ = q.shape
723
+
724
+ if output_attentions:
725
+ # Manual SDPA so attention probabilities are observable.
726
+ q4 = q.view(b, s, self.n_heads, self.d_head).transpose(1, 2)
727
+ k4 = k.view(b, s, self.n_heads, self.d_head).transpose(1, 2)
728
+ v4 = v.view(b, s, self.n_heads, self.d_head).transpose(1, 2)
729
+ scale = self.d_head**-0.5
730
+ attn_scores = (q4 @ k4.transpose(-2, -1)) * scale
731
+ if seq_id is not None:
732
+ mask = (seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)).unsqueeze(1)
733
+ attn_scores = attn_scores.masked_fill(~mask, float("-inf"))
734
+ attn_weights = torch.softmax(attn_scores, dim=-1)
735
+ context = (attn_weights @ v4).transpose(1, 2).reshape(b, s, -1)
736
+ return self.out_proj(context), attn_weights
737
+
738
+ context = _scaled_dot_product_attention(
739
+ q, k, v, n_heads=self.n_heads, d_head=self.d_head, seq_id=seq_id
740
+ )
741
+ return self.out_proj(context), None
742
+
743
+
744
+ class _FlashMultiHeadAttention(MultiHeadAttention):
745
+ """Flash-Attention 2 variant of :class:`MultiHeadAttention`."""
746
+
747
+ def __init__(
748
+ self, d_model: int, n_heads: int, bias: bool = False, qk_layernorm: bool = True
749
+ ):
750
+ super().__init__(
751
+ d_model=d_model, n_heads=n_heads, bias=bias, qk_layernorm=qk_layernorm
752
+ )
753
+ self.rotary = _TritonRotaryEmbedding(d_model // n_heads)
754
+
755
+ def forward(
756
+ self,
757
+ x: torch.Tensor,
758
+ seq_id: torch.Tensor | None,
759
+ output_attentions: bool = False,
760
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
761
+ if output_attentions:
762
+ raise ValueError(
763
+ "output_attentions=True is not supported with "
764
+ "attn_implementation='flash_attention_2'. "
765
+ "Re-load the model with attn_implementation='sdpa' (or 'eager')."
766
+ )
767
+ assert seq_id is not None and seq_id.dtype == torch.bool
768
+
769
+ seqlens = seq_id.sum(dim=-1, dtype=torch.int32)
770
+ cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
771
+ max_seqlen = int(seqlens.max().item())
772
+
773
+ qkv = self.layernorm_qkv(x)
774
+ q, k, v = torch.chunk(qkv, 3, dim=-1)
775
+ q = self.q_ln(q).to(q.dtype)
776
+ k = self.k_ln(k).to(q.dtype)
777
+
778
+ # ``q``/``k``/``v`` are 2D ``(T, D)`` here: the parent ``ESMCModel.forward``
779
+ # calls ``unpad_input`` before the transformer stack to produce the
780
+ # varlen-flat layout that ``flash_attn_varlen_qkvpacked_func`` requires.
781
+ T = q.shape[0]
782
+ qkv_packed = torch.stack([q, k, v], dim=1).view(T, 3, self.n_heads, self.d_head)
783
+ qkv_packed = self.rotary(qkv_packed, cu_seqlens, max_seqlen)
784
+
785
+ context = flash_attn_varlen_qkvpacked_func( # type: ignore[misc]
786
+ qkv_packed, cu_seqlens, max_seqlen, softmax_scale=self.d_head**-0.5
787
+ )
788
+ n_out, h_out, d_out = context.shape # type: ignore[union-attr]
789
+ return (
790
+ self.out_proj(context.reshape(n_out, h_out * d_out)), # type: ignore[union-attr]
791
+ None,
792
+ )
793
+
794
+
795
+ # ---------------------------------------------------------------------------
796
+ # Transformer blocks
797
+ # ---------------------------------------------------------------------------
798
+
799
+
800
+ class UnifiedTransformerBlock(nn.Module):
801
+ """Single transformer block: pre-norm attention + pre-norm FFN with residual scaling.
802
+
803
+ Args:
804
+ d_model: Hidden dimension.
805
+ n_heads: Number of attention heads.
806
+ use_flash_attn: Use Flash Attention 2 kernel if available.
807
+ bias: Whether linear layers include bias terms.
808
+ expansion_ratio: Hidden-dim expansion ratio for the FFN.
809
+ residue_scaling_factor: Scales residual connections to stabilise deep
810
+ networks (``1 / sqrt(n_layers / 36)`` is the ESM3 scheme).
811
+ qk_layernorm: Whether to apply QK LayerNorm in attention.
812
+ ffn_type: Feed-forward activation: ``"swiglu"`` or ``"gelu"``.
813
+ """
814
+
815
+ def __init__(
816
+ self,
817
+ d_model: int,
818
+ n_heads: int,
819
+ use_flash_attn: bool = False,
820
+ bias: bool = False,
821
+ expansion_ratio: float = 4.0,
822
+ residue_scaling_factor: float = 1.0,
823
+ qk_layernorm: bool = True,
824
+ ffn_type: str = "swiglu",
825
+ ):
826
+ super().__init__()
827
+
828
+ attn_cls = _FlashMultiHeadAttention if use_flash_attn else MultiHeadAttention
829
+ self.attn = attn_cls(d_model, n_heads, bias=bias, qk_layernorm=qk_layernorm)
830
+
831
+ if ffn_type == "swiglu":
832
+ self.ffn = _swiglu_ln_ffn(d_model, expansion_ratio, bias)
833
+ elif ffn_type == "gelu":
834
+ self.ffn = _gelu_ln_ffn(d_model, expansion_ratio, bias)
835
+ else:
836
+ raise ValueError(
837
+ f"Unknown ffn_type: {ffn_type!r}. Choose 'swiglu' or 'gelu'."
838
+ )
839
+
840
+ self.scaling_factor = residue_scaling_factor
841
+
842
+ def forward(
843
+ self,
844
+ x: torch.Tensor,
845
+ sequence_id: torch.Tensor | None,
846
+ output_attentions: bool = False,
847
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
848
+ """
849
+ Args:
850
+ x: ``(batch, seq_len, d_model)``
851
+ sequence_id: ``(batch, seq_len)`` chain-ID tensor used to restrict
852
+ attention to tokens within the same chain. SDPA blocks accept
853
+ an integer tensor (``-1`` marks padding); the flash-attn block
854
+ takes a ``bool`` padding mask — the caller selects which.
855
+ ``None`` skips chain-aware masking entirely (fast path).
856
+ output_attentions: When ``True``, returns the per-head attention
857
+ weights for this block alongside the residual output.
858
+
859
+ Returns:
860
+ ``(output, attn_weights_or_None)``. Shape of ``output`` is
861
+ ``(batch, seq_len, d_model)``; ``attn_weights`` shape is
862
+ ``(batch, num_heads, seq_len, seq_len)`` or ``None``.
863
+ """
864
+ attn_out, attn_weights = self.attn(
865
+ x, sequence_id, output_attentions=output_attentions
866
+ )
867
+ x = x + attn_out / self.scaling_factor
868
+ x = x + self.ffn(x) / self.scaling_factor
869
+ return x, attn_weights
870
+
871
+
872
+ class TransformerStack(nn.Module):
873
+ """Stack of :class:`UnifiedTransformerBlock` layers with a final LayerNorm.
874
+
875
+ Args:
876
+ d_model: Hidden dimension.
877
+ n_heads: Number of attention heads.
878
+ n_layers: Number of transformer blocks.
879
+ scale_residue: When ``True`` apply ESM3 residue scaling
880
+ ``sqrt(n_layers / 36)`` to each block.
881
+ bias: Bias flag forwarded to every sub-module.
882
+ qk_layernorm: QK LayerNorm flag forwarded to every block.
883
+ ffn_type: FFN activation type (``"swiglu"`` or ``"gelu"``).
884
+ expansion_ratio: FFN expansion ratio.
885
+ use_flash_attn: Use Flash Attention 2 kernel when available.
886
+ """
887
+
888
+ def __init__(
889
+ self,
890
+ d_model: int,
891
+ n_heads: int,
892
+ n_layers: int,
893
+ scale_residue: bool = True,
894
+ bias: bool = False,
895
+ qk_layernorm: bool = True,
896
+ ffn_type: str = "swiglu",
897
+ expansion_ratio: float = 8 / 3,
898
+ use_flash_attn: bool = False,
899
+ ):
900
+ super().__init__()
901
+ self.blocks = nn.ModuleList(
902
+ [
903
+ UnifiedTransformerBlock(
904
+ d_model,
905
+ n_heads,
906
+ use_flash_attn=use_flash_attn,
907
+ residue_scaling_factor=math.sqrt(n_layers / 36)
908
+ if scale_residue
909
+ else 1.0,
910
+ expansion_ratio=expansion_ratio,
911
+ bias=bias,
912
+ qk_layernorm=qk_layernorm,
913
+ ffn_type=ffn_type,
914
+ )
915
+ for _ in range(n_layers)
916
+ ]
917
+ )
918
+ self.norm = nn.LayerNorm(d_model, bias=False)
919
+
920
+ def forward(
921
+ self,
922
+ x: torch.Tensor,
923
+ sequence_id: torch.Tensor | None = None,
924
+ layers_to_collect: list[int] | None = None,
925
+ output_attentions: bool = False,
926
+ ) -> tuple[
927
+ torch.Tensor,
928
+ torch.Tensor,
929
+ tuple[torch.Tensor, ...],
930
+ tuple[torch.Tensor, ...] | None,
931
+ ]:
932
+ """Run the full transformer stack.
933
+
934
+ Args:
935
+ x: ``(batch, seq_len, d_model)``
936
+ sequence_id: Optional chain-id tensor forwarded to each block.
937
+ layers_to_collect: Layer indices (0-based pre-block inputs plus
938
+ ``n_layers`` for the post-norm output) whose hidden states
939
+ should be returned.
940
+ output_attentions: When ``True``, collects the per-block attention
941
+ weights and returns them as the fourth tuple element.
942
+
943
+ Returns:
944
+ ``(post_norm, pre_norm, hidden_states, attentions)`` where
945
+ ``hidden_states`` is a (possibly empty) tuple of tensors and
946
+ ``attentions`` is a tuple of per-block ``(B, H, L, L)`` tensors
947
+ or ``None`` when ``output_attentions`` is ``False``.
948
+ """
949
+ if layers_to_collect is None:
950
+ layers_to_collect = []
951
+
952
+ collected: list[torch.Tensor] = []
953
+ all_attentions: list[torch.Tensor] = []
954
+ for layer_idx, block in enumerate(self.blocks):
955
+ if layer_idx in layers_to_collect:
956
+ collected.append(x)
957
+ x, attn_weights = block(x, sequence_id, output_attentions=output_attentions)
958
+ if output_attentions and attn_weights is not None:
959
+ all_attentions.append(attn_weights)
960
+
961
+ norm_x = self.norm(x)
962
+ if len(self.blocks) in layers_to_collect:
963
+ collected.append(norm_x)
964
+
965
+ attentions = tuple(all_attentions) if output_attentions else None
966
+ return norm_x, x, tuple(collected), attentions
967
+
968
+
969
+ # ---------------------------------------------------------------------------
970
+ # Pre-trained model base class
971
+ # ---------------------------------------------------------------------------
972
+
973
+
974
+ @auto_docstring
975
+ class ESMCPreTrainedModel(PreTrainedModel):
976
+ """Base class for ESMC models.
977
+
978
+ Handles weight initialisation and declares module-level capabilities.
979
+ """
980
+
981
+ config_class = ESMCConfig
982
+ base_model_prefix = "esmc"
983
+ supports_gradient_checkpointing = False
984
+ _supports_sdpa = True
985
+ _supports_flash_attn = True
986
+ _supports_attention_backend = True
987
+ _no_split_modules = ["UnifiedTransformerBlock"]
988
+ _keys_to_ignore_on_load_unexpected = [r"\._extra_state$"]
989
+
990
+ def _init_weights(self, module: nn.Module):
991
+ std = self.config.initializer_range
992
+ if isinstance(module, nn.Linear):
993
+ module.weight.data.normal_(mean=0.0, std=std)
994
+ if module.bias is not None:
995
+ module.bias.data.zero_()
996
+ elif isinstance(module, RotaryEmbedding):
997
+ module.reset_parameters(device=self.device)
998
+
999
+
1000
+ # ---------------------------------------------------------------------------
1001
+ # Base encoder model
1002
+ # ---------------------------------------------------------------------------
1003
+
1004
+
1005
+ @auto_docstring
1006
+ class ESMCModel(ESMCPreTrainedModel):
1007
+ """The bare ESMC encoder outputting raw hidden states.
1008
+
1009
+ ESMC is a protein language model trained by EvolutionaryScale using a
1010
+ masked-token objective over amino acid sequences. The architecture is a
1011
+ standard Transformer encoder with RoPE positional embeddings, QK LayerNorm,
1012
+ and SwiGLU feed-forward networks.
1013
+
1014
+ Args:
1015
+ config: An :class:`ESMCConfig` instance.
1016
+ """
1017
+
1018
+ def __init__(self, config: ESMCConfig):
1019
+ super().__init__(config)
1020
+ self._use_flash_attn = (
1021
+ _flash_attn_available and config._attn_implementation == "flash_attention_2"
1022
+ )
1023
+ self.embed = nn.Embedding(config.vocab_size, config.d_model)
1024
+ self.transformer = TransformerStack(
1025
+ config.d_model,
1026
+ config.n_heads,
1027
+ config.n_layers,
1028
+ use_flash_attn=self._use_flash_attn,
1029
+ )
1030
+ self._sae_models: nn.ModuleDict = nn.ModuleDict()
1031
+ self.post_init()
1032
+
1033
+ def get_input_embeddings(self) -> nn.Embedding:
1034
+ return self.embed
1035
+
1036
+ def set_input_embeddings(self, value: nn.Embedding):
1037
+ self.embed = value
1038
+
1039
+ def add_sae_models(self, sae_models: list[_ESMCSAELayer]) -> None:
1040
+ """Register one or more SAEs obtained from an :class:`ESMCSAEModel`.
1041
+
1042
+ Each is keyed by ``f"layer{N}"`` (the backbone-layer index ``N`` the
1043
+ SAE is trained against, set by
1044
+ :meth:`ESMCSAEModel.initialize_layers`). Attaching two SAEs for the
1045
+ same backbone layer raises — only one SAE per layer can be active.
1046
+
1047
+ Example::
1048
+
1049
+ sae = ESMCSAEModel.from_pretrained(
1050
+ "biohub/esmc-600m-2024-12-sae-k64-codebook16384"
1051
+ )
1052
+ sae.initialize_layers([27, 33])
1053
+ model.add_sae_models([sae.layers["27"], sae.layers["33"]])
1054
+ """
1055
+ for layer in sae_models:
1056
+ assert isinstance(layer, _ESMCSAELayer), (
1057
+ f"Expected an SAE layer (model.layers['<idx>']), got "
1058
+ f"{type(layer).__name__}."
1059
+ )
1060
+ key = f"layer{int(layer.layer)}"
1061
+ if key in self._sae_models:
1062
+ raise ValueError(
1063
+ f"An SAE is already registered at {key!r}. Only one SAE "
1064
+ "per backbone layer can be active — pick a different "
1065
+ "layer on one of them, or attach in a fresh model."
1066
+ )
1067
+ self._sae_models[key] = layer
1068
+
1069
+ _SAE_KEY_RE = re.compile(r"layer(\d+)")
1070
+
1071
+ def _get_sae_layer_num_requested(self, model_name: str) -> int:
1072
+ """Recover the backbone-layer index from a key written by
1073
+ :meth:`add_sae_models` (``"layer{N}"`` → ``N``)."""
1074
+ match = self._SAE_KEY_RE.fullmatch(model_name)
1075
+ assert (
1076
+ match is not None
1077
+ ), f"Unexpected SAE key {model_name!r}; expected 'layer{{N}}'."
1078
+ return int(match.group(1))
1079
+
1080
+ def _validate_sae_inputs(self, input_ids: torch.Tensor) -> None:
1081
+ assert torch.all(input_ids != self.config.mask_token_id), (
1082
+ "SAE inputs must not contain mask tokens. "
1083
+ "SAEs were trained on unmasked sequences."
1084
+ )
1085
+
1086
+ def _get_sae_outputs(
1087
+ self,
1088
+ hidden_states: torch.Tensor,
1089
+ layers_to_collect: list[int],
1090
+ token_mask: torch.Tensor,
1091
+ normalize_sae: bool = False,
1092
+ ) -> dict[str, torch.Tensor]:
1093
+ """Run all registered SAEs and return their feature magnitudes.
1094
+
1095
+ Args:
1096
+ hidden_states: Stacked tensor of shape
1097
+ ``(len(layers_to_collect), batch, seq_len, d_model)``.
1098
+ layers_to_collect: The ESMC layer indices that were collected,
1099
+ in the same order as the first dim of ``hidden_states``.
1100
+ token_mask: Boolean mask ``(batch, seq_len)`` — ``True`` for
1101
+ real (non-padding) tokens.
1102
+ normalize_sae: When ``True``, scale features by ``idf / max``
1103
+ using the per-feature stats trained alongside each SAE.
1104
+ """
1105
+ layer_to_idx = {layer: idx for idx, layer in enumerate(layers_to_collect)}
1106
+ sae_outputs: dict[str, torch.Tensor] = {}
1107
+
1108
+ for model_name, sae_module in self._sae_models.items():
1109
+ # `nn.ModuleDict` only stores `nn.Module`s at the type level;
1110
+ # ``add_sae_models`` enforces that each entry is an ``_ESMCSAELayer``.
1111
+ assert isinstance(sae_module, _ESMCSAELayer)
1112
+ layer: _ESMCSAELayer = sae_module
1113
+ requested_layer = self._get_sae_layer_num_requested(model_name)
1114
+ layer_idx = layer_to_idx[requested_layer]
1115
+ layer_states = hidden_states[layer_idx].clone().to(self.device)
1116
+
1117
+ sae_out = layer.get_sae_output(layer_states, token_mask)
1118
+ features = sae_out.feature_magnitudes.detach()
1119
+
1120
+ if normalize_sae:
1121
+ # ``register_buffer`` is typed as ``Tensor | Module`` on
1122
+ # ``nn.Module``; narrow here since these are Tensors.
1123
+ idf = cast(torch.Tensor, layer.idf)
1124
+ max_val = cast(torch.Tensor, layer.max)
1125
+ features = (features / max_val) * idf
1126
+
1127
+ sae_outputs[model_name] = features.to_sparse()
1128
+
1129
+ return sae_outputs
1130
+
1131
+ @can_return_tuple
1132
+ @auto_docstring
1133
+ def forward(
1134
+ self,
1135
+ input_ids: Optional[torch.Tensor] = None,
1136
+ attention_mask: Optional[torch.Tensor] = None,
1137
+ sequence_id: Optional[torch.Tensor] = None,
1138
+ output_hidden_states: Optional[bool] = None,
1139
+ output_attentions: Optional[bool] = None,
1140
+ return_dict: Optional[bool] = None,
1141
+ compute_sae: bool = True,
1142
+ normalize_sae: bool = False,
1143
+ ) -> tuple[torch.Tensor, ...] | ESMCOutput:
1144
+ r"""
1145
+ sequence_id (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1146
+ Integer chain-ID tensor for chain-aware attention masking. Tokens with the same
1147
+ non-negative integer value can attend to each other; tokens with different values
1148
+ cannot (cross-chain masking). Padding positions should be set to ``-1``.
1149
+ When provided, ``attention_mask`` is ignored. The ``flash_attention_2`` backend
1150
+ only supports single-chain inputs (all non-padding values must be ``0``); pass
1151
+ multi-chain ``sequence_id`` with ``attn_implementation='sdpa'`` (or ``'eager'``).
1152
+ output_attentions (`bool`, *optional*):
1153
+ Whether to return the per-block attention weights of shape
1154
+ ``(batch_size, num_heads, sequence_length, sequence_length)``.
1155
+ Forces a manual-SDPA path inside :class:`MultiHeadAttention` so the
1156
+ attention probabilities are observable; raises on the
1157
+ ``flash_attention_2`` path.
1158
+ compute_sae (`bool`, *optional*, defaults to ``True``):
1159
+ Whether to run any SAE models registered via :meth:`add_sae_models`.
1160
+ Has no effect when no SAEs are registered.
1161
+ normalize_sae (`bool`, *optional*, defaults to ``False``):
1162
+ When ``True``, scale SAE feature magnitudes by ``idf / max`` (only
1163
+ applied when the SAE's normalization buffers contain non-trivial values).
1164
+
1165
+ Examples:
1166
+
1167
+ ```python
1168
+ >>> from transformers import AutoTokenizer, ESMCModel
1169
+
1170
+ >>> model = ESMCModel.from_pretrained("Biohub/ESMC-600M-2024-12")
1171
+ >>> tokenizer = AutoTokenizer.from_pretrained("Biohub/ESMC-600M-2024-12")
1172
+ >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt")
1173
+ >>> outputs = model(**inputs)
1174
+ >>> outputs.last_hidden_state.shape
1175
+ torch.Size([1, 12, 960])
1176
+ ```
1177
+ """
1178
+ output_hidden_states = (
1179
+ output_hidden_states
1180
+ if output_hidden_states is not None
1181
+ else self.config.output_hidden_states
1182
+ )
1183
+ output_attentions = (
1184
+ output_attentions
1185
+ if output_attentions is not None
1186
+ else self.config.output_attentions
1187
+ )
1188
+ return_dict = (
1189
+ return_dict if return_dict is not None else self.config.use_return_dict
1190
+ )
1191
+
1192
+ output_sae = compute_sae and len(self._sae_models) > 0
1193
+
1194
+ # Determine which intermediate layers to collect. When SAEs are
1195
+ # registered we must collect at least the layers they target, even if
1196
+ # the caller did not ask for all hidden states.
1197
+ if output_hidden_states:
1198
+ layers_to_collect: list[int] = list(range(self.config.n_layers + 1))
1199
+ elif output_sae:
1200
+ layers_to_collect = sorted(
1201
+ {self._get_sae_layer_num_requested(name) for name in self._sae_models}
1202
+ )
1203
+ else:
1204
+ layers_to_collect = []
1205
+
1206
+ user_supplied_sequence_id = sequence_id is not None
1207
+ if sequence_id is not None:
1208
+ bool_mask = sequence_id >= 0
1209
+ else:
1210
+ if attention_mask is None:
1211
+ attention_mask = input_ids != self.config.pad_token_id
1212
+ assert attention_mask is not None
1213
+ bool_mask = attention_mask.bool()
1214
+ sequence_id = bool_mask.to(torch.long) - 1
1215
+
1216
+ x = self.embed(input_ids)
1217
+ b, l_ = x.shape[:2]
1218
+
1219
+ if self._use_flash_attn:
1220
+ if user_supplied_sequence_id and (sequence_id > 0).any():
1221
+ raise ValueError(
1222
+ "Multi-chain ``sequence_id`` (any value > 0) is not "
1223
+ "supported with attn_implementation='flash_attention_2'. "
1224
+ "Re-load the model with attn_implementation='sdpa' (or "
1225
+ "'eager') for chain-aware attention masking."
1226
+ )
1227
+ assert unpad_input is not None
1228
+ x, indices, *_ = unpad_input(x, bool_mask)
1229
+ else:
1230
+ indices = None
1231
+
1232
+ if self._use_flash_attn:
1233
+ trans_seq_id = bool_mask
1234
+ elif user_supplied_sequence_id:
1235
+ trans_seq_id = sequence_id
1236
+ elif bool_mask.all() and not output_attentions:
1237
+ # Fused SDPA fast path (xformers / flash) is correct only when the
1238
+ # mask is uniform; output_attentions forces the manual branch.
1239
+ trans_seq_id = None
1240
+ else:
1241
+ trans_seq_id = sequence_id
1242
+ last_hidden_state, _, collected, attentions = self.transformer(
1243
+ x,
1244
+ sequence_id=trans_seq_id,
1245
+ layers_to_collect=layers_to_collect,
1246
+ output_attentions=output_attentions,
1247
+ )
1248
+
1249
+ if self._use_flash_attn:
1250
+ assert indices is not None and pad_input is not None
1251
+ last_hidden_state = pad_input(last_hidden_state, indices, b, l_)
1252
+ collected = [pad_input(h, indices, b, l_) for h in collected]
1253
+
1254
+ # Stack once; reused for both SAE and hidden-state output.
1255
+ collected_tensor: torch.Tensor | None = (
1256
+ torch.stack(collected, dim=0) if collected else None # type: ignore[arg-type]
1257
+ )
1258
+
1259
+ sae_outputs: dict[str, torch.Tensor] | None = None
1260
+ if output_sae and collected_tensor is not None:
1261
+ assert input_ids is not None
1262
+ self._validate_sae_inputs(input_ids)
1263
+ sae_outputs = self._get_sae_outputs(
1264
+ collected_tensor, layers_to_collect, bool_mask, normalize_sae
1265
+ )
1266
+
1267
+ hidden_states_tensor = collected_tensor if output_hidden_states else None
1268
+
1269
+ if not return_dict:
1270
+ return tuple(
1271
+ v
1272
+ for v in [
1273
+ last_hidden_state,
1274
+ hidden_states_tensor,
1275
+ sae_outputs,
1276
+ attentions,
1277
+ ]
1278
+ if v is not None
1279
+ )
1280
+
1281
+ return ESMCOutput(
1282
+ last_hidden_state=last_hidden_state,
1283
+ hidden_states=hidden_states_tensor,
1284
+ sae_outputs=sae_outputs,
1285
+ attentions=attentions,
1286
+ )
1287
+
1288
+
1289
+ # ---------------------------------------------------------------------------
1290
+ # LM head
1291
+ # ---------------------------------------------------------------------------
1292
+
1293
+
1294
+ def _esmc_lm_head(
1295
+ d_model: int, output_dim: int, hidden_dim: int | None = None
1296
+ ) -> nn.Sequential:
1297
+ """Linear → GELU → LayerNorm → Linear projection head for masked LM."""
1298
+ hidden_dim = hidden_dim if hidden_dim is not None else d_model
1299
+ return nn.Sequential(
1300
+ nn.Linear(d_model, hidden_dim),
1301
+ nn.GELU(),
1302
+ nn.LayerNorm(hidden_dim),
1303
+ nn.Linear(hidden_dim, output_dim),
1304
+ )
1305
+
1306
+
1307
+ # ---------------------------------------------------------------------------
1308
+ # Masked language model
1309
+ # ---------------------------------------------------------------------------
1310
+
1311
+
1312
+ @auto_docstring
1313
+ class ESMCForMaskedLM(ESMCPreTrainedModel):
1314
+ """ESMC with a masked language modelling head.
1315
+
1316
+ This is the primary pre-training objective of ESMC. The LM head consists
1317
+ of a single hidden layer with GELU activation followed by LayerNorm and a
1318
+ linear projection to ``vocab_size``.
1319
+ """
1320
+
1321
+ def __init__(self, config: ESMCConfig):
1322
+ super().__init__(config)
1323
+ self.esmc = ESMCModel(config)
1324
+ self.lm_head = _esmc_lm_head(config.d_model, config.vocab_size)
1325
+ self.post_init()
1326
+
1327
+ def get_output_embeddings(self) -> nn.Linear:
1328
+ return self.lm_head[-1] # type: ignore[return-value]
1329
+
1330
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
1331
+ self.lm_head[-1] = new_embeddings
1332
+
1333
+ def add_sae_models(self, sae_models: list[_ESMCSAELayer]) -> None:
1334
+ """Proxy to :meth:`ESMCModel.add_sae_models`."""
1335
+ self.esmc.add_sae_models(sae_models)
1336
+
1337
+ @can_return_tuple
1338
+ @auto_docstring
1339
+ def forward(
1340
+ self,
1341
+ input_ids: Optional[torch.Tensor] = None,
1342
+ attention_mask: Optional[torch.Tensor] = None,
1343
+ sequence_id: Optional[torch.Tensor] = None,
1344
+ output_hidden_states: Optional[bool] = None,
1345
+ output_attentions: Optional[bool] = None,
1346
+ return_dict: Optional[bool] = None,
1347
+ labels: Optional[torch.Tensor] = None,
1348
+ compute_sae: bool = True,
1349
+ normalize_sae: bool = False,
1350
+ ) -> tuple[torch.Tensor, ...] | ESMCMaskedLMOutput:
1351
+ r"""
1352
+ sequence_id (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1353
+ Integer chain-ID tensor forwarded to the encoder for chain-aware
1354
+ attention masking. See :meth:`ESMCModel.forward` for the encoding.
1355
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1356
+ Labels for masked language modelling loss. Positions with label ``-100``
1357
+ are ignored. Other positions must be in ``[0, config.vocab_size)``.
1358
+ output_attentions (`bool`, *optional*):
1359
+ Whether to return per-block attention weights. Forwarded to the
1360
+ backbone; raises on the ``flash_attention_2`` path.
1361
+ compute_sae (`bool`, *optional*, defaults to ``True``):
1362
+ Whether to run registered SAE models. Has no effect when none are registered.
1363
+ normalize_sae (`bool`, *optional*, defaults to ``False``):
1364
+ When ``True``, scale SAE features by ``idf / max`` normalization buffers.
1365
+
1366
+ Examples:
1367
+
1368
+ ```python
1369
+ >>> from transformers import AutoTokenizer, ESMCForMaskedLM
1370
+ >>> import torch
1371
+
1372
+ >>> model = ESMCForMaskedLM.from_pretrained("Biohub/ESMC-600M-2024-12")
1373
+ >>> tokenizer = AutoTokenizer.from_pretrained("Biohub/ESMC-600M-2024-12")
1374
+ >>> inputs = tokenizer(["MLKNVQ<mask>LV"], return_tensors="pt")
1375
+ >>> outputs = model(**inputs)
1376
+ >>> outputs.logits.shape
1377
+ torch.Size([1, 11, 64])
1378
+ ```
1379
+ """
1380
+ return_dict = (
1381
+ return_dict if return_dict is not None else self.config.use_return_dict
1382
+ )
1383
+
1384
+ encoder_outputs = self.esmc(
1385
+ input_ids=input_ids,
1386
+ attention_mask=attention_mask,
1387
+ sequence_id=sequence_id,
1388
+ output_hidden_states=output_hidden_states,
1389
+ output_attentions=output_attentions,
1390
+ return_dict=True,
1391
+ compute_sae=compute_sae,
1392
+ normalize_sae=normalize_sae,
1393
+ )
1394
+
1395
+ logits = self.lm_head(encoder_outputs.last_hidden_state)
1396
+
1397
+ loss: torch.Tensor | None = None
1398
+ if labels is not None:
1399
+ loss = CrossEntropyLoss(ignore_index=-100)(
1400
+ logits.view(-1, self.config.vocab_size), labels.view(-1)
1401
+ )
1402
+
1403
+ if not return_dict:
1404
+ return tuple(
1405
+ v
1406
+ for v in [
1407
+ loss,
1408
+ logits,
1409
+ encoder_outputs.last_hidden_state,
1410
+ encoder_outputs.hidden_states,
1411
+ encoder_outputs.sae_outputs,
1412
+ encoder_outputs.attentions,
1413
+ ]
1414
+ if v is not None
1415
+ )
1416
+
1417
+ return ESMCMaskedLMOutput(
1418
+ loss=loss,
1419
+ logits=logits,
1420
+ last_hidden_state=encoder_outputs.last_hidden_state,
1421
+ hidden_states=encoder_outputs.hidden_states,
1422
+ sae_outputs=encoder_outputs.sae_outputs,
1423
+ attentions=encoder_outputs.attentions,
1424
+ )
1425
+
1426
+
1427
+ # ---------------------------------------------------------------------------
1428
+ # Classification heads
1429
+ # ---------------------------------------------------------------------------
1430
+
1431
+
1432
+ class _ESMCClassificationHead(nn.Module):
1433
+ """Dense classification head applied to the ``<cls>`` token representation."""
1434
+
1435
+ def __init__(self, config: ESMCConfig):
1436
+ super().__init__()
1437
+ self.dense = nn.Linear(config.d_model, config.d_model)
1438
+ self.dropout = nn.Dropout(config.classifier_dropout)
1439
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
1440
+
1441
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1442
+ x = hidden_states[:, 0, :] # <cls> token
1443
+ x = self.dropout(x)
1444
+ x = torch.tanh(self.dense(x))
1445
+ x = self.dropout(x)
1446
+ return self.out_proj(x)
1447
+
1448
+
1449
+ # ---------------------------------------------------------------------------
1450
+ # Sequence classification
1451
+ # ---------------------------------------------------------------------------
1452
+
1453
+
1454
+ @auto_docstring
1455
+ class ESMCForSequenceClassification(ESMCPreTrainedModel):
1456
+ """ESMC with a sequence-level classification head.
1457
+
1458
+ A linear layer is applied to the ``<cls>`` token representation.
1459
+ Supports regression (``num_labels == 1``), single-label classification,
1460
+ and multi-label classification.
1461
+ """
1462
+
1463
+ def __init__(self, config: ESMCConfig):
1464
+ super().__init__(config)
1465
+ self.num_labels = config.num_labels
1466
+ self.esmc = ESMCModel(config)
1467
+ self.classifier = _ESMCClassificationHead(config)
1468
+ self.post_init()
1469
+
1470
+ def add_sae_models(self, sae_models: list[_ESMCSAELayer]) -> None:
1471
+ """Proxy to :meth:`ESMCModel.add_sae_models`."""
1472
+ self.esmc.add_sae_models(sae_models)
1473
+
1474
+ @can_return_tuple
1475
+ @auto_docstring
1476
+ def forward(
1477
+ self,
1478
+ input_ids: Optional[torch.LongTensor] = None,
1479
+ attention_mask: Optional[torch.Tensor] = None,
1480
+ output_hidden_states: Optional[bool] = None,
1481
+ output_attentions: Optional[bool] = None,
1482
+ return_dict: Optional[bool] = None,
1483
+ labels: Optional[torch.Tensor] = None,
1484
+ compute_sae: bool = True,
1485
+ normalize_sae: bool = False,
1486
+ ) -> tuple[torch.Tensor, ...] | ESMCSequenceClassifierOutput:
1487
+ r"""
1488
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1489
+ Labels for sequence classification loss. Indices must be in
1490
+ ``[0, config.num_labels - 1]``. For regression pass a float
1491
+ tensor of shape ``(batch_size,)``.
1492
+ output_attentions (`bool`, *optional*):
1493
+ Whether to return per-block attention weights. Forwarded to the
1494
+ backbone; raises on the ``flash_attention_2`` path.
1495
+ compute_sae (`bool`, *optional*, defaults to ``True``):
1496
+ Whether to run registered SAE models. Has no effect when none are registered.
1497
+ normalize_sae (`bool`, *optional*, defaults to ``False``):
1498
+ When ``True``, scale SAE features by ``idf / max`` normalization buffers.
1499
+ """
1500
+ return_dict = (
1501
+ return_dict if return_dict is not None else self.config.use_return_dict
1502
+ )
1503
+
1504
+ encoder_outputs = self.esmc(
1505
+ input_ids,
1506
+ attention_mask=attention_mask,
1507
+ output_hidden_states=output_hidden_states,
1508
+ output_attentions=output_attentions,
1509
+ return_dict=True,
1510
+ compute_sae=compute_sae,
1511
+ normalize_sae=normalize_sae,
1512
+ )
1513
+ logits = self.classifier(encoder_outputs.last_hidden_state)
1514
+
1515
+ loss: torch.Tensor | None = None
1516
+ if labels is not None:
1517
+ labels = labels.to(logits.device)
1518
+
1519
+ if self.config.problem_type is None:
1520
+ if self.num_labels == 1:
1521
+ self.config.problem_type = "regression"
1522
+ elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int):
1523
+ self.config.problem_type = "single_label_classification"
1524
+ else:
1525
+ self.config.problem_type = "multi_label_classification"
1526
+
1527
+ if self.config.problem_type == "regression":
1528
+ loss_fct = MSELoss()
1529
+ loss = loss_fct(
1530
+ logits.squeeze() if self.num_labels == 1 else logits,
1531
+ labels.squeeze() if self.num_labels == 1 else labels,
1532
+ )
1533
+ elif self.config.problem_type == "single_label_classification":
1534
+ loss = CrossEntropyLoss()(
1535
+ logits.view(-1, self.num_labels), labels.view(-1)
1536
+ )
1537
+ elif self.config.problem_type == "multi_label_classification":
1538
+ loss = BCEWithLogitsLoss()(logits, labels)
1539
+
1540
+ if not return_dict:
1541
+ return tuple(
1542
+ v
1543
+ for v in [
1544
+ loss,
1545
+ logits,
1546
+ encoder_outputs.last_hidden_state,
1547
+ encoder_outputs.hidden_states,
1548
+ encoder_outputs.sae_outputs,
1549
+ encoder_outputs.attentions,
1550
+ ]
1551
+ if v is not None
1552
+ )
1553
+
1554
+ return ESMCSequenceClassifierOutput(
1555
+ loss=loss,
1556
+ logits=logits,
1557
+ last_hidden_state=encoder_outputs.last_hidden_state,
1558
+ hidden_states=encoder_outputs.hidden_states,
1559
+ sae_outputs=encoder_outputs.sae_outputs,
1560
+ attentions=encoder_outputs.attentions,
1561
+ )
1562
+
1563
+
1564
+ # ---------------------------------------------------------------------------
1565
+ # Token classification
1566
+ # ---------------------------------------------------------------------------
1567
+
1568
+
1569
+ @auto_docstring
1570
+ class ESMCForTokenClassification(ESMCPreTrainedModel):
1571
+ """ESMC with a per-token classification head.
1572
+
1573
+ Useful for tasks such as secondary structure prediction, contact-map
1574
+ prediction, or per-residue labelling.
1575
+ """
1576
+
1577
+ def __init__(self, config: ESMCConfig):
1578
+ super().__init__(config)
1579
+ self.num_labels = config.num_labels
1580
+ self.esmc = ESMCModel(config)
1581
+ self.dropout = nn.Dropout(config.classifier_dropout)
1582
+ self.classifier = nn.Linear(config.d_model, config.num_labels)
1583
+ self.post_init()
1584
+
1585
+ def add_sae_models(self, sae_models: list[_ESMCSAELayer]) -> None:
1586
+ """Proxy to :meth:`ESMCModel.add_sae_models`."""
1587
+ self.esmc.add_sae_models(sae_models)
1588
+
1589
+ @can_return_tuple
1590
+ @auto_docstring
1591
+ def forward(
1592
+ self,
1593
+ input_ids: Optional[torch.Tensor] = None,
1594
+ attention_mask: Optional[torch.Tensor] = None,
1595
+ output_hidden_states: Optional[bool] = None,
1596
+ output_attentions: Optional[bool] = None,
1597
+ return_dict: Optional[bool] = None,
1598
+ labels: Optional[torch.Tensor] = None,
1599
+ compute_sae: bool = True,
1600
+ normalize_sae: bool = False,
1601
+ ) -> tuple[torch.Tensor, ...] | ESMCTokenClassifierOutput:
1602
+ r"""
1603
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1604
+ Per-token labels. Indices must be in ``[0, config.num_labels - 1]``.
1605
+ Positions with index ``-100`` are ignored in the loss.
1606
+ output_attentions (`bool`, *optional*):
1607
+ Whether to return per-block attention weights. Forwarded to the
1608
+ backbone; raises on the ``flash_attention_2`` path.
1609
+ compute_sae (`bool`, *optional*, defaults to ``True``):
1610
+ Whether to run registered SAE models. Has no effect when none are registered.
1611
+ normalize_sae (`bool`, *optional*, defaults to ``False``):
1612
+ When ``True``, scale SAE features by ``idf / max`` normalization buffers.
1613
+ """
1614
+ return_dict = (
1615
+ return_dict if return_dict is not None else self.config.use_return_dict
1616
+ )
1617
+
1618
+ encoder_outputs = self.esmc(
1619
+ input_ids=input_ids,
1620
+ attention_mask=attention_mask,
1621
+ output_hidden_states=output_hidden_states,
1622
+ output_attentions=output_attentions,
1623
+ return_dict=True,
1624
+ compute_sae=compute_sae,
1625
+ normalize_sae=normalize_sae,
1626
+ )
1627
+
1628
+ sequence_output = self.dropout(encoder_outputs.last_hidden_state)
1629
+ logits = self.classifier(sequence_output)
1630
+
1631
+ loss: torch.Tensor | None = None
1632
+ if labels is not None:
1633
+ loss = CrossEntropyLoss(ignore_index=-100)(
1634
+ logits.view(-1, self.num_labels), labels.to(logits.device).view(-1)
1635
+ )
1636
+
1637
+ if not return_dict:
1638
+ return tuple(
1639
+ v
1640
+ for v in [
1641
+ loss,
1642
+ logits,
1643
+ encoder_outputs.last_hidden_state,
1644
+ encoder_outputs.hidden_states,
1645
+ encoder_outputs.sae_outputs,
1646
+ encoder_outputs.attentions,
1647
+ ]
1648
+ if v is not None
1649
+ )
1650
+
1651
+ return ESMCTokenClassifierOutput(
1652
+ loss=loss,
1653
+ logits=logits,
1654
+ last_hidden_state=encoder_outputs.last_hidden_state,
1655
+ hidden_states=encoder_outputs.hidden_states,
1656
+ sae_outputs=encoder_outputs.sae_outputs,
1657
+ attentions=encoder_outputs.attentions,
1658
+ )
1659
+
1660
+
1661
+ __all__ = [
1662
+ "ESMCModel",
1663
+ "ESMCForMaskedLM",
1664
+ "ESMCForSequenceClassification",
1665
+ "ESMCForTokenClassification",
1666
+ "ESMCPreTrainedModel",
1667
+ ]
modeling_esmc_sae.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2026 Biohub. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PyTorch ESMC SAE (Sparse Autoencoder) model.
15
+
16
+ * :class:`ESMCSAEModel` — the published HF container, one repo per
17
+ ``(backbone, codebook_dim, k)`` group. Each backbone layer ships as a
18
+ ``layer_{i}.safetensors`` shard; ``from_pretrained`` downloads the whole
19
+ snapshot but loads no weights — callers materialize the layers they need
20
+ via :meth:`initialize_layers`. Single-layer repos auto-load so bare
21
+ ``forward(x)`` works.
22
+ * :class:`_ESMCSAELayer` — internal ``nn.Module`` that holds the weights for
23
+ one ``(backbone, codebook_dim, k, layer)`` SAE. Not a published HF artifact;
24
+ obtained only via ``model.layers["<idx>"]``.
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import os
30
+ from dataclasses import dataclass
31
+ from pathlib import Path
32
+ from typing import Optional
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ from safetensors.torch import load_file, save_file
38
+
39
+ from transformers.modeling_outputs import ModelOutput
40
+ from transformers.modeling_utils import PreTrainedModel
41
+ from transformers.utils import auto_docstring
42
+ from .configuration_esmc_sae import ESMCSAEConfig, ESMCSAEParams
43
+
44
+
45
+ @dataclass
46
+ @auto_docstring(
47
+ custom_intro="""
48
+ Output type of [`ESMCSAEModel`].
49
+ """
50
+ )
51
+ class ESMCSAEOutput(ModelOutput):
52
+ feature_magnitudes: torch.Tensor
53
+ reconstruction_loss: Optional[torch.Tensor] = None
54
+
55
+ def to_sparse(self) -> None:
56
+ self.feature_magnitudes = self.feature_magnitudes.to_sparse()
57
+
58
+
59
+ class _ESMCSAELayer(nn.Module):
60
+ """One backbone layer's SAE — internal building block of :class:`ESMCSAEModel`.
61
+
62
+ Not exposed via ``AutoModel`` and not loadable on its own. Obtain one
63
+ via ``model.layers["<layer_idx>"]`` after calling ``initialize_layers``.
64
+ """
65
+
66
+ def __init__(self, params: ESMCSAEParams):
67
+ super().__init__()
68
+ self.params = params
69
+
70
+ self.W_enc = nn.Parameter(torch.empty(params.d_model, params.codebook_dim))
71
+ self.W_dec = nn.Parameter(torch.empty(params.codebook_dim, params.d_model))
72
+ self.b_dec = nn.Parameter(torch.zeros(params.d_model))
73
+ # Per-feature normalization stats. Trained alongside the SAE for some
74
+ # variants; for variants that don't ship them, leaving these as ones
75
+ # makes ``_get_sae_outputs``'s ``features / max * idf`` a no-op.
76
+ self.register_buffer("idf", torch.ones(params.codebook_dim))
77
+ self.register_buffer("max", torch.ones(params.codebook_dim))
78
+
79
+ @property
80
+ def layer(self) -> int:
81
+ """Backbone-layer index this SAE is trained against."""
82
+ return self.params.layer
83
+
84
+ def forward(self, x: torch.Tensor, **_kwargs: object) -> ESMCSAEOutput:
85
+ del _kwargs
86
+ x = self._zscore_normalize_representation(x)
87
+
88
+ x_with_pre_encoder_bias = x - self.b_dec
89
+ preactivations = F.relu(x_with_pre_encoder_bias @ self.W_enc)
90
+
91
+ topk = torch.topk(preactivations, self.params.k, dim=-1)
92
+ feature_magnitudes = torch.zeros_like(preactivations).scatter(
93
+ -1, topk.indices, topk.values
94
+ )
95
+
96
+ reconstructed = feature_magnitudes @ self.W_dec + self.b_dec
97
+
98
+ reconstruction_loss = (reconstructed - x).pow(2).mean(dim=-1)
99
+
100
+ return ESMCSAEOutput(
101
+ feature_magnitudes=feature_magnitudes,
102
+ reconstruction_loss=reconstruction_loss,
103
+ )
104
+
105
+ def get_sae_output(
106
+ self, layer_states: torch.Tensor, token_mask: torch.Tensor
107
+ ) -> ESMCSAEOutput:
108
+ _, _, v_len = layer_states.shape
109
+ nonpad_states = layer_states[token_mask].view(-1, v_len)
110
+ return self(nonpad_states)
111
+
112
+ def _zscore_normalize_representation(self, x: torch.Tensor) -> torch.Tensor:
113
+ x_mean = x.mean(dim=-1, keepdim=True)
114
+ x = x - x_mean
115
+ x_std = x.std(dim=-1, keepdim=True)
116
+ return x / (x_std + 1e-5)
117
+
118
+
119
+ @auto_docstring
120
+ class ESMCSAEPreTrainedModel(PreTrainedModel):
121
+ config_class = ESMCSAEConfig
122
+ base_model_prefix = "esmc_sae"
123
+
124
+
125
+ @auto_docstring(
126
+ custom_intro="""
127
+ HF container holding one SAE per backbone layer, all sharing the same
128
+ ``(d_model, codebook_dim, k)``.
129
+
130
+ ``from_pretrained`` downloads the entire repo (every ``layer_{i}.safetensors``)
131
+ into the local HF cache but does **not** load any weights into memory.
132
+ Callers materialize the layers they actually need by calling
133
+ :meth:`initialize_layers`. The full set is available on disk after the
134
+ first call, so subsequent layer switches read from the local cache without
135
+ re-downloading.
136
+
137
+ Examples::
138
+
139
+ model = ESMCSAEModel.from_pretrained(
140
+ "biohub/esmc-6b-2024-12-sae-k64-codebook16384"
141
+ )
142
+ model.initialize_layers([60]) # ~2.5 GB into memory
143
+ out = model(layer_states, layer=60) # forward through layer 60
144
+ model.initialize_layers([45]) # add layer 45 (cached locally)
145
+ model.release_layer(60) # free layer 60
146
+ """
147
+ )
148
+ class ESMCSAEModel(ESMCSAEPreTrainedModel):
149
+ def __init__(self, config: ESMCSAEConfig):
150
+ super().__init__(config)
151
+ # Layers are populated lazily by ``initialize_layers``; the container
152
+ # starts empty so ``from_pretrained`` doesn't materialize hundreds of
153
+ # GB of unused parameters.
154
+ self.layers = nn.ModuleDict()
155
+ # Zero-element buffer that rides along with ``.to(device/dtype)``.
156
+ # ``initialize_layers`` reads its current device/dtype so SAEs added
157
+ # after ``model.to("cuda")`` land on CUDA without re-passing ``device=``.
158
+ self.register_buffer("_device_marker", torch.empty(0), persistent=False)
159
+ self._snapshot_dir: Optional[str] = None
160
+ self.post_init()
161
+
162
+ @classmethod
163
+ def from_pretrained( # type: ignore[override]
164
+ cls, pretrained_model_name_or_path: str | os.PathLike, *model_args, **kwargs
165
+ ) -> "ESMCSAEModel":
166
+ """Download (or reuse cached) the full repo and return the model.
167
+
168
+ By default no weights are read into memory and the caller must invoke
169
+ :meth:`initialize_layers` before running :meth:`forward`. The single
170
+ exception is when the repo ships exactly one layer: that layer is
171
+ auto-loaded (honoring ``torch_dtype`` / ``device`` if passed) so the
172
+ bare ``forward(x)`` call just works.
173
+
174
+ Honored kwargs: ``revision``, ``cache_dir``, ``token``,
175
+ ``allow_patterns``, ``local_files_only``, ``force_download`` (forwarded
176
+ to ``snapshot_download``); ``torch_dtype`` and ``device`` (used by the
177
+ single-layer auto-load path; otherwise pass them to
178
+ :meth:`initialize_layers`). Behavioral kwargs that imply work we do
179
+ not perform (``device_map``, ``low_cpu_mem_usage``,
180
+ ``quantization_config``, ``attn_implementation``) raise so the user
181
+ isn't silently misled. Other HF housekeeping kwargs (``config``,
182
+ ``trust_remote_code``, ``adapter_kwargs``, …) are accepted and
183
+ ignored — they only matter for the standard loader, which we bypass.
184
+ """
185
+ del model_args
186
+ torch_dtype = kwargs.pop("torch_dtype", None)
187
+ device = kwargs.pop("device", None)
188
+ local_dir = _resolve_snapshot_dir(pretrained_model_name_or_path, kwargs)
189
+ unsupported = {
190
+ "device_map",
191
+ "low_cpu_mem_usage",
192
+ "quantization_config",
193
+ "attn_implementation",
194
+ "max_memory",
195
+ "offload_folder",
196
+ "offload_state_dict",
197
+ } & kwargs.keys()
198
+ if unsupported:
199
+ raise TypeError(
200
+ f"Unsupported kwargs to ESMCSAEModel.from_pretrained: "
201
+ f"{sorted(unsupported)}. The standard HF loader is bypassed —"
202
+ " call initialize_layers(..., device=, dtype=) instead."
203
+ )
204
+ config = ESMCSAEConfig.from_pretrained(local_dir)
205
+ model = cls(config)
206
+ model._snapshot_dir = str(local_dir)
207
+ if device is not None:
208
+ model.to(device)
209
+ if torch_dtype is not None:
210
+ model.to(torch_dtype)
211
+ if len(config.available_layers) == 1:
212
+ model.initialize_layers(list(config.available_layers))
213
+ return model
214
+
215
+ def initialize_layers(
216
+ self,
217
+ layers: list[int],
218
+ *,
219
+ device: torch.device | str | None = None,
220
+ dtype: torch.dtype | None = None,
221
+ ) -> None:
222
+ """Load the requested layers from the local snapshot into memory.
223
+
224
+ Layers already present in :attr:`self.layers` are skipped — calling
225
+ ``initialize_layers([23])`` twice is idempotent. ``device`` / ``dtype``
226
+ default to wherever the model itself lives (via the ``_device_marker``
227
+ buffer that moves with ``.to(...)``), so the common pattern of
228
+ ``model.to("cuda"); model.initialize_layers([7])`` Just Works.
229
+ """
230
+ assert self._snapshot_dir is not None, (
231
+ "ESMCSAEModel has no snapshot directory — call "
232
+ "from_pretrained first, or set _snapshot_dir manually."
233
+ )
234
+ if device is None:
235
+ device = self._device_marker.device
236
+ if dtype is None:
237
+ dtype = self._device_marker.dtype
238
+ snapshot_dir = Path(self._snapshot_dir)
239
+ available = set(self.config.available_layers)
240
+ for layer_idx in layers:
241
+ key = str(layer_idx)
242
+ if key in self.layers:
243
+ continue
244
+ if layer_idx not in available:
245
+ raise KeyError(
246
+ f"Layer {layer_idx} is not in this repo. "
247
+ f"available_layers={sorted(available)}"
248
+ )
249
+ shard = snapshot_dir / f"layer_{layer_idx}.safetensors"
250
+ if not shard.exists():
251
+ raise FileNotFoundError(
252
+ f"Missing layer file {shard} — config lists layer "
253
+ f"{layer_idx} as available but the shard is not on disk."
254
+ )
255
+ params = ESMCSAEParams(
256
+ d_model=self.config.d_model,
257
+ codebook_dim=self.config.codebook_dim,
258
+ k=self.config.k,
259
+ layer=layer_idx,
260
+ )
261
+ # Build on the meta device so we don't allocate weights that
262
+ # ``load_state_dict`` would immediately overwrite.
263
+ with torch.device("meta"):
264
+ layer = _ESMCSAELayer(params)
265
+ layer.to_empty(device=device)
266
+ layer.load_state_dict(load_file(str(shard)))
267
+ layer.to(dtype=dtype)
268
+ self.layers[key] = layer
269
+
270
+ def release_layer(self, layer: int) -> None:
271
+ """Drop the named layer from memory. No-op if not loaded."""
272
+ key = str(layer)
273
+ if key in self.layers:
274
+ del self.layers[key]
275
+
276
+ def loaded_layers(self) -> list[int]:
277
+ """Sorted list of layer indices currently materialized in memory."""
278
+ return sorted(int(k) for k in self.layers.keys())
279
+
280
+ def forward(
281
+ self, x: torch.Tensor, layer: int | None = None, **kwargs: object
282
+ ) -> ESMCSAEOutput:
283
+ if layer is None:
284
+ if len(self.layers) == 1:
285
+ # Unambiguous: exactly one layer loaded → use it.
286
+ ((_only_key, only_layer),) = self.layers.items()
287
+ return only_layer(x, **kwargs)
288
+ if len(self.layers) == 0:
289
+ raise RuntimeError(
290
+ "No layers loaded — call "
291
+ f"initialize_layers([...]) first. "
292
+ f"available_layers={self.config.available_layers}"
293
+ )
294
+ raise RuntimeError(
295
+ "Multiple layers are loaded — please select one via "
296
+ f"forward(x, layer=<idx>). Loaded layers: {self.loaded_layers()}"
297
+ )
298
+ key = str(layer)
299
+ if key not in self.layers:
300
+ raise KeyError(
301
+ f"Layer {layer} is not loaded. Call "
302
+ f"initialize_layers([{layer}]) first. Loaded layers: "
303
+ f"{self.loaded_layers()}"
304
+ )
305
+ return self.layers[key](x, **kwargs)
306
+
307
+ def save_pretrained( # type: ignore[override]
308
+ self, save_directory: str | os.PathLike, *args, **kwargs
309
+ ) -> None:
310
+ """Write ``config.json`` plus one ``layer_{i}.safetensors`` per loaded layer.
311
+
312
+ Only layers currently in :attr:`self.layers` are written.
313
+ ``available_layers`` in the saved config is synced to what's actually
314
+ on disk so a ``release_layer`` + ``save_pretrained`` round-trip never
315
+ advertises a layer whose shard is missing.
316
+ """
317
+ del args, kwargs
318
+ save_directory = Path(save_directory)
319
+ save_directory.mkdir(parents=True, exist_ok=True)
320
+ # Sync available_layers to what we're about to write — never advertise
321
+ # a layer that isn't on disk in this repo.
322
+ self.config.available_layers = self.loaded_layers()
323
+ self.config.save_pretrained(str(save_directory))
324
+ for key, layer in self.layers.items():
325
+ shard = save_directory / f"layer_{key}.safetensors"
326
+ save_file(
327
+ {
328
+ k: v.detach().cpu().contiguous()
329
+ for k, v in layer.state_dict().items()
330
+ },
331
+ str(shard),
332
+ )
333
+
334
+
335
+ def _resolve_snapshot_dir(
336
+ pretrained_model_name_or_path: str | os.PathLike, kwargs: dict
337
+ ) -> str:
338
+ """Local dir → return as-is; hub id → ``snapshot_download`` it.
339
+
340
+ A directory only counts as "local" if it actually contains ``config.json``,
341
+ so a stale subdir named like a hub id (``./biohub/esmc-...``)
342
+ doesn't accidentally shadow the hub fetch.
343
+
344
+ Pops the standard ``snapshot_download`` keyword args from ``kwargs`` so
345
+ callers can forward them via ``from_pretrained``.
346
+ """
347
+ path = Path(pretrained_model_name_or_path)
348
+ if path.is_dir() and (path / "config.json").exists():
349
+ return str(path)
350
+ from huggingface_hub import snapshot_download
351
+
352
+ return snapshot_download(
353
+ repo_id=str(pretrained_model_name_or_path),
354
+ revision=kwargs.pop("revision", None),
355
+ cache_dir=kwargs.pop("cache_dir", None),
356
+ token=kwargs.pop("token", None),
357
+ allow_patterns=kwargs.pop("allow_patterns", None),
358
+ local_files_only=kwargs.pop("local_files_only", False),
359
+ force_download=kwargs.pop("force_download", False),
360
+ )
361
+
362
+
363
+ __all__ = ["ESMCSAEModel", "ESMCSAEOutput", "ESMCSAEPreTrainedModel"]
modeling_esmfold2.py ADDED
@@ -0,0 +1,1288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch ESMFold2 model — the standard released architecture.
2
+
3
+ Quickstart::
4
+
5
+ from transformers import ESMFold2Model
6
+
7
+ model = ESMFold2Model.from_pretrained("biohub/ESMFold2").cuda().eval()
8
+ open("ubq.pdb", "w").write(model.infer_protein_as_pdb("MQIFVKTLTGKT..."))
9
+
10
+ For multi-chain / ligand / MSA inputs see ``ESMFold2InputBuilder`` in the
11
+ companion ``esm`` package.
12
+ """
13
+
14
+ import importlib
15
+ import math
16
+ import sys
17
+ from contextlib import contextmanager
18
+ from pathlib import Path
19
+ from typing import Any, cast
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch import Tensor
25
+
26
+ try:
27
+ te = importlib.import_module("transformer_engine.pytorch")
28
+ te_recipe = importlib.import_module("transformer_engine.common.recipe")
29
+ DelayedScaling = te_recipe.DelayedScaling
30
+ Format = te_recipe.Format
31
+
32
+ TE_AVAILABLE = True
33
+ except ImportError:
34
+ te = None # type: ignore[assignment]
35
+ DelayedScaling = None # type: ignore[assignment]
36
+ Format = None # type: ignore[assignment]
37
+ TE_AVAILABLE = False
38
+
39
+ from transformers.modeling_utils import PreTrainedModel
40
+ from .configuration_esmc import ESMCConfig as _FastPLMSESMCConfig
41
+ from .configuration_esmc_sae import ESMCSAEConfig as _FastPLMSESMCSAEConfig
42
+ from .configuration_esmfold2 import ESMFold2Config
43
+ from .modeling_esmc import ESMCModel as _FastPLMSESMCModel
44
+ from .modeling_esmc_sae import _ESMCSAELayer as _FastPLMSESMCSAELayer
45
+ from .modeling_esmfold2_common import (
46
+ CHAR_VOCAB_SIZE,
47
+ MAX_ATOMIC_NUMBER,
48
+ NUM_RES_TYPES,
49
+ DiffusionStructureHead,
50
+ FoldingTrunk,
51
+ InputsEmbedder,
52
+ LanguageModelShim,
53
+ MSAPairWeightedAveraging,
54
+ OuterProductMean,
55
+ ResIdxAsymIdSymIdEntityIdEncoding,
56
+ RowAttentionPooling,
57
+ SwiGLUMLP,
58
+ TriangleMultiplicativeUpdate,
59
+ _categorical_mean,
60
+ _compute_intra_token_idx,
61
+ compute_lm_hidden_states,
62
+ gather_rep_atom_coords,
63
+ gather_token_to_atom,
64
+ )
65
+ from .esmfold2_affine3d import Affine3D as _FastPLMSESMFold2Affine3D
66
+ from .esmfold2_aligner import Aligner as _FastPLMSESMFold2Aligner
67
+ from .esmfold2_atom_indexer import AtomIndexer as _FastPLMSESMFold2AtomIndexer
68
+ from .esmfold2_conformers import load_ccd as _fastplms_esmfold2_load_ccd
69
+ from .esmfold2_constants import ELEMENT_NUMBER_TO_SYMBOL as _FASTPLMS_ESMFOLD2_ELEMENT_NUMBER_TO_SYMBOL
70
+ from .esmfold2_constants_esm3 import CHAIN_BREAK_STR as _FASTPLMS_ESMFOLD2_CHAIN_BREAK_STR
71
+ from .esmfold2_input_builder import StructurePredictionInput as _FastPLMSESMFold2StructurePredictionInput
72
+ from .esmfold2_metrics import compute_rmsd as _fastplms_esmfold2_compute_rmsd
73
+ from .esmfold2_misc import slice_any_object as _fastplms_esmfold2_slice_any_object
74
+ from .esmfold2_mmcif_parsing import MmcifWrapper as _FastPLMSESMFold2MmcifWrapper
75
+ from .esmfold2_molecular_complex import MolecularComplex as _FastPLMSESMFold2MolecularComplex
76
+ from .esmfold2_msa import MSA as _FastPLMSESMFold2MSA
77
+ from .esmfold2_msa_filter_sequences import greedy_select_indices as _fastplms_esmfold2_greedy_select_indices
78
+ from .esmfold2_normalize_coordinates import normalize_coordinates as _fastplms_esmfold2_normalize_coordinates
79
+ from .esmfold2_output import build_molecular_complex_from_features as _fastplms_esmfold2_build_molecular_complex_from_features
80
+ from .esmfold2_paired_msa import construct_paired_msa as _fastplms_esmfold2_construct_paired_msa
81
+ from .esmfold2_parsing import FastaEntry as _FastPLMSESMFold2FastaEntry
82
+ from .esmfold2_predicted_aligned_error import compute_tm as _fastplms_esmfold2_compute_tm
83
+ from .esmfold2_prepare_input import prepare_esmfold2_input as _fastplms_esmfold2_prepare_esmfold2_input
84
+ from .esmfold2_processor import ESMFold2InputBuilder as _FastPLMSESMFold2InputBuilder
85
+ from .esmfold2_protein_chain import ProteinChain as _FastPLMSESMFold2ProteinChain
86
+ from .esmfold2_protein_complex import ProteinComplex as _FastPLMSESMFold2ProteinComplex
87
+ from .esmfold2_protein_structure import index_by_atom_name as _fastplms_esmfold2_index_by_atom_name
88
+ from .esmfold2_residue_constants import restypes as _FASTPLMS_ESMFOLD2_RESTYPES
89
+ from .esmfold2_sequential_dataclass import SequentialDataclass as _FastPLMSESMFold2SequentialDataclass
90
+ from .esmfold2_system import run_subprocess_with_errorcheck as _fastplms_esmfold2_run_subprocess_with_errorcheck
91
+ from .esmfold2_types import ProteinInput as _FastPLMSESMFold2ProteinInput
92
+ from .esmfold2_utils_types import PathOrBuffer as _FastPLMSESMFold2PathOrBuffer
93
+
94
+ _EPS = 1e-6
95
+ _NONPOLYMER_ID = 4
96
+
97
+ # Default for the triangle / OPM / pair-transition L² ops. Caps peak memory
98
+ # so L≈2k folds on an 80 GB GPU (~76 GB peak at chunk=128 for L=1438;
99
+ # chunk=64 leaves headroom for the largest foldbench targets). Override via
100
+ # ``model.set_chunk_size(...)``; pass None to disable chunking (faster for
101
+ # short L but OOM-prone past ~600).
102
+ _DEFAULT_CHUNK_SIZE = 64
103
+
104
+
105
+ def _ensure_vendored_esm_alias() -> None:
106
+ package = __package__
107
+ assert package is not None
108
+ vendored_esm = importlib.import_module(f"{package}.esm")
109
+ sys.modules["esm"] = vendored_esm
110
+
111
+
112
+ class PairTransition(nn.Module):
113
+ """LayerNorm + SwiGLU feed-forward residual block on the pair representation."""
114
+
115
+ def __init__(self, d_model: int, expansion_ratio: int = 4) -> None:
116
+ super().__init__()
117
+ self.norm = nn.LayerNorm(d_model)
118
+ self.ffn = SwiGLUMLP(d_model, expansion_ratio=expansion_ratio, bias=False)
119
+ self._chunk_size: int | None = _DEFAULT_CHUNK_SIZE
120
+
121
+ def set_chunk_size(self, chunk_size: int | None) -> None:
122
+ self._chunk_size = chunk_size
123
+
124
+ def forward(self, x: Tensor) -> Tensor:
125
+ if self._chunk_size is None or x.shape[1] <= self._chunk_size:
126
+ return self.ffn(self.norm(x))
127
+ out: list[Tensor] = []
128
+ for s in range(0, x.shape[1], self._chunk_size):
129
+ e = min(s + self._chunk_size, x.shape[1])
130
+ sl = x[:, s:e]
131
+ out.append(self.ffn(self.norm(sl)))
132
+ return torch.cat(out, dim=1)
133
+
134
+
135
+ class ConfidenceHead(nn.Module):
136
+ """Predicts pLDDT, PAE, PDE, resolved-atom probability and distogram bins."""
137
+
138
+ boundaries: Tensor
139
+
140
+ def __init__(self, config: "ESMFold2Config") -> None:
141
+ super().__init__()
142
+ ch = config.confidence_head
143
+ d_single = config.d_single
144
+ d_pair = config.d_pair
145
+ d_inputs = config.inputs.d_inputs
146
+
147
+ boundaries = torch.linspace(ch.min_dist, ch.max_dist, ch.distogram_bins - 1)
148
+ self.register_buffer("boundaries", boundaries)
149
+ self.dist_bin_pairwise_embed = nn.Embedding(ch.distogram_bins, d_pair)
150
+
151
+ self.s_norm = nn.LayerNorm(d_single)
152
+ self.s_inputs_to_single = nn.Linear(d_inputs, d_single, bias=False)
153
+ self.s_to_z = nn.Linear(d_inputs, d_pair, bias=False)
154
+ self.s_to_z_transpose = nn.Linear(d_inputs, d_pair, bias=False)
155
+ self.s_to_z_prod_in1 = nn.Linear(d_inputs, d_pair, bias=False)
156
+ self.s_to_z_prod_in2 = nn.Linear(d_inputs, d_pair, bias=False)
157
+ self.s_to_z_prod_out = nn.Linear(d_pair, d_pair, bias=False)
158
+ self.s_input_to_s = nn.Linear(d_inputs, d_single, bias=False)
159
+ self.s_inputs_norm = nn.LayerNorm(d_inputs)
160
+ self.z_norm = nn.LayerNorm(d_pair)
161
+
162
+ self.row_attention_pooling = RowAttentionPooling(
163
+ d_pair=d_pair, d_single=d_single
164
+ )
165
+
166
+ pf = ch.folding_trunk
167
+ self.folding_trunk = FoldingTrunk(
168
+ n_layers=pf.n_layers, d_pair=d_pair, expansion_ratio=4
169
+ )
170
+
171
+ # Heads.
172
+ self.plddt_ln = nn.LayerNorm(d_single)
173
+ max_atoms_per_token = 23
174
+ self.plddt_weight = nn.Parameter(
175
+ torch.zeros(max_atoms_per_token, d_single, ch.num_plddt_bins)
176
+ )
177
+
178
+ self.pae_ln = nn.LayerNorm(d_pair)
179
+ self.pae_head = nn.Linear(d_pair, ch.num_pae_bins, bias=False)
180
+
181
+ self.pde_ln = nn.LayerNorm(d_pair)
182
+ self.pde_head = nn.Linear(d_pair, ch.num_pde_bins, bias=False)
183
+
184
+ self.resolved_ln = nn.LayerNorm(d_single)
185
+ # 2 = resolved logits ([unresolved, resolved]).
186
+ self.resolved_weight = nn.Parameter(
187
+ torch.zeros(max_atoms_per_token, d_single, 2)
188
+ )
189
+
190
+ def set_kernel_backend(self, backend: str | None) -> None:
191
+ self.folding_trunk.set_kernel_backend(backend)
192
+
193
+ def set_chunk_size(self, chunk_size: int | None) -> None:
194
+ self.folding_trunk.set_chunk_size(chunk_size)
195
+
196
+ @staticmethod
197
+ def _repeat_batch(x: Tensor, num_diffusion_samples: int) -> Tensor:
198
+ return (
199
+ x
200
+ if num_diffusion_samples == 1
201
+ else x.repeat_interleave(num_diffusion_samples, 0)
202
+ )
203
+
204
+ @staticmethod
205
+ def _flatten_sample_axis(x: Tensor) -> Tensor:
206
+ if x.ndim == 4:
207
+ b, mult, n, c = x.shape
208
+ return x.reshape(b * mult, n, c)
209
+ return x
210
+
211
+ def forward(
212
+ self,
213
+ s_inputs: Tensor,
214
+ z: Tensor,
215
+ x_pred: Tensor,
216
+ distogram_atom_idx: Tensor,
217
+ token_attention_mask: Tensor,
218
+ atom_to_token: Tensor,
219
+ atom_attention_mask: Tensor,
220
+ asym_id: Tensor,
221
+ mol_type: Tensor,
222
+ num_diffusion_samples: int = 1,
223
+ relative_position_encoding: Tensor | None = None,
224
+ token_bonds_encoding: Tensor | None = None,
225
+ ) -> dict[str, Tensor]:
226
+ s_inputs_normed = self.s_inputs_norm(s_inputs)
227
+
228
+ z_base = self.z_norm(z)
229
+ if relative_position_encoding is not None:
230
+ z_base = z_base + relative_position_encoding
231
+ if token_bonds_encoding is not None:
232
+ z_base = z_base + token_bonds_encoding
233
+ z_base = z_base + self.s_to_z(s_inputs_normed).unsqueeze(2)
234
+ z_base = z_base + self.s_to_z_transpose(s_inputs_normed).unsqueeze(1)
235
+ z_base = z_base + self.s_to_z_prod_out(
236
+ self.s_to_z_prod_in1(s_inputs_normed)[:, :, None, :]
237
+ * self.s_to_z_prod_in2(s_inputs_normed)[:, None, :, :]
238
+ )
239
+
240
+ pair = self._repeat_batch(z_base, num_diffusion_samples)
241
+ x_pred_flat = self._flatten_sample_axis(x_pred)
242
+ atom_to_token_m = self._repeat_batch(atom_to_token, num_diffusion_samples)
243
+ atom_mask_m = self._repeat_batch(atom_attention_mask, num_diffusion_samples)
244
+ rep_idx_m = self._repeat_batch(distogram_atom_idx, num_diffusion_samples).long()
245
+ mask = self._repeat_batch(token_attention_mask, num_diffusion_samples)
246
+ Bm = pair.shape[0]
247
+
248
+ rep_coords = gather_rep_atom_coords(x_pred_flat, rep_idx_m)
249
+ rep_distances = torch.cdist(
250
+ rep_coords, rep_coords, compute_mode="donot_use_mm_for_euclid_dist"
251
+ )
252
+ distogram_bins = (
253
+ (rep_distances.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
254
+ )
255
+ pair = pair + self.dist_bin_pairwise_embed(distogram_bins)
256
+
257
+ pair_mask = mask[:, :, None].float() * mask[:, None, :].float()
258
+
259
+ # FoldingTrunk handles the bf16 cast internally during inference so
260
+ # each block's fused trimul engages. In-place residual avoids an
261
+ # extra fp32 pair allocation.
262
+ with torch.amp.autocast("cuda", enabled=pair.is_cuda, dtype=torch.bfloat16):
263
+ pair_delta = self.folding_trunk(pair, pair_attention_mask=pair_mask)
264
+ pair.add_(pair_delta.float())
265
+ del pair_delta
266
+ single = self.row_attention_pooling(pair, mask)
267
+
268
+ atom_mask_f = atom_mask_m.float()
269
+ s_at_atoms = gather_token_to_atom(single, atom_to_token_m)
270
+ s_at_atoms_ln = self.plddt_ln(s_at_atoms)
271
+
272
+ intra_idx = _compute_intra_token_idx(atom_to_token_m)
273
+ intra_idx = intra_idx.clamp(max=self.plddt_weight.shape[0] - 1)
274
+ w_plddt = self.plddt_weight[intra_idx]
275
+ plddt_logits = torch.einsum("...c,...cb->...b", s_at_atoms_ln, w_plddt)
276
+ plddt_per_atom = _categorical_mean(plddt_logits, start=0.0, end=1.0)
277
+
278
+ L = single.shape[1]
279
+ plddt_sum = torch.zeros(Bm, L, device=single.device, dtype=plddt_per_atom.dtype)
280
+ atom_count = torch.zeros(
281
+ Bm, L, device=single.device, dtype=plddt_per_atom.dtype
282
+ )
283
+ atom_mask_t = atom_mask_f.to(plddt_per_atom.dtype)
284
+ plddt_sum.scatter_add_(1, atom_to_token_m, plddt_per_atom * atom_mask_t)
285
+ atom_count.scatter_add_(1, atom_to_token_m, atom_mask_t)
286
+ plddt = plddt_sum / atom_count.clamp(min=1e-6)
287
+
288
+ complex_plddt = (plddt_per_atom * atom_mask_f).sum(dim=-1) / (
289
+ atom_mask_f.sum(dim=-1) + _EPS
290
+ )
291
+
292
+ expanded_type = self._repeat_batch(mol_type, num_diffusion_samples)
293
+ expanded_asym = self._repeat_batch(asym_id, num_diffusion_samples)
294
+ is_ligand = (expanded_type == _NONPOLYMER_ID).float()
295
+ inter_chain = (
296
+ expanded_asym.unsqueeze(-1) != expanded_asym.unsqueeze(-2)
297
+ ).float()
298
+ near_contact = (rep_distances < 8).float()
299
+ interface_per_token = (
300
+ near_contact * inter_chain * (1.0 - is_ligand).unsqueeze(-1)
301
+ ).amax(dim=-1)
302
+ iplddt_weight = torch.where(
303
+ is_ligand.bool(),
304
+ torch.full_like(interface_per_token, 2.0),
305
+ interface_per_token,
306
+ )
307
+ iplddt_weight_atoms = gather_token_to_atom(
308
+ iplddt_weight.unsqueeze(-1), atom_to_token_m
309
+ ).squeeze(-1)
310
+ atom_iplddt_w = atom_mask_f * iplddt_weight_atoms
311
+ complex_iplddt = (plddt_per_atom * atom_iplddt_w).sum(dim=-1) / (
312
+ atom_iplddt_w.sum(dim=-1) + _EPS
313
+ )
314
+
315
+ plddt_ca = plddt_per_atom.gather(1, rep_idx_m)
316
+
317
+ # PAE
318
+ pae_logits = self.pae_head(self.pae_ln(pair))
319
+ pae = _categorical_mean(pae_logits, start=0.0, end=32.0).detach()
320
+
321
+ # PDE
322
+ pde_logits = self.pde_head(self.pde_ln(pair))
323
+ pde = _categorical_mean(pde_logits, start=0.0, end=32.0).detach()
324
+
325
+ # Resolved (per-atom binary).
326
+ s_at_atoms_res = self.resolved_ln(s_at_atoms)
327
+ w_res = self.resolved_weight[intra_idx]
328
+ resolved_logits = torch.einsum("...c,...cb->...b", s_at_atoms_res, w_res)
329
+
330
+ # pTM / ipTM from pae_logits.
331
+ n_bins = pae_logits.shape[-1]
332
+ bin_width = 32.0 / n_bins
333
+ bin_centers = torch.arange(
334
+ 0.5 * bin_width, 32.0, bin_width, device=pae_logits.device
335
+ )
336
+ mask_f = mask.float()
337
+ N_res = mask_f.sum(dim=-1, keepdim=True)
338
+ d0 = 1.24 * (N_res.clamp(min=19) - 15) ** (1 / 3) - 1.8
339
+ tm_per_bin = 1 / (1 + (bin_centers / d0) ** 2)
340
+ pae_probs = F.softmax(pae_logits, dim=-1)
341
+ tm_expected = (pae_probs * tm_per_bin[:, None, None, :]).sum(dim=-1)
342
+
343
+ pair_mask_2d = mask_f.unsqueeze(-1) * mask_f.unsqueeze(-2)
344
+ ptm_per_row = (tm_expected * pair_mask_2d).sum(dim=-1) / (
345
+ pair_mask_2d.sum(dim=-1) + _EPS
346
+ )
347
+ ptm = ptm_per_row.max(dim=-1).values
348
+
349
+ inter_chain_mask = (
350
+ expanded_asym.unsqueeze(-1) != expanded_asym.unsqueeze(-2)
351
+ ).float() * pair_mask_2d
352
+ iptm_per_row = (tm_expected * inter_chain_mask).sum(dim=-1) / (
353
+ inter_chain_mask.sum(dim=-1) + _EPS
354
+ )
355
+ iptm = iptm_per_row.max(dim=-1).values
356
+
357
+ max_chain_id = int(expanded_asym.max().item()) if Bm > 0 else 0
358
+ n_chains = max_chain_id + 1
359
+ pair_chains_iptm = torch.zeros(
360
+ Bm, n_chains, n_chains, device=tm_expected.device, dtype=tm_expected.dtype
361
+ )
362
+ for c1 in range(n_chains):
363
+ chain_c1 = (expanded_asym == c1).float() * mask_f
364
+ if chain_c1.sum() == 0:
365
+ continue
366
+ for c2 in range(n_chains):
367
+ chain_c2 = (expanded_asym == c2).float() * mask_f
368
+ pair_m = chain_c1.unsqueeze(-1) * chain_c2.unsqueeze(-2)
369
+ denom = pair_m.sum(dim=(-1, -2)) + _EPS
370
+ pair_chains_iptm[:, c1, c2] = (tm_expected * pair_m).sum(
371
+ dim=(-1, -2)
372
+ ) / denom
373
+
374
+ return {
375
+ "plddt_logits": plddt_logits,
376
+ "plddt": plddt.detach(),
377
+ "plddt_per_atom": plddt_per_atom.detach(),
378
+ "plddt_ca": plddt_ca.detach(),
379
+ "complex_plddt": complex_plddt.detach(),
380
+ "complex_iplddt": complex_iplddt.detach(),
381
+ "pae_logits": pae_logits,
382
+ "pae": pae,
383
+ "pde_logits": pde_logits,
384
+ "pde": pde,
385
+ "resolved_logits": resolved_logits,
386
+ "ptm": ptm.detach(),
387
+ "iptm": iptm.detach(),
388
+ "pair_chains_iptm": pair_chains_iptm.detach(),
389
+ }
390
+
391
+
392
+ def _inverse_softplus(value: float) -> float:
393
+ return value + math.log(-math.expm1(-value))
394
+
395
+
396
+ def _convert_te_modules_to_fp8_inplace(module: nn.Module) -> None:
397
+ """Re-init each TE module via quantized_model_init so weights live as fp8.
398
+
399
+ Must be called inside torch.no_grad(); covers nn.Linear, te.Linear,
400
+ te.LayerNormLinear, te.LayerNormMLP — the last two hold 99% of ESMC weight.
401
+ """
402
+ if not TE_AVAILABLE:
403
+ raise RuntimeError("transformer_engine is not available; cannot use fp8.")
404
+ quantized_model_init = importlib.import_module(
405
+ "transformer_engine.pytorch"
406
+ ).quantized_model_init
407
+
408
+ def _walk(mod: nn.Module) -> None:
409
+ for name, child in list(mod.named_children()):
410
+ replaced = False
411
+ if isinstance(child, nn.Linear):
412
+ in_f, out_f = child.in_features, child.out_features
413
+ has_bias = child.bias is not None
414
+ device = child.weight.device
415
+ dtype = child.weight.dtype
416
+ w = child.weight.data
417
+ b = child.bias.data if has_bias else None
418
+ setattr(mod, name, nn.Identity())
419
+ del child
420
+ torch.cuda.empty_cache()
421
+ with quantized_model_init(enabled=True):
422
+ new_mod = te.Linear( # type: ignore[union-attr]
423
+ in_f, out_f, bias=has_bias, params_dtype=dtype
424
+ ).to(device)
425
+ new_mod.weight.quantize_(w) # type: ignore[attr-defined,operator]
426
+ if has_bias:
427
+ assert b is not None
428
+ new_mod.bias.data.copy_(b) # type: ignore[union-attr]
429
+ del w, b
430
+ replaced = True
431
+ elif isinstance(child, te.Linear): # type: ignore[union-attr]
432
+ # te.Linear with bf16 weight → re-init inside quantized_model_init for fp8.
433
+ in_f, out_f = child.in_features, child.out_features
434
+ has_bias = child.bias is not None
435
+ device = child.weight.device
436
+ dtype = (
437
+ child.weight.dtype
438
+ if not hasattr(child.weight, "_data")
439
+ else torch.bfloat16
440
+ )
441
+ state = {k: v.detach().clone() for k, v in child.state_dict().items()}
442
+ setattr(mod, name, nn.Identity())
443
+ del child
444
+ torch.cuda.empty_cache()
445
+ with quantized_model_init(enabled=True):
446
+ new_mod = te.Linear( # type: ignore[union-attr]
447
+ in_f,
448
+ out_f,
449
+ bias=has_bias,
450
+ params_dtype=dtype, # type: ignore[arg-type]
451
+ ).to(device) # type: ignore[arg-type]
452
+ new_mod.load_state_dict(state, strict=False)
453
+ replaced = True
454
+ elif (
455
+ hasattr(te, "LayerNormLinear") and isinstance(child, te.LayerNormLinear) # type: ignore[union-attr]
456
+ ):
457
+ state = {k: v.detach().clone() for k, v in child.state_dict().items()}
458
+ hidden_size = child.in_features
459
+ out_features = child.out_features
460
+ has_bias = child.use_bias
461
+ device = next(child.parameters()).device
462
+ setattr(mod, name, nn.Identity())
463
+ del child
464
+ torch.cuda.empty_cache()
465
+ with quantized_model_init(enabled=True):
466
+ new_mod = te.LayerNormLinear( # type: ignore[union-attr]
467
+ hidden_size,
468
+ out_features,
469
+ bias=has_bias,
470
+ params_dtype=torch.bfloat16,
471
+ ).to(device)
472
+ new_mod.load_state_dict(state, strict=False)
473
+ replaced = True
474
+ elif (
475
+ hasattr(te, "LayerNormMLP") and isinstance(child, te.LayerNormMLP) # type: ignore[union-attr]
476
+ ):
477
+ state = {k: v.detach().clone() for k, v in child.state_dict().items()}
478
+ fc1_weight: Tensor = child.fc1_weight # type: ignore[attr-defined]
479
+ hidden_size = int(fc1_weight.shape[1])
480
+ # fc1 packed as (2*ffn_hidden_size, hidden_size) for swiglu.
481
+ ffn_hidden_size = int(fc1_weight.shape[0]) // 2
482
+ has_bias = (
483
+ getattr(child, "fc1_bias", None) is not None
484
+ and child.fc1_bias is not None # type: ignore[attr-defined]
485
+ )
486
+ device = fc1_weight.device
487
+ setattr(mod, name, nn.Identity())
488
+ del child
489
+ torch.cuda.empty_cache()
490
+ with quantized_model_init(enabled=True):
491
+ new_mod = te.LayerNormMLP( # type: ignore[union-attr]
492
+ hidden_size=hidden_size,
493
+ ffn_hidden_size=ffn_hidden_size,
494
+ bias=has_bias,
495
+ activation="swiglu",
496
+ params_dtype=torch.bfloat16,
497
+ ).to(device) # type: ignore[arg-type]
498
+ new_mod.load_state_dict(state, strict=False)
499
+ replaced = True
500
+
501
+ if replaced:
502
+ # Freeze via .eval()+.requires_grad_(False); per-param ops would unwrap Float8Tensor.
503
+ new_mod.eval().requires_grad_(False)
504
+ setattr(mod, name, new_mod)
505
+ torch.cuda.empty_cache()
506
+ else:
507
+ _walk(child)
508
+
509
+ _walk(module)
510
+ torch.cuda.empty_cache()
511
+
512
+
513
+ @contextmanager
514
+ def _lm_precision_context(fp8: bool):
515
+ """bf16 autocast (+ optional TE fp8 autocast) around the LM forward.
516
+
517
+ te.autocast keeps te.Linear outputs bf16 instead of the fp32 default
518
+ (~425 MB at L=1024 in the hidden-state cache).
519
+ """
520
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
521
+ if fp8 and TE_AVAILABLE:
522
+ fp8_recipe = DelayedScaling( # type: ignore[misc]
523
+ fp8_format=Format.HYBRID, # type: ignore[union-attr]
524
+ amax_history_len=1,
525
+ amax_compute_algo="most_recent",
526
+ )
527
+ with te.autocast(enabled=True, recipe=fp8_recipe): # type: ignore[union-attr]
528
+ yield
529
+ else:
530
+ yield
531
+
532
+
533
+ class ESMFold2Model(PreTrainedModel):
534
+ """ESMFold2 — all-atom structure prediction with an ESMC PLM backbone.
535
+
536
+ This is the standard released ESMFold2 architecture (uses a linear-
537
+ recurrent trunk, internally referred to as "parcae").
538
+
539
+ Forward kwargs that callers commonly override:
540
+
541
+ * ``num_loops`` (default ``config.num_loops``): trunk refinement
542
+ loops.
543
+ * ``num_diffusion_samples`` (default ``config.num_diffusion_samples``):
544
+ parallel structure samples; the confidence head re-runs once per
545
+ sample, so memory scales linearly. Pass ``1`` for cheap inference.
546
+ * ``num_sampling_steps`` (default ``config.structure_head.inference_num_steps``):
547
+ diffusion ODE solver steps. Lower for speed, higher for quality.
548
+
549
+ Memory / perf knobs:
550
+
551
+ * ``model.set_chunk_size(int|None)``: caps L² ops (triangle / OPM /
552
+ pair transition) at this token-axis chunk. Default 64 — fits
553
+ L≈2k on an 80 GB GPU. Pass ``None`` for faster inference at L<600.
554
+ * ``model.set_kernel_backend(None | "fused" | "cuequivariance")``:
555
+ select kernel backend (None = reference path).
556
+ """
557
+
558
+ config_class = ESMFold2Config
559
+ _keys_to_ignore_on_load_unexpected = [r"\._extra_state$"]
560
+
561
+ def __init__(self, config: ESMFold2Config) -> None:
562
+ super().__init__(config)
563
+ d_inputs = config.inputs.d_inputs
564
+ d_pair = config.d_pair
565
+
566
+ self.inputs_embedder = InputsEmbedder(config)
567
+ self.z_init_1 = nn.Linear(d_inputs, d_pair, bias=False)
568
+ self.z_init_2 = nn.Linear(d_inputs, d_pair, bias=False)
569
+ self.rel_pos = ResIdxAsymIdSymIdEntityIdEncoding(
570
+ n_relative_residx_bins=config.n_relative_residx_bins,
571
+ n_relative_chain_bins=config.n_relative_chain_bins,
572
+ d_pair=d_pair,
573
+ )
574
+ self.token_bonds = nn.Linear(1, d_pair, bias=False)
575
+ self.language_model = LanguageModelShim(
576
+ d_z=d_pair, d_model=config.lm_d_model, num_layers=config.lm_num_layers
577
+ )
578
+ self._esmc: nn.Module | None = None
579
+ self._esmc_fp8: bool = False # set by load_esmc(fp8=True)
580
+ self._esmfold2_input_builder: Any | None = None
581
+
582
+ pf = config.folding_trunk
583
+ self.folding_trunk = FoldingTrunk(
584
+ n_layers=pf.n_layers, d_pair=d_pair, expansion_ratio=4
585
+ )
586
+ if config.lm_encoder.enabled:
587
+ self.lm_encoder: FoldingTrunk | None = FoldingTrunk(
588
+ n_layers=config.lm_encoder.n_layers, d_pair=d_pair, expansion_ratio=4
589
+ )
590
+ else:
591
+ self.lm_encoder = None
592
+
593
+ self.parcae_input_norm = nn.LayerNorm(d_pair)
594
+ self.parcae_log_a = nn.Parameter(torch.zeros(d_pair))
595
+ parcae_decay_init = math.sqrt(1.0 / 5.0)
596
+ parcae_delta_init = -math.log(parcae_decay_init)
597
+ self.parcae_log_delta = nn.Parameter(
598
+ torch.full(
599
+ (d_pair,), _inverse_softplus(parcae_delta_init), dtype=torch.float32
600
+ )
601
+ )
602
+ self.parcae_b_cont = nn.Parameter(torch.eye(d_pair))
603
+ self.parcae_readout = nn.Linear(d_pair, d_pair, bias=False)
604
+ nn.init.eye_(self.parcae_readout.weight)
605
+ self.parcae_coda = FoldingTrunk(
606
+ n_layers=config.parcae.coda_n_layers, d_pair=d_pair, expansion_ratio=4
607
+ )
608
+
609
+ # Heads --------------------------------------------------------------
610
+ self.structure_head = DiffusionStructureHead(config)
611
+ self.distogram_head = nn.Linear(
612
+ d_pair, config.structure_head.distogram_bins, bias=True
613
+ )
614
+ self.confidence_head = ConfidenceHead(config)
615
+
616
+ msa_cfg = config.msa_encoder
617
+ self.msa_encoder = None
618
+ if msa_cfg.enabled:
619
+ self.msa_encoder = MSAEncoder(
620
+ d_msa=msa_cfg.d_msa,
621
+ d_pair=d_pair,
622
+ d_inputs=d_inputs,
623
+ d_hidden=msa_cfg.d_hidden,
624
+ n_layers=msa_cfg.n_layers,
625
+ n_heads_msa=msa_cfg.n_heads_msa,
626
+ msa_head_width=msa_cfg.msa_head_width,
627
+ )
628
+
629
+ self.post_init()
630
+
631
+ def load_esmc(self, esmc_model_path: str, precision: str = "bf16") -> None:
632
+ """Load the ESMC LM.
633
+
634
+ ``precision``: ``"bf16"`` (default), ``"fp32"``, or ``"fp8"``.
635
+ ``"fp8"`` requires H100 + TransformerEngine ≥ 2.x and quantizes
636
+ every TE module's weights to fp8 storage.
637
+ """
638
+ from .modeling_esmc import ESMCModel
639
+
640
+ dtype_map = {
641
+ "bf16": torch.bfloat16,
642
+ "fp32": torch.float32,
643
+ "fp8": torch.bfloat16, # underlying weights stay bf16, TE re-quantizes to fp8
644
+ }
645
+ if precision not in dtype_map:
646
+ raise ValueError(
647
+ f"precision must be one of {list(dtype_map)}, got {precision!r}"
648
+ )
649
+ dtype = dtype_map[precision]
650
+
651
+ esmc = (
652
+ ESMCModel.from_pretrained(esmc_model_path)
653
+ .to(device=self.device, dtype=dtype)
654
+ .eval()
655
+ )
656
+ for p in esmc.parameters():
657
+ p.requires_grad_(False)
658
+
659
+ if precision == "fp8":
660
+ if not TE_AVAILABLE:
661
+ raise RuntimeError(
662
+ "transformer_engine is not available; cannot use fp8."
663
+ )
664
+ with torch.no_grad():
665
+ _convert_te_modules_to_fp8_inplace(esmc)
666
+ self._esmc_fp8 = True
667
+ else:
668
+ self._esmc_fp8 = False
669
+
670
+ self._esmc = esmc
671
+
672
+ @classmethod
673
+ def from_pretrained(
674
+ cls, pretrained_model_name_or_path, *args, load_esmc: bool = True, **kwargs
675
+ ):
676
+ if cls is ESMFold2Model and "config" not in kwargs:
677
+ config = ESMFold2Config.from_pretrained(
678
+ pretrained_model_name_or_path, **kwargs
679
+ )
680
+ if config.type == "experimental":
681
+ experimental_module = importlib.import_module(
682
+ f"{__package__}.modeling_esmfold2_experimental"
683
+ )
684
+ return experimental_module.ESMFold2ExperimentalModel.from_pretrained(
685
+ pretrained_model_name_or_path,
686
+ *args,
687
+ config=config,
688
+ load_esmc=load_esmc,
689
+ **kwargs,
690
+ )
691
+ kwargs["config"] = config
692
+ # Pop the precision knob before forwarding to the HF loader.
693
+ esmc_precision = kwargs.pop("esmc_precision", "bf16")
694
+ model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
695
+ if load_esmc:
696
+ model.load_esmc(model.config.esmc_id, precision=esmc_precision)
697
+ return model
698
+
699
+ def set_kernel_backend(self, backend: str | None) -> None:
700
+ """Select kernel backend.
701
+
702
+ Args:
703
+ backend: ``None`` (reference path), ``"fused"`` (vendored Triton
704
+ kernels), or ``"cuequivariance"`` (cuequivariance kernels
705
+ where applicable; vanilla python fallback otherwise).
706
+ """
707
+ self.folding_trunk.set_kernel_backend(backend)
708
+ if self.lm_encoder is not None:
709
+ self.lm_encoder.set_kernel_backend(backend)
710
+ self.parcae_coda.set_kernel_backend(backend)
711
+ self.confidence_head.set_kernel_backend(backend)
712
+ self.structure_head.set_kernel_backend(backend)
713
+
714
+ def apply_torch_compile(
715
+ self, mode: str = "fixed_seqlen", dynamic: bool | None = None
716
+ ) -> None:
717
+ """Compile L²-heavy blocks. ``mode='fixed_seqlen'`` recompiles per L; ``'dynamic_seqlen'`` compiles once.
718
+
719
+ Does NOT stack with our Triton kernels — call ``set_kernel_backend(None)``
720
+ before compiling.
721
+ """
722
+ import torch._dynamo
723
+
724
+ torch._dynamo.config.cache_size_limit = 512 # type: ignore[attr-defined]
725
+ torch._dynamo.config.accumulated_cache_size_limit = 512 # type: ignore[attr-defined]
726
+ # capture_scalar_outputs avoids graph breaks at .item() in atom-attention path.
727
+ torch._dynamo.config.capture_scalar_outputs = True # type: ignore[attr-defined]
728
+
729
+ if dynamic is None:
730
+ dynamic = mode == "dynamic_seqlen"
731
+ kwargs: dict = {"dynamic": dynamic}
732
+
733
+ from .modeling_esmfold2_common import (
734
+ DiffusionModule,
735
+ DiffusionTransformer,
736
+ PairUpdateBlock,
737
+ )
738
+
739
+ compile_targets = (
740
+ PairUpdateBlock,
741
+ DiffusionTransformer,
742
+ DiffusionModule,
743
+ MSAEncoderBlock,
744
+ )
745
+
746
+ def _maybe_compile(module: nn.Module) -> None:
747
+ if isinstance(module, compile_targets):
748
+ module.forward = torch.compile(module.forward, **kwargs) # type: ignore[assignment]
749
+
750
+ self.apply(_maybe_compile)
751
+
752
+ def set_chunk_size(self, chunk_size: int | None) -> None:
753
+ self.folding_trunk.set_chunk_size(chunk_size)
754
+ if self.lm_encoder is not None:
755
+ self.lm_encoder.set_chunk_size(chunk_size)
756
+ self.parcae_coda.set_chunk_size(chunk_size)
757
+ self.confidence_head.set_chunk_size(chunk_size)
758
+ if self.msa_encoder is not None:
759
+ self.msa_encoder.set_chunk_size(chunk_size)
760
+
761
+ def _compute_lm_hidden_states(
762
+ self,
763
+ input_ids: Tensor,
764
+ asym_id: Tensor,
765
+ residue_index: Tensor,
766
+ mol_type: Tensor,
767
+ tok_mask: Tensor,
768
+ ) -> Tensor:
769
+ assert self._esmc is not None
770
+ # fp8 TE kernels require prod(shape[:-1]) % 8 == 0.
771
+ pad_to = 8 if self._esmc_fp8 else None
772
+ with _lm_precision_context(self._esmc_fp8):
773
+ return compute_lm_hidden_states(
774
+ self._esmc,
775
+ input_ids,
776
+ asym_id,
777
+ residue_index,
778
+ mol_type,
779
+ tok_mask,
780
+ pad_to_multiple=pad_to,
781
+ )
782
+
783
+ def _discretized_dynamics(self) -> tuple[Tensor, Tensor]:
784
+ delta = F.softplus(self.parcae_log_delta)
785
+ a = torch.exp(-delta * torch.exp(self.parcae_log_a))
786
+ b = delta[:, None] * self.parcae_b_cont
787
+ return a, b
788
+
789
+ def _init_pair_state(self, ref: Tensor) -> Tensor:
790
+ std = math.sqrt(2.0 / (5.0 * ref.shape[-1]))
791
+ state = torch.empty_like(ref, dtype=torch.float32)
792
+ nn.init.trunc_normal_(state, mean=0.0, std=std, a=-3 * std, b=3 * std)
793
+ return state.to(dtype=ref.dtype)
794
+
795
+ def _run_one_loop(
796
+ self,
797
+ z: Tensor,
798
+ z_init: Tensor,
799
+ lm_z: Tensor | None,
800
+ _msa_kwargs: dict | None,
801
+ pair_mask: Tensor,
802
+ a: Tensor,
803
+ b_mat: Tensor,
804
+ total_steps: int,
805
+ ) -> Tensor:
806
+ # Helper method (not inline) so per-iter locals free on return —
807
+ # otherwise leaks ~2 GB L²×c_z into distogram/sample scope.
808
+ # training=True forces dropout under eval(), matching the per-loop
809
+ # dropout strategy used at train time.
810
+ lm_cfg = self.config.lm_encoder
811
+ _per_loop_lm_dropout = (
812
+ lm_z is not None
813
+ and getattr(lm_cfg, "per_loop_lm_dropout", False)
814
+ and getattr(lm_cfg, "lm_dropout", 0.0) > 0.0
815
+ )
816
+ _lm_dropout_p = getattr(lm_cfg, "lm_dropout", 0.0)
817
+
818
+ for _ in range(total_steps):
819
+ if _per_loop_lm_dropout:
820
+ assert lm_z is not None # narrowed by _per_loop_lm_dropout
821
+ lm_z_i: Tensor | None = F.dropout(lm_z, p=_lm_dropout_p, training=True)
822
+ else:
823
+ lm_z_i = lm_z
824
+
825
+ refined_lm_z: Tensor | None = None
826
+ if lm_z_i is not None and self.lm_encoder is not None:
827
+ refined_lm_z = self.lm_encoder(
828
+ lm_z_i.to(z_init.dtype), pair_attention_mask=pair_mask
829
+ )
830
+
831
+ z_inject_pair = z_init
832
+ if lm_z_i is not None and self.lm_encoder is None:
833
+ z_inject_pair = z_inject_pair + lm_z_i.to(z_inject_pair.dtype)
834
+
835
+ if self.msa_encoder is not None and _msa_kwargs is not None:
836
+ msa_pair = self.msa_encoder(x_pair=z_inject_pair, **_msa_kwargs).to(
837
+ z_inject_pair.dtype
838
+ )
839
+ z_inject_pair = (
840
+ msa_pair
841
+ if self.config.msa_encoder_overwrite
842
+ else (z_inject_pair + msa_pair)
843
+ )
844
+
845
+ if refined_lm_z is not None:
846
+ z_inject_pair = z_inject_pair + refined_lm_z.to(z_inject_pair.dtype)
847
+
848
+ injected_pair = self.parcae_input_norm(z_inject_pair)
849
+ z = a * z + F.linear(injected_pair.to(z.dtype), b_mat)
850
+ z = self.folding_trunk(z, pair_attention_mask=pair_mask)
851
+
852
+ return z
853
+
854
+ @torch.inference_mode()
855
+ def forward(
856
+ self,
857
+ token_index: Tensor,
858
+ residue_index: Tensor,
859
+ asym_id: Tensor,
860
+ sym_id: Tensor,
861
+ entity_id: Tensor,
862
+ mol_type: Tensor,
863
+ res_type: Tensor,
864
+ token_bonds: Tensor,
865
+ token_attention_mask: Tensor,
866
+ ref_pos: Tensor,
867
+ ref_element: Tensor,
868
+ ref_charge: Tensor,
869
+ ref_atom_name_chars: Tensor,
870
+ ref_space_uid: Tensor,
871
+ atom_attention_mask: Tensor,
872
+ atom_to_token: Tensor,
873
+ distogram_atom_idx: Tensor,
874
+ deletion_mean: Tensor | None = None,
875
+ msa: Tensor | None = None,
876
+ has_deletion: Tensor | None = None,
877
+ deletion_value: Tensor | None = None,
878
+ msa_attention_mask: Tensor | None = None,
879
+ input_ids: Tensor | None = None,
880
+ lm_hidden_states: Tensor | None = None,
881
+ num_loops: int | None = None,
882
+ num_diffusion_samples: int | None = None,
883
+ num_sampling_steps: int | None = None,
884
+ **kwargs,
885
+ ) -> dict[str, Tensor]:
886
+ tok_mask = token_attention_mask
887
+ atm_mask = atom_attention_mask
888
+ disto_idx = distogram_atom_idx
889
+
890
+ n_loops: int = num_loops if num_loops is not None else self.config.num_loops
891
+ n_samples: int = (
892
+ num_diffusion_samples
893
+ if num_diffusion_samples is not None
894
+ else self.config.num_diffusion_samples
895
+ )
896
+ total_steps = max(1, n_loops + 1)
897
+
898
+ if res_type.dim() == 2:
899
+ res_type_oh = F.one_hot(res_type.long(), num_classes=NUM_RES_TYPES).float()
900
+ res_type_oh = res_type_oh * tok_mask.unsqueeze(-1).float()
901
+ else:
902
+ res_type_oh = res_type.float()
903
+
904
+ if msa is not None:
905
+ msa_oh_profile = F.one_hot(msa.long(), num_classes=NUM_RES_TYPES).float()
906
+ if msa_attention_mask is not None:
907
+ mask_f = msa_attention_mask.float().unsqueeze(-1)
908
+ msa_oh_profile = msa_oh_profile * mask_f
909
+ valid_seq_count = msa_attention_mask.float().sum(dim=1).clamp(min=1)
910
+ profile = msa_oh_profile.sum(dim=1) / valid_seq_count.unsqueeze(-1)
911
+ else:
912
+ profile = msa_oh_profile.mean(dim=1)
913
+ else:
914
+ profile = res_type_oh
915
+
916
+ if deletion_mean is None:
917
+ deletion_mean = torch.zeros(
918
+ res_type.shape[0], res_type.shape[1], device=res_type.device
919
+ )
920
+
921
+ ref_element_oh = F.one_hot(
922
+ ref_element.long(), num_classes=MAX_ATOMIC_NUMBER
923
+ ).float()
924
+ ref_atom_name_chars_oh = F.one_hot(
925
+ ref_atom_name_chars.long(), num_classes=CHAR_VOCAB_SIZE
926
+ ).float()
927
+ # Bias-free downstream Linears require zeroed padding.
928
+ atm_mask_f = atm_mask.float()
929
+ ref_element_oh = ref_element_oh * atm_mask_f.unsqueeze(-1)
930
+ ref_atom_name_chars_oh = ref_atom_name_chars_oh * atm_mask_f.unsqueeze(
931
+ -1
932
+ ).unsqueeze(-1)
933
+ atom_to_token = atom_to_token * atm_mask.long()
934
+
935
+ use_amp = ref_pos.device.type == "cuda"
936
+ with torch.amp.autocast("cuda", enabled=use_amp, dtype=torch.bfloat16):
937
+ x_inputs = self.inputs_embedder(
938
+ aatype=res_type_oh,
939
+ profile=profile.float(),
940
+ deletion_mean=deletion_mean.float(),
941
+ ref_pos=ref_pos,
942
+ atom_attention_mask=atm_mask,
943
+ ref_space_uid=ref_space_uid,
944
+ ref_charge=ref_charge,
945
+ ref_element=ref_element_oh,
946
+ ref_atom_name_chars=ref_atom_name_chars_oh,
947
+ atom_to_token=atom_to_token,
948
+ )
949
+
950
+ z_init = self.z_init_1(x_inputs).unsqueeze(2) + self.z_init_2(
951
+ x_inputs
952
+ ).unsqueeze(1)
953
+
954
+ relative_position_encoding = self.rel_pos(
955
+ residue_index=residue_index,
956
+ asym_id=asym_id,
957
+ sym_id=sym_id,
958
+ entity_id=entity_id,
959
+ token_index=token_index,
960
+ )
961
+ token_bonds_encoding = self.token_bonds(token_bonds.float())
962
+ z_init = z_init + relative_position_encoding + token_bonds_encoding
963
+
964
+ if (
965
+ lm_hidden_states is None
966
+ and input_ids is not None
967
+ and self._esmc is not None
968
+ ):
969
+ lm_hidden_states = self._compute_lm_hidden_states(
970
+ input_ids, asym_id, residue_index, mol_type, tok_mask
971
+ )
972
+ lm_z: Tensor | None = None
973
+ if lm_hidden_states is not None:
974
+ lm_z = self.language_model(lm_hidden_states.detach())
975
+ del lm_hidden_states
976
+
977
+ pair_mask = tok_mask[:, :, None].float() * tok_mask[:, None, :].float()
978
+
979
+ z = self._init_pair_state(z_init)
980
+
981
+ a, b = self._discretized_dynamics()
982
+ a = a.view(1, 1, 1, -1).to(device=z.device, dtype=z.dtype)
983
+ b_mat = b.to(device=z.device, dtype=z.dtype)
984
+
985
+ _msa_kwargs: dict | None = None
986
+ if self.msa_encoder is not None and msa is not None:
987
+ B_msa, M, L_msa = msa.shape
988
+ msa_oh = F.one_hot(
989
+ msa.permute(0, 2, 1).long(), num_classes=NUM_RES_TYPES
990
+ ).float()
991
+ msa_attn = (
992
+ msa_attention_mask.permute(0, 2, 1).float()
993
+ if msa_attention_mask is not None
994
+ else tok_mask[:, :, None].expand(-1, -1, M).float()
995
+ )
996
+ # Bias-free MSAEncoder.embed requires zeroed padding.
997
+ msa_oh = msa_oh * msa_attn.unsqueeze(-1)
998
+ hd = (
999
+ has_deletion.permute(0, 2, 1).float()
1000
+ if has_deletion is not None
1001
+ else torch.zeros(B_msa, L_msa, M, device=msa.device)
1002
+ )
1003
+ dv = (
1004
+ deletion_value.permute(0, 2, 1).float()
1005
+ if deletion_value is not None
1006
+ else torch.zeros(B_msa, L_msa, M, device=msa.device)
1007
+ )
1008
+ _msa_kwargs = dict(
1009
+ x_inputs=x_inputs,
1010
+ msa_oh=msa_oh,
1011
+ has_deletion=hd,
1012
+ deletion_value=dv,
1013
+ msa_attention_mask=msa_attn,
1014
+ )
1015
+
1016
+ # Method call (not inline loop) frees per-iter L²×c_z locals.
1017
+ z = self._run_one_loop(
1018
+ z=z,
1019
+ z_init=z_init,
1020
+ lm_z=lm_z,
1021
+ _msa_kwargs=_msa_kwargs,
1022
+ pair_mask=pair_mask,
1023
+ a=a,
1024
+ b_mat=b_mat,
1025
+ total_steps=total_steps,
1026
+ )
1027
+ del z_init, lm_z, _msa_kwargs, a, b_mat
1028
+
1029
+ z = self.parcae_readout(z)
1030
+ z = self.parcae_coda(z, pair_attention_mask=pair_mask)
1031
+
1032
+ z = z.float()
1033
+ distogram_logits = self.distogram_head(z + z.transpose(-2, -3))
1034
+
1035
+ structure_output = self.structure_head.sample(
1036
+ z_trunk=z,
1037
+ s_inputs=x_inputs,
1038
+ s_trunk=None,
1039
+ relative_position_encoding=relative_position_encoding,
1040
+ ref_pos=ref_pos,
1041
+ ref_charge=ref_charge,
1042
+ ref_mask=atm_mask,
1043
+ ref_element=ref_element_oh,
1044
+ ref_atom_name_chars=ref_atom_name_chars_oh,
1045
+ ref_space_uid=ref_space_uid,
1046
+ tok_idx=atom_to_token,
1047
+ asym_id=asym_id,
1048
+ residue_index=residue_index,
1049
+ entity_id=entity_id,
1050
+ token_index=token_index,
1051
+ sym_id=sym_id,
1052
+ token_attention_mask=tok_mask,
1053
+ num_diffusion_samples=n_samples,
1054
+ num_sampling_steps=num_sampling_steps,
1055
+ return_atom_repr=False,
1056
+ denoising_early_exit_rmsd=None,
1057
+ )
1058
+
1059
+ sample_coords = structure_output["sample_atom_coords"]
1060
+ assert sample_coords is not None
1061
+ output: dict[str, Tensor] = {"distogram_logits": distogram_logits}
1062
+ output["sample_atom_coords"] = sample_coords
1063
+
1064
+ confidence_output = self.confidence_head(
1065
+ s_inputs=x_inputs.detach(),
1066
+ z=z.detach().float(),
1067
+ x_pred=sample_coords.detach(),
1068
+ distogram_atom_idx=disto_idx,
1069
+ token_attention_mask=tok_mask,
1070
+ atom_to_token=atom_to_token,
1071
+ atom_attention_mask=atm_mask,
1072
+ asym_id=asym_id,
1073
+ mol_type=mol_type,
1074
+ num_diffusion_samples=n_samples,
1075
+ relative_position_encoding=relative_position_encoding.detach(),
1076
+ token_bonds_encoding=token_bonds_encoding.detach(),
1077
+ )
1078
+ output.update(confidence_output)
1079
+ output["atom_pad_mask"] = (
1080
+ atm_mask.unsqueeze(0) if atm_mask.dim() == 1 else atm_mask
1081
+ )
1082
+ output["residue_index"] = residue_index
1083
+ output["entity_id"] = entity_id
1084
+ return output
1085
+
1086
+ @torch.no_grad()
1087
+ def infer_protein(self, seq: str, **forward_kwargs) -> dict:
1088
+ from .protein_utils import prepare_protein_features
1089
+
1090
+ features = prepare_protein_features(seq)
1091
+ features = {k: v.to(self.device) for k, v in features.items()}
1092
+ return self(**features, **forward_kwargs)
1093
+
1094
+ @property
1095
+ def input_builder(self):
1096
+ if self._esmfold2_input_builder is None:
1097
+ from .esmfold2_processor import ESMFold2InputBuilder
1098
+
1099
+ self._esmfold2_input_builder = ESMFold2InputBuilder()
1100
+ return self._esmfold2_input_builder
1101
+
1102
+ @property
1103
+ def input_types(self):
1104
+ from . import esmfold2_types
1105
+
1106
+ return esmfold2_types
1107
+
1108
+ def prepare_structure_input(self, input, seed: int | None = None):
1109
+ return self.input_builder.prepare_input(input, seed=seed, device=self.device)
1110
+
1111
+ def fold(
1112
+ self,
1113
+ input,
1114
+ *,
1115
+ num_loops: int = 3,
1116
+ num_sampling_steps: int = 50,
1117
+ num_diffusion_samples: int = 1,
1118
+ seed: int | None = None,
1119
+ noise_scale: float | None = None,
1120
+ step_scale: float | None = None,
1121
+ max_inference_sigma: int | None = None,
1122
+ early_exit: bool = False,
1123
+ complex_id: str = "pred",
1124
+ ):
1125
+ return self.input_builder.fold(
1126
+ self,
1127
+ input,
1128
+ num_loops=num_loops,
1129
+ num_sampling_steps=num_sampling_steps,
1130
+ num_diffusion_samples=num_diffusion_samples,
1131
+ seed=seed,
1132
+ noise_scale=noise_scale,
1133
+ step_scale=step_scale,
1134
+ max_inference_sigma=max_inference_sigma,
1135
+ early_exit=early_exit,
1136
+ complex_id=complex_id,
1137
+ )
1138
+
1139
+ def fold_protein(
1140
+ self,
1141
+ sequence: str,
1142
+ *,
1143
+ chain_id: str = "A",
1144
+ num_loops: int = 3,
1145
+ num_sampling_steps: int = 50,
1146
+ num_diffusion_samples: int = 1,
1147
+ seed: int | None = None,
1148
+ complex_id: str = "pred",
1149
+ ):
1150
+ from .esmfold2_types import ProteinInput, StructurePredictionInput
1151
+
1152
+ input = StructurePredictionInput(
1153
+ sequences=[ProteinInput(id=chain_id, sequence=sequence)]
1154
+ )
1155
+ return self.fold(
1156
+ input,
1157
+ num_loops=num_loops,
1158
+ num_sampling_steps=num_sampling_steps,
1159
+ num_diffusion_samples=num_diffusion_samples,
1160
+ seed=seed,
1161
+ complex_id=complex_id,
1162
+ )
1163
+
1164
+ @staticmethod
1165
+ def result_to_cif(result) -> str:
1166
+ assert not isinstance(result, list), "Pass one MolecularComplexResult at a time."
1167
+ return result.complex.to_mmcif()
1168
+
1169
+ @staticmethod
1170
+ def result_to_pdb(result) -> str:
1171
+ assert not isinstance(result, list), "Pass one MolecularComplexResult at a time."
1172
+ return result.complex.to_protein_complex().to_pdb_string()
1173
+
1174
+ def save_as_cif(self, result, output_path: str | Path) -> None:
1175
+ Path(output_path).write_text(self.result_to_cif(result))
1176
+
1177
+ def save_as_pdb(self, result, output_path: str | Path) -> None:
1178
+ Path(output_path).write_text(self.result_to_pdb(result))
1179
+
1180
+ def infer_protein_as_cif(self, seq: str, **forward_kwargs) -> str:
1181
+ return self.result_to_cif(self.fold_protein(seq, **forward_kwargs))
1182
+
1183
+ def infer_protein_as_pdb(self, seq: str, **forward_kwargs) -> str:
1184
+ return self.result_to_pdb(self.fold_protein(seq, **forward_kwargs))
1185
+
1186
+
1187
+ class MSAEncoderBlock(nn.Module):
1188
+ """One MSA encoder block: OPM into pair, MSA pair-weighted averaging, triangle update."""
1189
+
1190
+ def __init__(
1191
+ self,
1192
+ d_msa: int,
1193
+ d_pair: int,
1194
+ d_hidden: int,
1195
+ n_heads_msa: int,
1196
+ msa_head_width: int,
1197
+ is_final_block: bool = False,
1198
+ ) -> None:
1199
+ super().__init__()
1200
+ self.is_final_block = is_final_block
1201
+ self.outer_product_mean = OuterProductMean(d_msa, d_hidden, d_pair)
1202
+ if not is_final_block:
1203
+ self.msa_pair_weighted_averaging = MSAPairWeightedAveraging(
1204
+ d_msa, d_pair, n_heads_msa, msa_head_width
1205
+ )
1206
+ self.msa_transition = PairTransition(d_msa, expansion_ratio=4)
1207
+ self.tri_mul_out = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=True)
1208
+ self.tri_mul_in = TriangleMultiplicativeUpdate(dim=d_pair, _outgoing=False)
1209
+ self.pair_transition = PairTransition(d_pair, expansion_ratio=4)
1210
+
1211
+ def set_chunk_size(self, chunk_size: int | None) -> None:
1212
+ self.outer_product_mean.set_chunk_size(chunk_size)
1213
+ self.tri_mul_out.set_chunk_size(chunk_size)
1214
+ self.tri_mul_in.set_chunk_size(chunk_size)
1215
+ if not self.is_final_block:
1216
+ self.msa_transition.set_chunk_size(chunk_size)
1217
+ self.pair_transition.set_chunk_size(chunk_size)
1218
+
1219
+ def forward(
1220
+ self,
1221
+ m: Tensor,
1222
+ pair: Tensor,
1223
+ msa_attention_mask: Tensor,
1224
+ pair_attention_mask: Tensor,
1225
+ ) -> tuple[Tensor, Tensor]:
1226
+ pair = pair + self.outer_product_mean(m, msa_attention_mask)
1227
+ if not self.is_final_block:
1228
+ m = m + self.msa_pair_weighted_averaging(m, pair, pair_attention_mask)
1229
+ m = m + self.msa_transition(m)
1230
+ pair = pair + self.tri_mul_out(pair, mask=pair_attention_mask)
1231
+ pair = pair + self.tri_mul_in(pair, mask=pair_attention_mask)
1232
+ pair = pair + self.pair_transition(pair)
1233
+ return m, pair
1234
+
1235
+
1236
+ class MSAEncoder(nn.Module):
1237
+ """Stack of [`MSAEncoderBlock`] layers that conditions the pair on an MSA."""
1238
+
1239
+ def __init__(
1240
+ self,
1241
+ d_msa: int,
1242
+ d_pair: int,
1243
+ d_inputs: int,
1244
+ d_hidden: int = 32,
1245
+ n_layers: int = 4,
1246
+ n_heads_msa: int = 8,
1247
+ msa_head_width: int = 16,
1248
+ ) -> None:
1249
+ super().__init__()
1250
+ self.embed = nn.Linear(35, d_msa, bias=False)
1251
+ self.project_inputs = nn.Linear(d_inputs, d_msa, bias=False)
1252
+ self.blocks = nn.ModuleList(
1253
+ [
1254
+ MSAEncoderBlock(
1255
+ d_msa=d_msa,
1256
+ d_pair=d_pair,
1257
+ d_hidden=d_hidden,
1258
+ n_heads_msa=n_heads_msa,
1259
+ msa_head_width=msa_head_width,
1260
+ is_final_block=(i == n_layers - 1),
1261
+ )
1262
+ for i in range(n_layers)
1263
+ ]
1264
+ )
1265
+
1266
+ def set_chunk_size(self, chunk_size: int | None) -> None:
1267
+ for block in self.blocks:
1268
+ cast(MSAEncoderBlock, block).set_chunk_size(chunk_size)
1269
+
1270
+ def forward(
1271
+ self,
1272
+ x_pair: Tensor,
1273
+ x_inputs: Tensor,
1274
+ msa_oh: Tensor,
1275
+ has_deletion: Tensor,
1276
+ deletion_value: Tensor,
1277
+ msa_attention_mask: Tensor,
1278
+ ) -> Tensor:
1279
+ # All inputs are pre-transposed to [B, L, M, ...] before calling.
1280
+ m_feat = torch.cat(
1281
+ [msa_oh, has_deletion.unsqueeze(-1), deletion_value.unsqueeze(-1)], dim=-1
1282
+ )
1283
+ m = self.embed(m_feat) + self.project_inputs(x_inputs).unsqueeze(2)
1284
+ tok_mask = msa_attention_mask[:, :, 0].bool()
1285
+ pair_attention_mask = tok_mask.unsqueeze(2) & tok_mask.unsqueeze(1)
1286
+ for block in self.blocks:
1287
+ m, x_pair = block(m, x_pair, msa_attention_mask, pair_attention_mask)
1288
+ return x_pair
modeling_esmfold2_common.py ADDED
The diff for this file is too large to render. See raw diff
 
protein_utils.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 Biohub. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Self-contained protein featurization for ESMFold2 inference.
16
+
17
+ Lets ``ESMFold2ExperimentalModel.infer_protein_as_pdb`` fold a protein sequence
18
+ ESMFold-style without the ``esm`` companion package. The featurization
19
+ mirrors ``ESMFold2InputBuilder.prepare_input`` for the protein-only path —
20
+ ``test_prepare_protein_features.py`` enforces tensor-exact parity.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import math
26
+
27
+ import torch
28
+ from torch import Tensor
29
+
30
+ MOL_TYPE_PROTEIN = 0
31
+ PROTEIN_UNK_RES_TYPE = 22
32
+ MSA_GAP_TOKEN_ID = 1
33
+
34
+ PROTEIN_RESIDUE_TO_RES_TYPE: dict[str, int] = {
35
+ "ALA": 2,
36
+ "ARG": 3,
37
+ "ASN": 4,
38
+ "ASP": 5,
39
+ "CYS": 6,
40
+ "GLN": 7,
41
+ "GLU": 8,
42
+ "GLY": 9,
43
+ "HIS": 10,
44
+ "ILE": 11,
45
+ "LEU": 12,
46
+ "LYS": 13,
47
+ "MET": 14,
48
+ "PHE": 15,
49
+ "PRO": 16,
50
+ "SER": 17,
51
+ "THR": 18,
52
+ "TRP": 19,
53
+ "TYR": 20,
54
+ "VAL": 21,
55
+ }
56
+
57
+ PROTEIN_1TO3: dict[str, str] = {
58
+ "A": "ALA",
59
+ "R": "ARG",
60
+ "N": "ASN",
61
+ "D": "ASP",
62
+ "C": "CYS",
63
+ "Q": "GLN",
64
+ "E": "GLU",
65
+ "G": "GLY",
66
+ "H": "HIS",
67
+ "I": "ILE",
68
+ "L": "LEU",
69
+ "K": "LYS",
70
+ "M": "MET",
71
+ "F": "PHE",
72
+ "P": "PRO",
73
+ "S": "SER",
74
+ "T": "THR",
75
+ "W": "TRP",
76
+ "Y": "TYR",
77
+ "V": "VAL",
78
+ "X": "UNK",
79
+ }
80
+
81
+ ESM_PROTEIN_VOCAB: dict[str, int] = {
82
+ "L": 4,
83
+ "A": 5,
84
+ "G": 6,
85
+ "V": 7,
86
+ "S": 8,
87
+ "E": 9,
88
+ "R": 10,
89
+ "T": 11,
90
+ "I": 12,
91
+ "D": 13,
92
+ "P": 14,
93
+ "K": 15,
94
+ "Q": 16,
95
+ "N": 17,
96
+ "F": 18,
97
+ "Y": 19,
98
+ "M": 20,
99
+ "H": 21,
100
+ "W": 22,
101
+ "C": 23,
102
+ "X": 3,
103
+ }
104
+
105
+ # Heavy atoms per canonical residue, in training-time order.
106
+ PROTEIN_HEAVY_ATOMS: dict[str, list[str]] = {
107
+ "ALA": ["N", "CA", "C", "O", "CB"],
108
+ "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
109
+ "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"],
110
+ "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"],
111
+ "CYS": ["N", "CA", "C", "O", "CB", "SG"],
112
+ "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"],
113
+ "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"],
114
+ "GLY": ["N", "CA", "C", "O"],
115
+ "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"],
116
+ "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"],
117
+ "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"],
118
+ "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"],
119
+ "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
120
+ "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
121
+ "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD"],
122
+ "SER": ["N", "CA", "C", "O", "CB", "OG"],
123
+ "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2"],
124
+ "TRP": [
125
+ "N",
126
+ "CA",
127
+ "C",
128
+ "O",
129
+ "CB",
130
+ "CG",
131
+ "CD1",
132
+ "CD2",
133
+ "NE1",
134
+ "CE2",
135
+ "CE3",
136
+ "CZ2",
137
+ "CZ3",
138
+ "CH2",
139
+ ],
140
+ "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"],
141
+ "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2"],
142
+ "UNK": ["N", "CA", "C", "O"],
143
+ }
144
+
145
+ PROTEIN_REF_POS: dict[str, dict[str, tuple[float, float, float]]] = {
146
+ "ALA": {
147
+ "N": (-0.01003183238208294, -1.2073018550872803, -1.0555061101913452),
148
+ "CA": (-0.04190138354897499, 0.17447763681411743, -0.5729365348815918),
149
+ "C": (1.2127548456192017, 0.4737588167190552, 0.19521640241146088),
150
+ "O": (1.9390329122543335, 1.4484562873840332, -0.13759790360927582),
151
+ "CB": (-1.276943325996399, 0.4288230538368225, 0.29937705397605896),
152
+ },
153
+ "ARG": {
154
+ "N": (-2.0170421600341797, 0.6717798113822937, -1.1794233322143555),
155
+ "CA": (-2.0503084659576416, -0.5735036730766296, -0.4097220301628113),
156
+ "C": (-3.469440460205078, -1.0612813234329224, -0.2755832374095917),
157
+ "O": (-3.8218462467193604, -2.1369943618774414, -0.8294969797134399),
158
+ "CB": (-1.4193516969680786, -0.3735991418361664, 0.9852858781814575),
159
+ "CG": (0.11878877878189087, -0.3112654983997345, 0.963895857334137),
160
+ "CD": (0.6643245816230774, 1.0068185329437256, 0.3963329493999481),
161
+ "NE": (2.1090238094329834, 1.0977025032043457, 0.6120952367782593),
162
+ "CZ": (3.098905324935913, 0.3215920031070709, -0.09047172218561172),
163
+ "NH1": (4.461230278015137, 0.3844667971134186, 0.34141138195991516),
164
+ "NH2": (2.7856509685516357, -0.4166366159915924, -1.1148239374160767),
165
+ },
166
+ "ASN": {
167
+ "N": (-0.7595629096031189, 0.7503494620323181, 1.1369825601577759),
168
+ "CA": (-0.76087886095047, 0.23876343667507172, -0.23573364317417145),
169
+ "C": (-1.9211044311523438, -0.6982439160346985, -0.42196929454803467),
170
+ "O": (-2.677666187286377, -0.5753439664840698, -1.4223182201385498),
171
+ "CB": (0.5504899024963379, -0.5078350305557251, -0.5390339493751526),
172
+ "CG": (1.7250099182128906, 0.4264017939567566, -0.5778228640556335),
173
+ "OD1": (1.9470350742340088, 1.1086392402648926, -1.613560438156128),
174
+ "ND2": (2.57365345954895, 0.5730618834495544, 0.5608599781990051),
175
+ },
176
+ "ASP": {
177
+ "N": (-1.8452696800231934, -1.2169504165649414, 0.19437327980995178),
178
+ "CA": (-0.6379959583282471, -0.41974392533302307, 0.41681644320487976),
179
+ "C": (-0.9431572556495667, 1.0356197357177734, 0.18555717170238495),
180
+ "O": (-1.5183608531951904, 1.4045922756195068, -0.8739855885505676),
181
+ "CB": (0.48594576120376587, -0.8970447778701782, -0.5209363698959351),
182
+ "CG": (1.780342936515808, -0.19918935000896454, -0.2310730367898941),
183
+ "OD1": (2.5202910900115967, -0.6044584512710571, 0.7049641013145447),
184
+ "OD2": (2.1454880237579346, 0.9208861589431763, -0.9712985157966614),
185
+ },
186
+ "CYS": {
187
+ "N": (0.0469963513314724, 1.190075159072876, -1.1607273817062378),
188
+ "CA": (0.11344368755817413, -0.09400428831577301, -0.45952197909355164),
189
+ "C": (-1.2652032375335693, -0.6832379698753357, -0.3594406247138977),
190
+ "O": (-1.4631439447402954, -1.8851220607757568, -0.6826791763305664),
191
+ "CB": (0.6919880509376526, 0.09034398198127747, 0.952482283115387),
192
+ "SG": (2.4619927406311035, 0.5235707759857178, 0.9020372629165649),
193
+ },
194
+ "GLN": {
195
+ "N": (-2.370004653930664, -0.9637529850006104, -0.7942749261856079),
196
+ "CA": (-1.370002269744873, -0.6000258922576904, 0.2103111445903778),
197
+ "C": (-1.7545503377914429, 0.7091967463493347, 0.8433493971824646),
198
+ "O": (-1.8520662784576416, 0.7999289631843567, 2.0964975357055664),
199
+ "CB": (0.02040259726345539, -0.5004461407661438, -0.44764479994773865),
200
+ "CG": (1.1377512216567993, -0.28680720925331116, 0.582992434501648),
201
+ "CD": (2.4745187759399414, -0.24800164997577667, -0.09364881366491318),
202
+ "OE1": (3.1685523986816406, -1.2966246604919434, -0.1717153936624527),
203
+ "NE2": (2.947425603866577, 0.9601329565048218, -0.6888364553451538),
204
+ },
205
+ "GLU": {
206
+ "N": (-1.5850872993469238, -1.337684154510498, 0.9490851163864136),
207
+ "CA": (-1.0560977458953857, 0.027459044009447098, 1.0306966304779053),
208
+ "C": (-1.7741456031799316, 0.9664392471313477, 0.09259600937366486),
209
+ "O": (-1.9012441635131836, 2.181349992752075, 0.402479350566864),
210
+ "CB": (0.4706551432609558, 0.048803869634866714, 0.8114414811134338),
211
+ "CG": (0.9133604764938354, -0.4219329059123993, -0.5830985307693481),
212
+ "CD": (2.398822069168091, -0.3097084164619446, -0.7210537791252136),
213
+ "OE1": (3.1389315128326416, -1.274524450302124, -0.39029765129089355),
214
+ "OE2": (2.9647817611694336, 0.8781346082687378, -1.1732689142227173),
215
+ },
216
+ "GLY": {
217
+ "N": (-1.3942985534667969, -0.39875128865242004, -0.3370324671268463),
218
+ "CA": (-0.39974430203437805, 0.5488945245742798, 0.15242962539196014),
219
+ "C": (0.9440054893493652, -0.10314033925533295, 0.19859643280506134),
220
+ "O": (1.3352899551391602, -0.669218122959137, 1.2541258335113525),
221
+ },
222
+ "HIS": {
223
+ "N": (-1.4532867670059204, -1.0689626932144165, 0.881072461605072),
224
+ "CA": (-1.3396095037460327, 0.24797579646110535, 0.24960045516490936),
225
+ "C": (-2.675257921218872, 0.6571555733680725, -0.30441102385520935),
226
+ "O": (-3.1311378479003906, 1.8079776763916016, -0.06785715371370316),
227
+ "CB": (-0.3041955828666687, 0.21721023321151733, -0.8885309100151062),
228
+ "CG": (1.0887513160705566, 0.028941065073013306, -0.36419469118118286),
229
+ "ND1": (1.840459942817688, 1.0411773920059204, 0.29804590344429016),
230
+ "CD2": (1.780855417251587, -1.1011489629745483, -0.3814258575439453),
231
+ "CE1": (2.9566943645477295, 0.4924798905849457, 0.6477115750312805),
232
+ "NE2": (3.0280203819274902, -0.8751969337463379, 0.26084381341934204),
233
+ },
234
+ "ILE": {
235
+ "N": (-0.7167549729347229, -1.5426139831542969, -0.9983330368995667),
236
+ "CA": (-1.0636085271835327, -0.35169270634651184, -0.21393552422523499),
237
+ "C": (-1.3896740674972534, 0.8142145276069641, -1.1164065599441528),
238
+ "O": (-1.2377792596817017, 0.7302915453910828, -2.3656840324401855),
239
+ "CB": (0.061667006462812424, 0.01599610224366188, 0.8057394623756409),
240
+ "CG1": (1.502519965171814, -0.08899776637554169, 0.24154816567897797),
241
+ "CG2": (-0.053174979984760284, -0.8521055579185486, 2.0702083110809326),
242
+ "CD1": (1.7929610013961792, 0.899773120880127, -0.8863027691841125),
243
+ },
244
+ "LEU": {
245
+ "N": (1.9657520055770874, -1.9763224124908447, -0.18391533195972443),
246
+ "CA": (1.3077669143676758, -0.6677430868148804, -0.19492436945438385),
247
+ "C": (1.9905058145523071, 0.24182087182998657, 0.7879968285560608),
248
+ "O": (2.06896710395813, -0.07880014181137085, 2.0048046112060547),
249
+ "CB": (-0.20306941866874695, -0.8093230128288269, 0.11243502795696259),
250
+ "CG": (-0.9916267395019531, 0.5234957337379456, 0.06723011285066605),
251
+ "CD1": (-2.4228057861328125, 0.29949337244033813, 0.573042094707489),
252
+ "CD2": (-1.0282856225967407, 1.1250264644622803, -1.346014380455017),
253
+ },
254
+ "LYS": {
255
+ "N": (2.4221372604370117, -0.6473312377929688, 0.6370573043823242),
256
+ "CA": (2.0314927101135254, 0.2786507308483124, -0.4298512041568756),
257
+ "C": (2.7168593406677246, 1.595757246017456, -0.20924785733222961),
258
+ "O": (3.397681713104248, 2.116427421569824, -1.1332510709762573),
259
+ "CB": (0.5018402934074402, 0.4873858690261841, -0.49062973260879517),
260
+ "CG": (-0.25062066316604614, -0.7894009947776794, -0.9055535793304443),
261
+ "CD": (-1.769762635231018, -0.5552700161933899, -1.040329933166504),
262
+ "CE": (-2.576533555984497, -1.0221366882324219, 0.18493641912937164),
263
+ "NZ": (-2.269151210784912, -0.24293844401836395, 1.3849012851715088),
264
+ },
265
+ "MET": {
266
+ "N": (1.8903918266296387, -1.5252995491027832, -0.42638593912124634),
267
+ "CA": (1.2630571126937866, -0.24417810142040253, -0.7626462578773499),
268
+ "C": (2.30391001701355, 0.8367712497711182, -0.7254616618156433),
269
+ "O": (2.465414524078369, 1.5928632020950317, -1.7207728624343872),
270
+ "CB": (0.10567972809076309, 0.10861825942993164, 0.19741646945476532),
271
+ "CG": (-1.0658042430877686, -0.8736631274223328, 0.08811883628368378),
272
+ "SD": (-2.4557132720947266, -0.3332225978374481, 1.1461700201034546),
273
+ "CE": (-3.265165090560913, 0.7033554911613464, -0.11588376015424728),
274
+ },
275
+ "PHE": {
276
+ "N": (-2.8484435081481934, -1.525790810585022, 0.01789816841483116),
277
+ "CA": (-1.591969609260559, -0.8545162677764893, 0.35214468836784363),
278
+ "C": (-1.8900631666183472, 0.45833414793014526, 1.0232222080230713),
279
+ "O": (-1.3424992561340332, 0.74432373046875, 2.121629476547241),
280
+ "CB": (-0.760358452796936, -0.6342853307723999, -0.9257160425186157),
281
+ "CG": (0.604112982749939, -0.07200468331575394, -0.6148118376731873),
282
+ "CD1": (0.8468314409255981, 1.2480632066726685, -0.7146694660186768),
283
+ "CD2": (1.6827683448791504, -0.9758077263832092, -0.1423054188489914),
284
+ "CE1": (2.1801748275756836, 1.7875733375549316, -0.3744623064994812),
285
+ "CE2": (2.888307809829712, -0.48277512192726135, 0.16804970800876617),
286
+ "CZ": (3.149812936782837, 0.9656873941421509, 0.04440271109342575),
287
+ },
288
+ "PRO": {
289
+ "N": (-0.836250364780426, -0.9899801015853882, 0.5561304688453674),
290
+ "CA": (0.32722190022468567, -0.6164458394050598, -0.25072571635246277),
291
+ "C": (1.6121541261672974, -1.1711241006851196, 0.31082412600517273),
292
+ "O": (1.6127740144729614, -2.2771971225738525, 0.9156193733215332),
293
+ "CB": (0.3248198926448822, 0.9028244018554688, -0.33368146419525146),
294
+ "CG": (-1.1425083875656128, 1.2730128765106201, -0.2590600252151489),
295
+ "CD": (-1.8495968580245972, 0.026575811207294464, 0.2681289613246918),
296
+ },
297
+ "SER": {
298
+ "N": (0.674650251865387, 1.5018702745437622, -0.5367295145988464),
299
+ "CA": (0.00013792862591799349, 0.4966467022895813, 0.28510504961013794),
300
+ "C": (0.9941009879112244, -0.5374617576599121, 0.73505038022995),
301
+ "O": (1.0545241832733154, -0.8683545589447021, 1.9495396614074707),
302
+ "CB": (-1.1279288530349731, -0.1659376323223114, -0.5160963535308838),
303
+ "OG": (-1.8135979175567627, -1.085249662399292, 0.28947514295578003),
304
+ },
305
+ "THR": {
306
+ "N": (-1.325830340385437, -1.3728225231170654, 0.6882233023643494),
307
+ "CA": (-0.5433306097984314, -0.16364754736423492, 0.41697052121162415),
308
+ "C": (-1.294381856918335, 0.7077372074127197, -0.5549946427345276),
309
+ "O": (-1.6939635276794434, 0.23654410243034363, -1.6540418863296509),
310
+ "CB": (0.853203296661377, -0.5363803505897522, -0.14109353721141815),
311
+ "OG1": (1.5220820903778076, -1.379003643989563, 0.7635167837142944),
312
+ "CG2": (1.7225933074951172, 0.7054727077484131, -0.3651331067085266),
313
+ },
314
+ "TRP": {
315
+ "N": (3.686030864715576, 0.7599999904632568, 0.496155709028244),
316
+ "CA": (2.384092092514038, 0.09079249948263168, 0.5325262546539307),
317
+ "C": (2.1113572120666504, -0.6121063232421875, -0.7733646035194397),
318
+ "O": (1.796526312828064, -1.8323148488998413, -0.7775964140892029),
319
+ "CB": (1.281521201133728, 1.1139036417007446, 0.8559791445732117),
320
+ "CG": (-0.04292375594377518, 0.44645074009895325, 1.0942792892456055),
321
+ "CD1": (-0.42329534888267517, -0.15470874309539795, 2.2227554321289062),
322
+ "CD2": (-1.1023900508880615, 0.2158389836549759, 0.11529432237148285),
323
+ "NE1": (-1.7030320167541504, -0.7665823101997375, 2.0595016479492188),
324
+ "CE2": (-2.045644998550415, -0.4881173074245453, 0.710669219493866),
325
+ "CE3": (-1.2173502445220947, 0.6102271676063538, -1.300106406211853),
326
+ "CZ2": (-3.256009340286255, -0.9164394736289978, -0.00984987337142229),
327
+ "CZ3": (-2.315925121307373, 0.2306906282901764, -1.9776310920715332),
328
+ "CH2": (-3.3817875385284424, -0.5677337646484375, -1.3032053709030151),
329
+ },
330
+ "TYR": {
331
+ "N": (-1.7900604009628296, -0.8409399390220642, 1.3180142641067505),
332
+ "CA": (-1.913882851600647, 0.23552845418453217, 0.330669641494751),
333
+ "C": (-3.347280740737915, 0.3588399887084961, -0.09830684959888458),
334
+ "O": (-3.967811346054077, -0.6449354290962219, -0.5423302054405212),
335
+ "CB": (-1.0093992948532104, 0.0004731413209810853, -0.8981552124023438),
336
+ "CG": (0.4520410895347595, 0.021162061020731926, -0.5305932760238647),
337
+ "CD1": (1.0992432832717896, 1.1877919435501099, -0.3579142987728119),
338
+ "CD2": (1.1803174018859863, -1.253401279449463, -0.31122180819511414),
339
+ "CE1": (2.5253450870513916, 1.1990256309509277, 0.029804613441228867),
340
+ "CE2": (2.471151113510132, -1.240687608718872, 0.043534230440855026),
341
+ "CZ": (3.180687665939331, 0.04672492295503616, 0.2214856892824173),
342
+ "OH": (4.523719787597656, 0.0671030730009079, 0.5877485871315002),
343
+ },
344
+ "VAL": {
345
+ "N": (0.5987519025802612, -1.569443702697754, -0.7379124760627747),
346
+ "CA": (0.6014357209205627, -0.10503966361284256, -0.6336286664009094),
347
+ "C": (1.8391697406768799, 0.4067850410938263, 0.06351757049560547),
348
+ "O": (2.3952062129974365, -0.2666190266609192, 0.9731166958808899),
349
+ "CB": (-0.694736897945404, 0.4259096384048462, 0.03581475466489792),
350
+ "CG1": (-1.9276031255722046, 0.09515828639268875, -0.8172357082366943),
351
+ "CG2": (-0.8938426971435547, -0.08640842139720917, 1.472349762916565),
352
+ },
353
+ "UNK": {
354
+ "N": (0.0, 0.0, 0.0),
355
+ "CA": (0.0, 0.0, 0.0),
356
+ "C": (0.0, 0.0, 0.0),
357
+ "O": (0.0, 0.0, 0.0),
358
+ },
359
+ }
360
+
361
+ # Protonated nitrogens at physiological pH (matches CHARGED_ATOMS in the
362
+ # opensource constants for the protein subset).
363
+ PROTEIN_CHARGED_ATOMS: dict[tuple[str, str], int] = {
364
+ ("LYS", "NZ"): 1,
365
+ ("ARG", "NH2"): 1,
366
+ ("HIS", "ND1"): 1,
367
+ }
368
+
369
+ # Only the elements that appear in canonical protein heavy atoms.
370
+ _PROTEIN_ELEMENT_TO_ATOMIC_NUM: dict[str, int] = {"C": 6, "N": 7, "O": 8, "S": 16}
371
+
372
+
373
+ def _encode_atom_name(name: str) -> list[int]:
374
+ padded = name.ljust(4)[:4]
375
+ return [ord(c) - 32 if c != " " else 0 for c in padded]
376
+
377
+
378
+ def prepare_protein_features(sequence: str) -> dict[str, Tensor]:
379
+ """Featurize a single protein sequence for ESMFold2ExperimentalModel.forward.
380
+
381
+ Returns the same keys with the same dtypes/shapes as
382
+ ``ESMFold2InputBuilder.prepare_input(StructurePredictionInput(...))``
383
+ restricted to a single-chain protein with no MSA, modifications,
384
+ distogram conditioning, or covalent bonds. All tensors have a
385
+ leading batch dim of 1; the caller is responsible for moving them
386
+ to the model device.
387
+ """
388
+ if not sequence:
389
+ raise ValueError("sequence must be non-empty")
390
+
391
+ res_3letter = [PROTEIN_1TO3.get(c, "UNK") for c in sequence]
392
+ L = len(sequence)
393
+
394
+ token_atom_starts: list[int] = []
395
+ atom_records: list[tuple[int, str, str, int, tuple[float, float, float]]] = []
396
+ res_type_vals: list[int] = []
397
+ input_id_vals: list[int] = []
398
+ distogram_rep_atom_idx: list[int] = []
399
+
400
+ atom_cursor = 0
401
+ for t_idx, (letter, res_3) in enumerate(zip(sequence, res_3letter)):
402
+ atom_names = PROTEIN_HEAVY_ATOMS[res_3]
403
+ res_type = PROTEIN_RESIDUE_TO_RES_TYPE.get(res_3, PROTEIN_UNK_RES_TYPE)
404
+ input_id = ESM_PROTEIN_VOCAB.get(letter, ESM_PROTEIN_VOCAB["X"])
405
+
406
+ token_atom_starts.append(atom_cursor)
407
+ for name in atom_names:
408
+ charge = PROTEIN_CHARGED_ATOMS.get((res_3, name), 0)
409
+ element = name[0] # protein heavy atoms are all single-letter C/N/O/S
410
+ ref_pos = PROTEIN_REF_POS[res_3][name]
411
+ atom_records.append((t_idx, name, element, charge, ref_pos))
412
+ atom_cursor += 1
413
+
414
+ rep_name = "CB" if "CB" in atom_names else "CA"
415
+ distogram_rep_atom_idx.append(
416
+ token_atom_starts[t_idx] + atom_names.index(rep_name)
417
+ )
418
+
419
+ res_type_vals.append(res_type)
420
+ input_id_vals.append(input_id)
421
+
422
+ n_real_atoms = len(atom_records)
423
+ n_atoms = math.ceil(n_real_atoms / 32) * 32 if n_real_atoms > 0 else 32
424
+
425
+ ref_pos = torch.zeros(n_atoms, 3, dtype=torch.float32)
426
+ ref_element = torch.zeros(n_atoms, dtype=torch.int64)
427
+ ref_charge = torch.zeros(n_atoms, dtype=torch.int8)
428
+ ref_atom_name_chars = torch.zeros(n_atoms, 4, dtype=torch.int64)
429
+ ref_space_uid = torch.zeros(n_atoms, dtype=torch.int64)
430
+ atom_attention_mask = torch.zeros(n_atoms, dtype=torch.bool)
431
+ atom_to_token = torch.zeros(n_atoms, dtype=torch.int64)
432
+
433
+ for i, (t_idx, name, element, charge, pos) in enumerate(atom_records):
434
+ ref_pos[i] = torch.tensor(pos, dtype=torch.float32)
435
+ ref_element[i] = _PROTEIN_ELEMENT_TO_ATOMIC_NUM[element]
436
+ ref_charge[i] = charge
437
+ ref_atom_name_chars[i] = torch.tensor(
438
+ _encode_atom_name(name), dtype=torch.int64
439
+ )
440
+ ref_space_uid[i] = t_idx
441
+ atom_attention_mask[i] = True
442
+ atom_to_token[i] = t_idx
443
+
444
+ token_index = torch.arange(L, dtype=torch.int64)
445
+ residue_index = torch.arange(L, dtype=torch.int64)
446
+ asym_id = torch.zeros(L, dtype=torch.int64)
447
+ sym_id = torch.zeros(L, dtype=torch.int64)
448
+ entity_id = torch.ones(L, dtype=torch.int64)
449
+ mol_type = torch.full((L,), MOL_TYPE_PROTEIN, dtype=torch.int64)
450
+ res_type = torch.tensor(res_type_vals, dtype=torch.int64)
451
+ input_ids = torch.tensor(input_id_vals, dtype=torch.int64)
452
+ token_bonds = torch.zeros(L, L, 1, dtype=torch.float32)
453
+ token_attention_mask = torch.ones(L, dtype=torch.bool)
454
+ distogram_atom_idx = torch.tensor(distogram_rep_atom_idx, dtype=torch.int64)
455
+
456
+ # Single-sequence MSA: depth 1, row 0 is the sequence itself.
457
+ msa = res_type.unsqueeze(0)
458
+ msa_attention_mask = torch.ones(1, L, dtype=torch.bool)
459
+ has_deletion = torch.zeros(1, L, dtype=torch.bool)
460
+ deletion_value = torch.zeros(1, L, dtype=torch.float32)
461
+ deletion_mean = torch.zeros(L, dtype=torch.float32)
462
+
463
+ features = {
464
+ "token_index": token_index,
465
+ "residue_index": residue_index,
466
+ "asym_id": asym_id,
467
+ "sym_id": sym_id,
468
+ "entity_id": entity_id,
469
+ "mol_type": mol_type,
470
+ "res_type": res_type,
471
+ "input_ids": input_ids,
472
+ "token_bonds": token_bonds,
473
+ "token_attention_mask": token_attention_mask,
474
+ "ref_pos": ref_pos,
475
+ "ref_element": ref_element,
476
+ "ref_charge": ref_charge,
477
+ "ref_atom_name_chars": ref_atom_name_chars,
478
+ "ref_space_uid": ref_space_uid,
479
+ "atom_attention_mask": atom_attention_mask,
480
+ "atom_to_token": atom_to_token,
481
+ "distogram_atom_idx": distogram_atom_idx,
482
+ "msa": msa,
483
+ "msa_attention_mask": msa_attention_mask,
484
+ "has_deletion": has_deletion,
485
+ "deletion_value": deletion_value,
486
+ "deletion_mean": deletion_mean,
487
+ }
488
+ return {k: v.unsqueeze(0) for k, v in features.items()}