nikraf commited on
Commit
714cf46
·
verified ·
1 Parent(s): 9b113fb

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +65 -0
  2. config.json +116 -0
  3. model.safetensors +3 -0
  4. packaged_probe_model.py +215 -0
  5. protify/FastPLMs/__init__.py +0 -0
  6. protify/FastPLMs/boltz/scripts/eval/aggregate_evals.py +753 -0
  7. protify/FastPLMs/boltz/scripts/eval/physcialsim_metrics.py +304 -0
  8. protify/FastPLMs/boltz/scripts/eval/run_evals.py +167 -0
  9. protify/FastPLMs/boltz/scripts/process/ccd.py +295 -0
  10. protify/FastPLMs/boltz/scripts/process/cluster.py +111 -0
  11. protify/FastPLMs/boltz/scripts/process/mmcif.py +1123 -0
  12. protify/FastPLMs/boltz/scripts/process/msa.py +130 -0
  13. protify/FastPLMs/boltz/scripts/process/rcsb.py +359 -0
  14. protify/FastPLMs/boltz/scripts/train/train.py +241 -0
  15. protify/FastPLMs/boltz/src/boltz/__init__.py +7 -0
  16. protify/FastPLMs/boltz/src/boltz/data/__init__.py +0 -0
  17. protify/FastPLMs/boltz/src/boltz/data/const.py +1184 -0
  18. protify/FastPLMs/boltz/src/boltz/data/crop/__init__.py +0 -0
  19. protify/FastPLMs/boltz/src/boltz/data/crop/affinity.py +164 -0
  20. protify/FastPLMs/boltz/src/boltz/data/crop/boltz.py +296 -0
  21. protify/FastPLMs/boltz/src/boltz/data/crop/cropper.py +45 -0
  22. protify/FastPLMs/boltz/src/boltz/data/feature/__init__.py +0 -0
  23. protify/FastPLMs/boltz/src/boltz/data/feature/featurizer.py +1225 -0
  24. protify/FastPLMs/boltz/src/boltz/data/feature/featurizerv2.py +2354 -0
  25. protify/FastPLMs/boltz/src/boltz/data/feature/symmetry.py +602 -0
  26. protify/FastPLMs/boltz/src/boltz/data/filter/__init__.py +0 -0
  27. protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/__init__.py +0 -0
  28. protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/date.py +76 -0
  29. protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/filter.py +24 -0
  30. protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/max_residues.py +37 -0
  31. protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/resolution.py +34 -0
  32. protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/size.py +38 -0
  33. protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/subset.py +42 -0
  34. protify/FastPLMs/boltz/src/boltz/data/filter/static/__init__.py +0 -0
  35. protify/FastPLMs/boltz/src/boltz/data/filter/static/filter.py +26 -0
  36. protify/FastPLMs/boltz/src/boltz/data/filter/static/ligand.py +37 -0
  37. protify/FastPLMs/boltz/src/boltz/data/filter/static/polymer.py +299 -0
  38. protify/FastPLMs/boltz/src/boltz/data/module/__init__.py +0 -0
  39. protify/FastPLMs/boltz/src/boltz/data/module/inference.py +310 -0
  40. protify/FastPLMs/boltz/src/boltz/data/module/inferencev2.py +433 -0
  41. protify/FastPLMs/boltz/src/boltz/data/module/training.py +687 -0
  42. protify/FastPLMs/boltz/src/boltz/data/module/trainingv2.py +660 -0
  43. protify/FastPLMs/boltz/src/boltz/data/mol.py +900 -0
  44. protify/FastPLMs/boltz/src/boltz/data/msa/__init__.py +0 -0
  45. protify/FastPLMs/boltz/src/boltz/data/msa/mmseqs2.py +286 -0
  46. protify/FastPLMs/boltz/src/boltz/data/pad.py +84 -0
  47. protify/FastPLMs/boltz/src/boltz/data/parse/__init__.py +0 -0
  48. protify/FastPLMs/boltz/src/boltz/data/parse/a3m.py +134 -0
  49. protify/FastPLMs/boltz/src/boltz/data/parse/csv.py +100 -0
  50. protify/FastPLMs/boltz/src/boltz/data/parse/fasta.py +138 -0
README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # nikraf/OmniPath_2class_clustered-30_ESMC-600_2026-03-11-15-46_NQRV
7
+
8
+ Fine-tuned with Protify.
9
+
10
+ ## About Protify
11
+
12
+ Protify is an open source platform designed to simplify and democratize workflows for chemical language models. With Protify, deep learning models can be trained to predict chemical properties without requiring extensive coding knowledge or computational resources.
13
+
14
+ ### Why Protify?
15
+
16
+ - Benchmark multiple models efficiently.
17
+ - Flexible for all skill levels.
18
+ - Accessible computing with support for precomputed embeddings.
19
+ - Cost-effective workflows for training and evaluation.
20
+
21
+ ## Training Run
22
+
23
+ - `dataset`: OmniPath_2class_clustered-30
24
+ - `model`: ESMC-600
25
+ - `run_id`: 2026-03-11-15-46_NQRV
26
+ - `task_type`: singlelabel
27
+ - `num_runs`: 1
28
+
29
+ ## Dataset Statistics
30
+
31
+ - `train_size`: 102872
32
+ - `valid_size`: 18102
33
+ - `test_size`: 18074
34
+
35
+ ## Validation Metrics
36
+
37
+ - `epoch`: 5.000000
38
+ - `eval_accuracy`: 0.789750
39
+ - `eval_f1`: 0.789330
40
+ - `eval_loss`: 0.445219
41
+ - `eval_mcc`: 0.581780
42
+ - `eval_model_preparation_time`: 0.000300
43
+ - `eval_pr_auc`: 0.884610
44
+ - `eval_precision`: 0.792040
45
+ - `eval_recall`: 0.789750
46
+ - `eval_roc_auc`: 0.880010
47
+ - `eval_runtime`: 21.260300
48
+ - `eval_samples_per_second`: 851.444000
49
+ - `eval_steps_per_second`: 13.311000
50
+
51
+ ## Test Metrics
52
+
53
+ - `test_accuracy`: 0.779350
54
+ - `test_f1`: 0.778210
55
+ - `test_loss`: 0.455012
56
+ - `test_mcc`: 0.564560
57
+ - `test_model_preparation_time`: 0.000300
58
+ - `test_pr_auc`: 0.884200
59
+ - `test_precision`: 0.785240
60
+ - `test_recall`: 0.779350
61
+ - `test_roc_auc`: 0.874270
62
+ - `test_runtime`: 21.119900
63
+ - `test_samples_per_second`: 855.780000
64
+ - `test_steps_per_second`: 13.400000
65
+ - `training_time_seconds`: 1235.285100
config.json ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_token_ids": false,
3
+ "architectures": [
4
+ "PackagedProbeModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "packaged_probe_model.PackagedProbeConfig",
8
+ "AutoModel": "packaged_probe_model.PackagedProbeModel"
9
+ },
10
+ "base_model_name": "ESMC-600",
11
+ "dtype": "float32",
12
+ "matrix_embed": true,
13
+ "model_type": "packaged_probe",
14
+ "pooling_types": [
15
+ "mean",
16
+ "var"
17
+ ],
18
+ "ppi": true,
19
+ "probe_config": {
20
+ "_name_or_path": "",
21
+ "add_cross_attention": false,
22
+ "add_token_ids": false,
23
+ "architectures": [
24
+ "TransformerForSequenceClassification"
25
+ ],
26
+ "bad_words_ids": null,
27
+ "begin_suppress_tokens": null,
28
+ "bos_token_id": null,
29
+ "chunk_size_feed_forward": 0,
30
+ "classifier_dropout": 0.2,
31
+ "classifier_size": 4096,
32
+ "cross_attention_hidden_size": null,
33
+ "decoder_start_token_id": null,
34
+ "diversity_penalty": 0.0,
35
+ "do_sample": false,
36
+ "dropout": 0.2,
37
+ "dtype": "float32",
38
+ "early_stopping": false,
39
+ "encoder_no_repeat_ngram_size": 0,
40
+ "eos_token_id": null,
41
+ "exponential_decay_length_penalty": null,
42
+ "finetuning_task": null,
43
+ "forced_bos_token_id": null,
44
+ "forced_eos_token_id": null,
45
+ "hidden_size": 512,
46
+ "id2label": {
47
+ "0": "LABEL_0",
48
+ "1": "LABEL_1"
49
+ },
50
+ "input_size": 1152,
51
+ "is_decoder": false,
52
+ "is_encoder_decoder": false,
53
+ "label2id": {
54
+ "LABEL_0": 0,
55
+ "LABEL_1": 1
56
+ },
57
+ "length_penalty": 1.0,
58
+ "lora": false,
59
+ "lora_alpha": 32.0,
60
+ "lora_dropout": 0.01,
61
+ "lora_r": 8,
62
+ "max_length": 20,
63
+ "min_length": 0,
64
+ "model_type": "probe",
65
+ "n_heads": 4,
66
+ "n_layers": 1,
67
+ "no_repeat_ngram_size": 0,
68
+ "num_beam_groups": 1,
69
+ "num_beams": 1,
70
+ "num_return_sequences": 1,
71
+ "output_attentions": false,
72
+ "output_hidden_states": false,
73
+ "output_scores": false,
74
+ "pad_token_id": null,
75
+ "pooling_types": [
76
+ "mean",
77
+ "cls"
78
+ ],
79
+ "pre_ln": true,
80
+ "prefix": null,
81
+ "probe_type": "transformer",
82
+ "problem_type": null,
83
+ "pruned_heads": {},
84
+ "remove_invalid_values": false,
85
+ "repetition_penalty": 1.0,
86
+ "return_dict": true,
87
+ "return_dict_in_generate": false,
88
+ "rotary": true,
89
+ "sep_token_id": null,
90
+ "sim_type": "dot",
91
+ "suppress_tokens": null,
92
+ "task_specific_params": null,
93
+ "task_type": "singlelabel",
94
+ "temperature": 1.0,
95
+ "tf_legacy_loss": false,
96
+ "tie_encoder_decoder": false,
97
+ "tie_word_embeddings": true,
98
+ "token_attention": false,
99
+ "tokenizer_class": null,
100
+ "tokenwise": false,
101
+ "top_k": 50,
102
+ "top_p": 1.0,
103
+ "torchscript": false,
104
+ "transformer_dropout": 0.1,
105
+ "transformer_hidden_size": 512,
106
+ "transformers_version": "4.57.6",
107
+ "typical_p": 1.0,
108
+ "use_bfloat16": false,
109
+ "use_bias": false
110
+ },
111
+ "probe_type": "transformer",
112
+ "sep_token_id": 2,
113
+ "task_type": "singlelabel",
114
+ "tokenwise": false,
115
+ "transformers_version": "4.57.6"
116
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec8ea16612d2975dad1abb1da8977591cbba6ff2b0566374755120e6e950bded
3
+ size 2331568712
packaged_probe_model.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from torch import nn
7
+ from transformers import AutoModel, PreTrainedModel, PretrainedConfig
8
+ from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
9
+
10
+
11
+ try:
12
+ from protify.base_models.supported_models import all_presets_with_paths
13
+ from protify.pooler import Pooler
14
+ from protify.probes.get_probe import rebuild_probe_from_saved_config
15
+ except ImportError:
16
+ current_dir = os.path.dirname(os.path.abspath(__file__))
17
+ candidate_paths = [
18
+ current_dir,
19
+ os.path.dirname(current_dir),
20
+ os.path.dirname(os.path.dirname(current_dir)),
21
+ os.path.join(current_dir, "src"),
22
+ ]
23
+ for candidate in candidate_paths:
24
+ if os.path.isdir(candidate) and candidate not in sys.path:
25
+ sys.path.insert(0, candidate)
26
+ from protify.base_models.supported_models import all_presets_with_paths
27
+ from protify.pooler import Pooler
28
+ from protify.probes.get_probe import rebuild_probe_from_saved_config
29
+
30
+
31
+ class PackagedProbeConfig(PretrainedConfig):
32
+ model_type = "packaged_probe"
33
+
34
+ def __init__(
35
+ self,
36
+ base_model_name: str = "",
37
+ probe_type: str = "linear",
38
+ probe_config: Optional[Dict[str, Any]] = None,
39
+ tokenwise: bool = False,
40
+ matrix_embed: bool = False,
41
+ pooling_types: Optional[list[str]] = None,
42
+ task_type: str = "singlelabel",
43
+ num_labels: int = 2,
44
+ ppi: bool = False,
45
+ add_token_ids: bool = False,
46
+ sep_token_id: Optional[int] = None,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(**kwargs)
50
+ self.base_model_name = base_model_name
51
+ self.probe_type = probe_type
52
+ self.probe_config = {} if probe_config is None else probe_config
53
+ self.tokenwise = tokenwise
54
+ self.matrix_embed = matrix_embed
55
+ self.pooling_types = ["mean"] if pooling_types is None else pooling_types
56
+ self.task_type = task_type
57
+ self.num_labels = num_labels
58
+ self.ppi = ppi
59
+ self.add_token_ids = add_token_ids
60
+ self.sep_token_id = sep_token_id
61
+
62
+
63
+ class PackagedProbeModel(PreTrainedModel):
64
+ config_class = PackagedProbeConfig
65
+ base_model_prefix = "backbone"
66
+ all_tied_weights_keys = {}
67
+
68
+ def __init__(
69
+ self,
70
+ config: PackagedProbeConfig,
71
+ base_model: Optional[nn.Module] = None,
72
+ probe: Optional[nn.Module] = None,
73
+ ):
74
+ super().__init__(config)
75
+ self.config = config
76
+ self.backbone = self._load_base_model() if base_model is None else base_model
77
+ self.probe = self._load_probe() if probe is None else probe
78
+ self.pooler = Pooler(self.config.pooling_types)
79
+
80
+ def _load_base_model(self) -> nn.Module:
81
+ if self.config.base_model_name in all_presets_with_paths:
82
+ model_path = all_presets_with_paths[self.config.base_model_name]
83
+ else:
84
+ model_path = self.config.base_model_name
85
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
86
+ model.eval()
87
+ return model
88
+
89
+ def _load_probe(self) -> nn.Module:
90
+ return rebuild_probe_from_saved_config(
91
+ probe_type=self.config.probe_type,
92
+ tokenwise=self.config.tokenwise,
93
+ probe_config=self.config.probe_config,
94
+ )
95
+
96
+ @staticmethod
97
+ def _extract_hidden_states(backbone_output: Any) -> torch.Tensor:
98
+ if isinstance(backbone_output, tuple):
99
+ return backbone_output[0]
100
+ if hasattr(backbone_output, "last_hidden_state"):
101
+ return backbone_output.last_hidden_state
102
+ if isinstance(backbone_output, torch.Tensor):
103
+ return backbone_output
104
+ raise ValueError("Unsupported backbone output format for packaged probe model")
105
+
106
+ @staticmethod
107
+ def _extract_attentions(backbone_output: Any) -> Optional[torch.Tensor]:
108
+ if hasattr(backbone_output, "attentions"):
109
+ return backbone_output.attentions
110
+ return None
111
+
112
+ def _build_ppi_segment_masks(
113
+ self,
114
+ input_ids: torch.Tensor,
115
+ attention_mask: torch.Tensor,
116
+ token_type_ids: Optional[torch.Tensor],
117
+ ) -> tuple[torch.Tensor, torch.Tensor]:
118
+ if token_type_ids is not None and torch.any(token_type_ids == 1):
119
+ mask_a = ((token_type_ids == 0) & (attention_mask == 1)).long()
120
+ mask_b = ((token_type_ids == 1) & (attention_mask == 1)).long()
121
+ assert torch.all(mask_a.sum(dim=1) > 0), "PPI token_type_ids produced empty segment A"
122
+ assert torch.all(mask_b.sum(dim=1) > 0), "PPI token_type_ids produced empty segment B"
123
+ return mask_a, mask_b
124
+
125
+ assert self.config.sep_token_id is not None, "sep_token_id is required for PPI fallback segmentation"
126
+ batch_size, seq_len = input_ids.shape
127
+ mask_a = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
128
+ mask_b = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
129
+
130
+ for batch_idx in range(batch_size):
131
+ valid_positions = torch.where(attention_mask[batch_idx] == 1)[0]
132
+ sep_positions = torch.where((input_ids[batch_idx] == self.config.sep_token_id) & (attention_mask[batch_idx] == 1))[0]
133
+ if len(valid_positions) == 0:
134
+ continue
135
+
136
+ if len(sep_positions) >= 2:
137
+ first_sep = int(sep_positions[0].item())
138
+ second_sep = int(sep_positions[1].item())
139
+ mask_a[batch_idx, :first_sep + 1] = 1
140
+ mask_b[batch_idx, first_sep + 1:second_sep + 1] = 1
141
+ elif len(sep_positions) == 1:
142
+ first_sep = int(sep_positions[0].item())
143
+ mask_a[batch_idx, :first_sep + 1] = 1
144
+ mask_b[batch_idx, first_sep + 1: int(valid_positions[-1].item()) + 1] = 1
145
+ else:
146
+ midpoint = len(valid_positions) // 2
147
+ mask_a[batch_idx, valid_positions[:midpoint]] = 1
148
+ mask_b[batch_idx, valid_positions[midpoint:]] = 1
149
+
150
+ assert torch.all(mask_a.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment A"
151
+ assert torch.all(mask_b.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment B"
152
+ return mask_a, mask_b
153
+
154
+ def _build_probe_inputs(
155
+ self,
156
+ hidden_states: torch.Tensor,
157
+ input_ids: torch.Tensor,
158
+ attention_mask: torch.Tensor,
159
+ token_type_ids: Optional[torch.Tensor],
160
+ attentions: Optional[torch.Tensor],
161
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
162
+ if self.config.ppi and (not self.config.matrix_embed) and (not self.config.tokenwise):
163
+ mask_a, mask_b = self._build_ppi_segment_masks(input_ids, attention_mask, token_type_ids)
164
+ vec_a = self.pooler(hidden_states, attention_mask=mask_a, attentions=attentions)
165
+ vec_b = self.pooler(hidden_states, attention_mask=mask_b, attentions=attentions)
166
+ return torch.cat((vec_a, vec_b), dim=-1), None
167
+
168
+ if self.config.matrix_embed or self.config.tokenwise:
169
+ return hidden_states, attention_mask
170
+
171
+ pooled = self.pooler(hidden_states, attention_mask=attention_mask, attentions=attentions)
172
+ return pooled, None
173
+
174
+ def forward(
175
+ self,
176
+ input_ids: torch.Tensor,
177
+ attention_mask: Optional[torch.Tensor] = None,
178
+ token_type_ids: Optional[torch.Tensor] = None,
179
+ labels: Optional[torch.Tensor] = None,
180
+ ) -> SequenceClassifierOutput | TokenClassifierOutput:
181
+ if attention_mask is None:
182
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
183
+
184
+ requires_attentions = "parti" in self.config.pooling_types and (not self.config.matrix_embed) and (not self.config.tokenwise)
185
+ backbone_kwargs: Dict[str, Any] = {"input_ids": input_ids, "attention_mask": attention_mask}
186
+ if requires_attentions:
187
+ backbone_kwargs["output_attentions"] = True
188
+ backbone_output = self.backbone(**backbone_kwargs)
189
+ hidden_states = self._extract_hidden_states(backbone_output)
190
+ attentions = self._extract_attentions(backbone_output)
191
+ if requires_attentions:
192
+ assert attentions is not None, "parti pooling requires base model attentions"
193
+ probe_embeddings, probe_attention_mask = self._build_probe_inputs(
194
+ hidden_states=hidden_states,
195
+ input_ids=input_ids,
196
+ attention_mask=attention_mask,
197
+ token_type_ids=token_type_ids,
198
+ attentions=attentions,
199
+ )
200
+
201
+ if self.config.probe_type == "linear":
202
+ return self.probe(embeddings=probe_embeddings, labels=labels)
203
+
204
+ if self.config.probe_type == "transformer":
205
+ forward_kwargs: Dict[str, Any] = {"embeddings": probe_embeddings, "labels": labels}
206
+ if probe_attention_mask is not None:
207
+ forward_kwargs["attention_mask"] = probe_attention_mask
208
+ if self.config.add_token_ids and token_type_ids is not None and probe_attention_mask is not None:
209
+ forward_kwargs["token_type_ids"] = token_type_ids
210
+ return self.probe(**forward_kwargs)
211
+
212
+ if self.config.probe_type in ["retrievalnet", "lyra"]:
213
+ return self.probe(embeddings=probe_embeddings, attention_mask=probe_attention_mask, labels=labels)
214
+
215
+ raise ValueError(f"Unsupported probe type for packaged model: {self.config.probe_type}")
protify/FastPLMs/__init__.py ADDED
File without changes
protify/FastPLMs/boltz/scripts/eval/aggregate_evals.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+
9
+ METRICS = ["lddt", "bb_lddt", "tm_score", "rmsd"]
10
+
11
+
12
+ def compute_af3_metrics(preds, evals, name):
13
+ metrics = {}
14
+
15
+ top_model = None
16
+ top_confidence = -1000
17
+ for model_id in range(5):
18
+ # Load confidence file
19
+ confidence_file = (
20
+ Path(preds) / f"seed-1_sample-{model_id}" / "summary_confidences.json"
21
+ )
22
+ with confidence_file.open("r") as f:
23
+ confidence_data = json.load(f)
24
+ confidence = confidence_data["ranking_score"]
25
+ if confidence > top_confidence:
26
+ top_model = model_id
27
+ top_confidence = confidence
28
+
29
+ # Load eval file
30
+ eval_file = Path(evals) / f"{name}_model_{model_id}.json"
31
+ with eval_file.open("r") as f:
32
+ eval_data = json.load(f)
33
+ for metric_name in METRICS:
34
+ if metric_name in eval_data:
35
+ metrics.setdefault(metric_name, []).append(eval_data[metric_name])
36
+
37
+ if "dockq" in eval_data and eval_data["dockq"] is not None:
38
+ metrics.setdefault("dockq_>0.23", []).append(
39
+ np.mean(
40
+ [float(v > 0.23) for v in eval_data["dockq"] if v is not None]
41
+ )
42
+ )
43
+ metrics.setdefault("dockq_>0.49", []).append(
44
+ np.mean(
45
+ [float(v > 0.49) for v in eval_data["dockq"] if v is not None]
46
+ )
47
+ )
48
+ metrics.setdefault("len_dockq_", []).append(
49
+ len([v for v in eval_data["dockq"] if v is not None])
50
+ )
51
+
52
+ eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json"
53
+ with eval_file.open("r") as f:
54
+ eval_data = json.load(f)
55
+ if "lddt_pli" in eval_data:
56
+ lddt_plis = [
57
+ x["score"] for x in eval_data["lddt_pli"]["assigned_scores"]
58
+ ]
59
+ for _ in eval_data["lddt_pli"][
60
+ "model_ligand_unassigned_reason"
61
+ ].items():
62
+ lddt_plis.append(0)
63
+ if not lddt_plis:
64
+ continue
65
+ lddt_pli = np.mean([x for x in lddt_plis])
66
+ metrics.setdefault("lddt_pli", []).append(lddt_pli)
67
+ metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis))
68
+
69
+ if "rmsd" in eval_data:
70
+ rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]]
71
+ for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items():
72
+ rmsds.append(100)
73
+ if not rmsds:
74
+ continue
75
+ rmsd2 = np.mean([x < 2.0 for x in rmsds])
76
+ rmsd5 = np.mean([x < 5.0 for x in rmsds])
77
+ metrics.setdefault("rmsd<2", []).append(rmsd2)
78
+ metrics.setdefault("rmsd<5", []).append(rmsd5)
79
+ metrics.setdefault("len_rmsd", []).append(len(rmsds))
80
+
81
+ # Get oracle
82
+ oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()}
83
+ avg = {k: sum(v) / len(v) for k, v in metrics.items()}
84
+ top1 = {k: v[top_model] for k, v in metrics.items()}
85
+
86
+ results = {}
87
+ for metric_name in metrics:
88
+ if metric_name.startswith("len_"):
89
+ continue
90
+ if metric_name == "lddt_pli":
91
+ l = metrics["len_lddt_pli"][0]
92
+ elif metric_name == "rmsd<2" or metric_name == "rmsd<5":
93
+ l = metrics["len_rmsd"][0]
94
+ elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49":
95
+ l = metrics["len_dockq_"][0]
96
+ else:
97
+ l = 1
98
+ results[metric_name] = {
99
+ "oracle": oracle[metric_name],
100
+ "average": avg[metric_name],
101
+ "top1": top1[metric_name],
102
+ "len": l,
103
+ }
104
+
105
+ return results
106
+
107
+
108
+ def compute_chai_metrics(preds, evals, name):
109
+ metrics = {}
110
+
111
+ top_model = None
112
+ top_confidence = 0
113
+ for model_id in range(5):
114
+ # Load confidence file
115
+ confidence_file = Path(preds) / f"scores.model_idx_{model_id}.npz"
116
+ confidence_data = np.load(confidence_file)
117
+ confidence = confidence_data["aggregate_score"].item()
118
+ if confidence > top_confidence:
119
+ top_model = model_id
120
+ top_confidence = confidence
121
+
122
+ # Load eval file
123
+ eval_file = Path(evals) / f"{name}_model_{model_id}.json"
124
+ with eval_file.open("r") as f:
125
+ eval_data = json.load(f)
126
+ for metric_name in METRICS:
127
+ if metric_name in eval_data:
128
+ metrics.setdefault(metric_name, []).append(eval_data[metric_name])
129
+
130
+ if "dockq" in eval_data and eval_data["dockq"] is not None:
131
+ metrics.setdefault("dockq_>0.23", []).append(
132
+ np.mean(
133
+ [float(v > 0.23) for v in eval_data["dockq"] if v is not None]
134
+ )
135
+ )
136
+ metrics.setdefault("dockq_>0.49", []).append(
137
+ np.mean(
138
+ [float(v > 0.49) for v in eval_data["dockq"] if v is not None]
139
+ )
140
+ )
141
+ metrics.setdefault("len_dockq_", []).append(
142
+ len([v for v in eval_data["dockq"] if v is not None])
143
+ )
144
+
145
+ eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json"
146
+ with eval_file.open("r") as f:
147
+ eval_data = json.load(f)
148
+ if "lddt_pli" in eval_data:
149
+ lddt_plis = [
150
+ x["score"] for x in eval_data["lddt_pli"]["assigned_scores"]
151
+ ]
152
+ for _ in eval_data["lddt_pli"][
153
+ "model_ligand_unassigned_reason"
154
+ ].items():
155
+ lddt_plis.append(0)
156
+ if not lddt_plis:
157
+ continue
158
+ lddt_pli = np.mean([x for x in lddt_plis])
159
+ metrics.setdefault("lddt_pli", []).append(lddt_pli)
160
+ metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis))
161
+
162
+ if "rmsd" in eval_data:
163
+ rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]]
164
+ for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items():
165
+ rmsds.append(100)
166
+ if not rmsds:
167
+ continue
168
+ rmsd2 = np.mean([x < 2.0 for x in rmsds])
169
+ rmsd5 = np.mean([x < 5.0 for x in rmsds])
170
+ metrics.setdefault("rmsd<2", []).append(rmsd2)
171
+ metrics.setdefault("rmsd<5", []).append(rmsd5)
172
+ metrics.setdefault("len_rmsd", []).append(len(rmsds))
173
+
174
+ # Get oracle
175
+ oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()}
176
+ avg = {k: sum(v) / len(v) for k, v in metrics.items()}
177
+ top1 = {k: v[top_model] for k, v in metrics.items()}
178
+
179
+ results = {}
180
+ for metric_name in metrics:
181
+ if metric_name.startswith("len_"):
182
+ continue
183
+ if metric_name == "lddt_pli":
184
+ l = metrics["len_lddt_pli"][0]
185
+ elif metric_name == "rmsd<2" or metric_name == "rmsd<5":
186
+ l = metrics["len_rmsd"][0]
187
+ elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49":
188
+ l = metrics["len_dockq_"][0]
189
+ else:
190
+ l = 1
191
+ results[metric_name] = {
192
+ "oracle": oracle[metric_name],
193
+ "average": avg[metric_name],
194
+ "top1": top1[metric_name],
195
+ "len": l,
196
+ }
197
+
198
+ return results
199
+
200
+
201
+ def compute_boltz_metrics(preds, evals, name):
202
+ metrics = {}
203
+
204
+ top_model = None
205
+ top_confidence = 0
206
+ for model_id in range(5):
207
+ # Load confidence file
208
+ confidence_file = (
209
+ Path(preds) / f"confidence_{Path(preds).name}_model_{model_id}.json"
210
+ )
211
+ with confidence_file.open("r") as f:
212
+ confidence_data = json.load(f)
213
+ confidence = confidence_data["confidence_score"]
214
+ if confidence > top_confidence:
215
+ top_model = model_id
216
+ top_confidence = confidence
217
+
218
+ # Load eval file
219
+ eval_file = Path(evals) / f"{name}_model_{model_id}.json"
220
+ with eval_file.open("r") as f:
221
+ eval_data = json.load(f)
222
+ for metric_name in METRICS:
223
+ if metric_name in eval_data:
224
+ metrics.setdefault(metric_name, []).append(eval_data[metric_name])
225
+
226
+ if "dockq" in eval_data and eval_data["dockq"] is not None:
227
+ metrics.setdefault("dockq_>0.23", []).append(
228
+ np.mean(
229
+ [float(v > 0.23) for v in eval_data["dockq"] if v is not None]
230
+ )
231
+ )
232
+ metrics.setdefault("dockq_>0.49", []).append(
233
+ np.mean(
234
+ [float(v > 0.49) for v in eval_data["dockq"] if v is not None]
235
+ )
236
+ )
237
+ metrics.setdefault("len_dockq_", []).append(
238
+ len([v for v in eval_data["dockq"] if v is not None])
239
+ )
240
+
241
+ eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json"
242
+ with eval_file.open("r") as f:
243
+ eval_data = json.load(f)
244
+ if "lddt_pli" in eval_data:
245
+ lddt_plis = [
246
+ x["score"] for x in eval_data["lddt_pli"]["assigned_scores"]
247
+ ]
248
+ for _ in eval_data["lddt_pli"][
249
+ "model_ligand_unassigned_reason"
250
+ ].items():
251
+ lddt_plis.append(0)
252
+ if not lddt_plis:
253
+ continue
254
+ lddt_pli = np.mean([x for x in lddt_plis])
255
+ metrics.setdefault("lddt_pli", []).append(lddt_pli)
256
+ metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis))
257
+
258
+ if "rmsd" in eval_data:
259
+ rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]]
260
+ for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items():
261
+ rmsds.append(100)
262
+ if not rmsds:
263
+ continue
264
+ rmsd2 = np.mean([x < 2.0 for x in rmsds])
265
+ rmsd5 = np.mean([x < 5.0 for x in rmsds])
266
+ metrics.setdefault("rmsd<2", []).append(rmsd2)
267
+ metrics.setdefault("rmsd<5", []).append(rmsd5)
268
+ metrics.setdefault("len_rmsd", []).append(len(rmsds))
269
+
270
+ # Get oracle
271
+ oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()}
272
+ avg = {k: sum(v) / len(v) for k, v in metrics.items()}
273
+ top1 = {k: v[top_model] for k, v in metrics.items()}
274
+
275
+ results = {}
276
+ for metric_name in metrics:
277
+ if metric_name.startswith("len_"):
278
+ continue
279
+ if metric_name == "lddt_pli":
280
+ l = metrics["len_lddt_pli"][0]
281
+ elif metric_name == "rmsd<2" or metric_name == "rmsd<5":
282
+ l = metrics["len_rmsd"][0]
283
+ elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49":
284
+ l = metrics["len_dockq_"][0]
285
+ else:
286
+ l = 1
287
+ results[metric_name] = {
288
+ "oracle": oracle[metric_name],
289
+ "average": avg[metric_name],
290
+ "top1": top1[metric_name],
291
+ "len": l,
292
+ }
293
+
294
+ return results
295
+
296
+
297
+ def eval_models(
298
+ chai_preds,
299
+ chai_evals,
300
+ af3_preds,
301
+ af3_evals,
302
+ boltz_preds,
303
+ boltz_evals,
304
+ boltz_preds_x,
305
+ boltz_evals_x,
306
+ ):
307
+ # Load preds and make sure we have predictions for all models
308
+ chai_preds_names = {
309
+ x.name.lower(): x
310
+ for x in Path(chai_preds).iterdir()
311
+ if not x.name.lower().startswith(".")
312
+ }
313
+ af3_preds_names = {
314
+ x.name.lower(): x
315
+ for x in Path(af3_preds).iterdir()
316
+ if not x.name.lower().startswith(".")
317
+ }
318
+ boltz_preds_names = {
319
+ x.name.lower(): x
320
+ for x in Path(boltz_preds).iterdir()
321
+ if not x.name.lower().startswith(".")
322
+ }
323
+ boltz_preds_names_x = {
324
+ x.name.lower(): x
325
+ for x in Path(boltz_preds_x).iterdir()
326
+ if not x.name.lower().startswith(".")
327
+ }
328
+
329
+ print("Chai preds", len(chai_preds_names))
330
+ print("Af3 preds", len(af3_preds_names))
331
+ print("Boltz preds", len(boltz_preds_names))
332
+ print("Boltzx preds", len(boltz_preds_names_x))
333
+
334
+ common = (
335
+ set(chai_preds_names.keys())
336
+ & set(af3_preds_names.keys())
337
+ & set(boltz_preds_names.keys())
338
+ & set(boltz_preds_names_x.keys())
339
+ )
340
+
341
+ # Remove examples in the validation set
342
+ keys_to_remove = ["t1133", "h1134", "r1134s1", "t1134s2", "t1121", "t1123", "t1159"]
343
+ for key in keys_to_remove:
344
+ if key in common:
345
+ common.remove(key)
346
+ print("Common", len(common))
347
+
348
+ # Create a dataframe with the following schema:
349
+ # tool, name, metric, oracle, average, top1
350
+ results = []
351
+ for name in tqdm(common):
352
+ try:
353
+ af3_results = compute_af3_metrics(
354
+ af3_preds_names[name],
355
+ af3_evals,
356
+ name,
357
+ )
358
+
359
+ except Exception as e:
360
+ import traceback
361
+
362
+ traceback.print_exc()
363
+ print(f"Error evaluating AF3 {name}: {e}")
364
+ continue
365
+ try:
366
+ chai_results = compute_chai_metrics(
367
+ chai_preds_names[name],
368
+ chai_evals,
369
+ name,
370
+ )
371
+ except Exception as e:
372
+ import traceback
373
+
374
+ traceback.print_exc()
375
+ print(f"Error evaluating Chai {name}: {e}")
376
+ continue
377
+ try:
378
+ boltz_results = compute_boltz_metrics(
379
+ boltz_preds_names[name],
380
+ boltz_evals,
381
+ name,
382
+ )
383
+ except Exception as e:
384
+ import traceback
385
+
386
+ traceback.print_exc()
387
+ print(f"Error evaluating Boltz {name}: {e}")
388
+ continue
389
+
390
+ try:
391
+ boltz_results_x = compute_boltz_metrics(
392
+ boltz_preds_names_x[name],
393
+ boltz_evals_x,
394
+ name,
395
+ )
396
+ except Exception as e:
397
+ import traceback
398
+
399
+ traceback.print_exc()
400
+ print(f"Error evaluating Boltzx {name}: {e}")
401
+ continue
402
+
403
+ for metric_name in af3_results:
404
+ if metric_name in chai_results and metric_name in boltz_results:
405
+ if (
406
+ (
407
+ af3_results[metric_name]["len"]
408
+ == chai_results[metric_name]["len"]
409
+ )
410
+ and (
411
+ af3_results[metric_name]["len"]
412
+ == boltz_results[metric_name]["len"]
413
+ )
414
+ and (
415
+ af3_results[metric_name]["len"]
416
+ == boltz_results_x[metric_name]["len"]
417
+ )
418
+ ):
419
+ results.append(
420
+ {
421
+ "tool": "AF3 oracle",
422
+ "target": name,
423
+ "metric": metric_name,
424
+ "value": af3_results[metric_name]["oracle"],
425
+ }
426
+ )
427
+ results.append(
428
+ {
429
+ "tool": "AF3 top-1",
430
+ "target": name,
431
+ "metric": metric_name,
432
+ "value": af3_results[metric_name]["top1"],
433
+ }
434
+ )
435
+ results.append(
436
+ {
437
+ "tool": "Chai-1 oracle",
438
+ "target": name,
439
+ "metric": metric_name,
440
+ "value": chai_results[metric_name]["oracle"],
441
+ }
442
+ )
443
+ results.append(
444
+ {
445
+ "tool": "Chai-1 top-1",
446
+ "target": name,
447
+ "metric": metric_name,
448
+ "value": chai_results[metric_name]["top1"],
449
+ }
450
+ )
451
+ results.append(
452
+ {
453
+ "tool": "Boltz-1 oracle",
454
+ "target": name,
455
+ "metric": metric_name,
456
+ "value": boltz_results[metric_name]["oracle"],
457
+ }
458
+ )
459
+ results.append(
460
+ {
461
+ "tool": "Boltz-1 top-1",
462
+ "target": name,
463
+ "metric": metric_name,
464
+ "value": boltz_results[metric_name]["top1"],
465
+ }
466
+ )
467
+ results.append(
468
+ {
469
+ "tool": "Boltz-1x oracle",
470
+ "target": name,
471
+ "metric": metric_name,
472
+ "value": boltz_results_x[metric_name]["oracle"],
473
+ }
474
+ )
475
+ results.append(
476
+ {
477
+ "tool": "Boltz-1x top-1",
478
+ "target": name,
479
+ "metric": metric_name,
480
+ "value": boltz_results_x[metric_name]["top1"],
481
+ }
482
+ )
483
+ else:
484
+ print(
485
+ "Different lengths",
486
+ name,
487
+ metric_name,
488
+ af3_results[metric_name]["len"],
489
+ chai_results[metric_name]["len"],
490
+ boltz_results[metric_name]["len"],
491
+ boltz_results_x[metric_name]["len"],
492
+ )
493
+ else:
494
+ print(
495
+ "Missing metric",
496
+ name,
497
+ metric_name,
498
+ metric_name in chai_results,
499
+ metric_name in boltz_results,
500
+ metric_name in boltz_results_x,
501
+ )
502
+
503
+ # Write the results to a file, ensure we only keep the target & metrics where we have all tools
504
+ df = pd.DataFrame(results)
505
+ return df
506
+
507
+
508
+ def eval_validity_checks(df):
509
+ # Filter the dataframe to only include the targets in the validity checks
510
+ name_mapping = {
511
+ "af3": "AF3 top-1",
512
+ "chai": "Chai-1 top-1",
513
+ "boltz1": "Boltz-1 top-1",
514
+ "boltz1x": "Boltz-1x top-1",
515
+ }
516
+ top1 = df[df["model_idx"] == 0]
517
+ top1 = top1[["tool", "pdb_id", "valid"]]
518
+ top1["tool"] = top1["tool"].apply(lambda x: name_mapping[x])
519
+ top1 = top1.rename(columns={"tool": "tool", "pdb_id": "target", "valid": "value"})
520
+ top1["metric"] = "physical validity"
521
+ top1["target"] = top1["target"].apply(lambda x: x.lower())
522
+ top1 = top1[["tool", "target", "metric", "value"]]
523
+
524
+ name_mapping = {
525
+ "af3": "AF3 oracle",
526
+ "chai": "Chai-1 oracle",
527
+ "boltz1": "Boltz-1 oracle",
528
+ "boltz1x": "Boltz-1x oracle",
529
+ }
530
+ oracle = df[["tool", "model_idx", "pdb_id", "valid"]]
531
+ oracle = oracle.groupby(["tool", "pdb_id"])["valid"].max().reset_index()
532
+ oracle = oracle.rename(
533
+ columns={"tool": "tool", "pdb_id": "target", "valid": "value"}
534
+ )
535
+ oracle["tool"] = oracle["tool"].apply(lambda x: name_mapping[x])
536
+ oracle["metric"] = "physical validity"
537
+ oracle = oracle[["tool", "target", "metric", "value"]]
538
+ oracle["target"] = oracle["target"].apply(lambda x: x.lower())
539
+ out = pd.concat([top1, oracle])
540
+ return out
541
+
542
+
543
+ def bootstrap_ci(series, n_boot=1000, alpha=0.05):
544
+ """
545
+ Compute 95% bootstrap confidence intervals for the mean of 'series'.
546
+ """
547
+ n = len(series)
548
+ boot_means = []
549
+ # Perform bootstrap resampling
550
+ for _ in range(n_boot):
551
+ sample = series.sample(n, replace=True)
552
+ boot_means.append(sample.mean())
553
+
554
+ boot_means = np.array(boot_means)
555
+ mean_val = np.mean(series)
556
+ lower = np.percentile(boot_means, 100 * alpha / 2)
557
+ upper = np.percentile(boot_means, 100 * (1 - alpha / 2))
558
+ return mean_val, lower, upper
559
+
560
+
561
+ def plot_data(desired_tools, desired_metrics, df, dataset, filename):
562
+ filtered_df = df[
563
+ df["tool"].isin(desired_tools) & df["metric"].isin(desired_metrics)
564
+ ]
565
+
566
+ # Apply bootstrap to each (tool, metric) group
567
+ boot_stats = filtered_df.groupby(["tool", "metric"])["value"].apply(bootstrap_ci)
568
+
569
+ # boot_stats is a Series of tuples (mean, lower, upper). Convert to DataFrame:
570
+ boot_stats = boot_stats.apply(pd.Series)
571
+ boot_stats.columns = ["mean", "lower", "upper"]
572
+
573
+ # Unstack to get a DataFrame suitable for plotting
574
+ plot_data = boot_stats["mean"].unstack("tool")
575
+ plot_data = plot_data.reindex(desired_metrics)
576
+
577
+ lower_data = boot_stats["lower"].unstack("tool")
578
+ lower_data = lower_data.reindex(desired_metrics)
579
+
580
+ upper_data = boot_stats["upper"].unstack("tool")
581
+ upper_data = upper_data.reindex(desired_metrics)
582
+
583
+ # If you need a specific order of tools:
584
+ tool_order = [
585
+ "AF3 oracle",
586
+ "AF3 top-1",
587
+ "Chai-1 oracle",
588
+ "Chai-1 top-1",
589
+ "Boltz-1 oracle",
590
+ "Boltz-1 top-1",
591
+ "Boltz-1x oracle",
592
+ "Boltz-1x top-1",
593
+ ]
594
+ plot_data = plot_data[tool_order]
595
+ lower_data = lower_data[tool_order]
596
+ upper_data = upper_data[tool_order]
597
+
598
+ # Rename metrics
599
+ renaming = {
600
+ "lddt_pli": "Mean LDDT-PLI",
601
+ "rmsd<2": "L-RMSD < 2A",
602
+ "lddt": "Mean LDDT",
603
+ "dockq_>0.23": "DockQ > 0.23",
604
+ "physical validity": "Physical Validity",
605
+ }
606
+ plot_data = plot_data.rename(index=renaming)
607
+ lower_data = lower_data.rename(index=renaming)
608
+ upper_data = upper_data.rename(index=renaming)
609
+ mean_vals = plot_data.values
610
+
611
+ # Colors
612
+ tool_colors = [
613
+ "#994C00", # AF3 oracle
614
+ "#FFB55A", # AF3 top-1
615
+ "#931652", # Chai-1 oracle
616
+ "#FC8AD9", # Chai-1 top-1
617
+ "#188F52", # Boltz-1 oracle
618
+ "#86E935", # Boltz-1 top-1
619
+ "#004D80", # Boltz-1x oracle
620
+ "#55C2FF", # Boltz-1x top-1
621
+ ]
622
+
623
+ fig, ax = plt.subplots(figsize=(10, 5))
624
+
625
+ x = np.arange(len(plot_data.index))
626
+ bar_spacing = 0.015
627
+ total_width = 0.7
628
+ # Adjust width to account for the spacing
629
+ width = (total_width - (len(tool_order) - 1) * bar_spacing) / len(tool_order)
630
+
631
+ for i, tool in enumerate(tool_order):
632
+ # Each subsequent bar moves over by width + bar_spacing
633
+ offsets = x - (total_width - width) / 2 + i * (width + bar_spacing)
634
+ # Extract the means and errors for this tool
635
+ tool_means = plot_data[tool].values
636
+ tool_yerr_lower = mean_vals[:, i] - lower_data.values[:, i]
637
+ tool_yerr_upper = upper_data.values[:, i] - mean_vals[:, i]
638
+ # Construct yerr array specifically for this tool
639
+ tool_yerr = np.vstack([tool_yerr_lower, tool_yerr_upper])
640
+
641
+ ax.bar(
642
+ offsets,
643
+ tool_means,
644
+ width=width,
645
+ color=tool_colors[i],
646
+ label=tool,
647
+ yerr=tool_yerr,
648
+ capsize=2,
649
+ error_kw={"elinewidth": 0.75},
650
+ )
651
+
652
+ ax.set_xticks(x)
653
+ ax.set_xticklabels(plot_data.index, rotation=0)
654
+ ax.set_ylabel("Value")
655
+ ax.set_title(f"Performances on {dataset} with 95% CI (Bootstrap)")
656
+
657
+ plt.tight_layout()
658
+ ax.legend(loc="lower center", bbox_to_anchor=(0.5, 0.85), ncols=4, frameon=False)
659
+
660
+ plt.savefig(filename)
661
+ plt.show()
662
+
663
+
664
+ def main():
665
+ eval_folder = "../../boltz_results_final/"
666
+ output_folder = "../../boltz_results_final/"
667
+
668
+ # Eval the test set
669
+ chai_preds = eval_folder + "outputs/test/chai"
670
+ chai_evals = eval_folder + "evals/test/chai"
671
+
672
+ af3_preds = eval_folder + "outputs/test/af3"
673
+ af3_evals = eval_folder + "evals/test/af3"
674
+
675
+ boltz_preds = eval_folder + "outputs/test/boltz/predictions"
676
+ boltz_evals = eval_folder + "evals/test/boltz"
677
+
678
+ boltz_preds_x = eval_folder + "outputs/test/boltzx/predictions"
679
+ boltz_evals_x = eval_folder + "evals/test/boltzx"
680
+
681
+ validity_checks = eval_folder + "physical_checks_test.csv"
682
+
683
+ df_validity_checks = pd.read_csv(validity_checks)
684
+ df_validity_checks = eval_validity_checks(df_validity_checks)
685
+
686
+ df = eval_models(
687
+ chai_preds,
688
+ chai_evals,
689
+ af3_preds,
690
+ af3_evals,
691
+ boltz_preds,
692
+ boltz_evals,
693
+ boltz_preds_x,
694
+ boltz_evals_x,
695
+ )
696
+
697
+ df = pd.concat([df, df_validity_checks]).reset_index(drop=True)
698
+ df.to_csv(output_folder + "results_test.csv", index=False)
699
+
700
+ desired_tools = [
701
+ "AF3 oracle",
702
+ "AF3 top-1",
703
+ "Chai-1 oracle",
704
+ "Chai-1 top-1",
705
+ "Boltz-1 oracle",
706
+ "Boltz-1 top-1",
707
+ "Boltz-1x oracle",
708
+ "Boltz-1x top-1",
709
+ ]
710
+ desired_metrics = ["lddt", "dockq_>0.23", "lddt_pli", "rmsd<2", "physical validity"]
711
+ plot_data(
712
+ desired_tools, desired_metrics, df, "PDB Test", output_folder + "plot_test.pdf"
713
+ )
714
+
715
+ # Eval CASP
716
+ chai_preds = eval_folder + "outputs/casp15/chai"
717
+ chai_evals = eval_folder + "evals/casp15/chai"
718
+
719
+ af3_preds = eval_folder + "outputs/casp15/af3"
720
+ af3_evals = eval_folder + "evals/casp15/af3"
721
+
722
+ boltz_preds = eval_folder + "outputs/casp15/boltz/predictions"
723
+ boltz_evals = eval_folder + "evals/casp15/boltz"
724
+
725
+ boltz_preds_x = eval_folder + "outputs/casp15/boltzx/predictions"
726
+ boltz_evals_x = eval_folder + "evals/casp15/boltzx"
727
+
728
+ validity_checks = eval_folder + "physical_checks_casp.csv"
729
+
730
+ df_validity_checks = pd.read_csv(validity_checks)
731
+ df_validity_checks = eval_validity_checks(df_validity_checks)
732
+
733
+ df = eval_models(
734
+ chai_preds,
735
+ chai_evals,
736
+ af3_preds,
737
+ af3_evals,
738
+ boltz_preds,
739
+ boltz_evals,
740
+ boltz_preds_x,
741
+ boltz_evals_x,
742
+ )
743
+
744
+ df = pd.concat([df, df_validity_checks]).reset_index(drop=True)
745
+ df.to_csv(output_folder + "results_casp.csv", index=False)
746
+
747
+ plot_data(
748
+ desired_tools, desired_metrics, df, "CASP15", output_folder + "plot_casp.pdf"
749
+ )
750
+
751
+
752
+ if __name__ == "__main__":
753
+ main()
protify/FastPLMs/boltz/scripts/eval/physcialsim_metrics.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+
4
+ import numpy as np
5
+ import torch
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ import pandas as pd
9
+ from boltz.data.mol import load_molecules
10
+ from boltz.data import const
11
+ from boltz.data.parse.mmcif_with_constraints import parse_mmcif
12
+ from multiprocessing import Pool
13
+
14
+
15
+ def compute_torsion_angles(coords, torsion_index):
16
+ r_ij = coords[..., torsion_index[0], :] - coords[..., torsion_index[1], :]
17
+ r_kj = coords[..., torsion_index[2], :] - coords[..., torsion_index[1], :]
18
+ r_kl = coords[..., torsion_index[2], :] - coords[..., torsion_index[3], :]
19
+ n_ijk = np.cross(r_ij, r_kj, axis=-1)
20
+ n_jkl = np.cross(r_kj, r_kl, axis=-1)
21
+ r_kj_norm = np.linalg.norm(r_kj, axis=-1)
22
+ n_ijk_norm = np.linalg.norm(n_ijk, axis=-1)
23
+ n_jkl_norm = np.linalg.norm(n_jkl, axis=-1)
24
+ sign_phi = np.sign(
25
+ r_kj[..., None, :] @ np.cross(n_ijk, n_jkl, axis=-1)[..., None]
26
+ ).squeeze(axis=(-1, -2))
27
+ phi = sign_phi * np.arccos(
28
+ np.clip(
29
+ (n_ijk[..., None, :] @ n_jkl[..., None]).squeeze(axis=(-1, -2))
30
+ / (n_ijk_norm * n_jkl_norm),
31
+ -1 + 1e-8,
32
+ 1 - 1e-8,
33
+ )
34
+ )
35
+ return phi
36
+
37
+
38
+ def check_ligand_distance_geometry(
39
+ structure, constraints, bond_buffer=0.25, angle_buffer=0.25, clash_buffer=0.2
40
+ ):
41
+ coords = structure.coords["coords"]
42
+ rdkit_bounds_constraints = constraints.rdkit_bounds_constraints
43
+ pair_index = rdkit_bounds_constraints["atom_idxs"].copy().astype(np.int64).T
44
+ bond_mask = rdkit_bounds_constraints["is_bond"].copy().astype(bool)
45
+ angle_mask = rdkit_bounds_constraints["is_angle"].copy().astype(bool)
46
+ upper_bounds = rdkit_bounds_constraints["upper_bound"].copy().astype(np.float32)
47
+ lower_bounds = rdkit_bounds_constraints["lower_bound"].copy().astype(np.float32)
48
+ dists = np.linalg.norm(coords[pair_index[0]] - coords[pair_index[1]], axis=-1)
49
+ bond_length_violations = (
50
+ dists[bond_mask] <= lower_bounds[bond_mask] * (1.0 - bond_buffer)
51
+ ) + (dists[bond_mask] >= upper_bounds[bond_mask] * (1.0 + bond_buffer))
52
+ bond_angle_violations = (
53
+ dists[angle_mask] <= lower_bounds[angle_mask] * (1.0 - angle_buffer)
54
+ ) + (dists[angle_mask] >= upper_bounds[angle_mask] * (1.0 + angle_buffer))
55
+ internal_clash_violations = dists[~bond_mask * ~angle_mask] <= lower_bounds[
56
+ ~bond_mask * ~angle_mask
57
+ ] * (1.0 - clash_buffer)
58
+ num_ligands = sum(
59
+ [
60
+ int(const.chain_types[chain["mol_type"]] == "NONPOLYMER")
61
+ for chain in structure.chains
62
+ ]
63
+ )
64
+ return {
65
+ "num_ligands": num_ligands,
66
+ "num_bond_length_violations": bond_length_violations.sum(),
67
+ "num_bonds": bond_mask.sum(),
68
+ "num_bond_angle_violations": bond_angle_violations.sum(),
69
+ "num_angles": angle_mask.sum(),
70
+ "num_internal_clash_violations": internal_clash_violations.sum(),
71
+ "num_non_neighbors": (~bond_mask * ~angle_mask).sum(),
72
+ }
73
+
74
+
75
+ def check_ligand_stereochemistry(structure, constraints):
76
+ coords = structure.coords["coords"]
77
+ chiral_atom_constraints = constraints.chiral_atom_constraints
78
+ stereo_bond_constraints = constraints.stereo_bond_constraints
79
+
80
+ chiral_atom_index = chiral_atom_constraints["atom_idxs"].T
81
+ true_chiral_atom_orientations = chiral_atom_constraints["is_r"]
82
+ chiral_atom_ref_mask = chiral_atom_constraints["is_reference"]
83
+ chiral_atom_index = chiral_atom_index[:, chiral_atom_ref_mask]
84
+ true_chiral_atom_orientations = true_chiral_atom_orientations[chiral_atom_ref_mask]
85
+ pred_chiral_atom_orientations = (
86
+ compute_torsion_angles(coords, chiral_atom_index) > 0
87
+ )
88
+ chiral_atom_violations = (
89
+ pred_chiral_atom_orientations != true_chiral_atom_orientations
90
+ )
91
+
92
+ stereo_bond_index = stereo_bond_constraints["atom_idxs"].T
93
+ true_stereo_bond_orientations = stereo_bond_constraints["is_e"]
94
+ stereo_bond_ref_mask = stereo_bond_constraints["is_reference"]
95
+ stereo_bond_index = stereo_bond_index[:, stereo_bond_ref_mask]
96
+ true_stereo_bond_orientations = true_stereo_bond_orientations[stereo_bond_ref_mask]
97
+ pred_stereo_bond_orientations = (
98
+ np.abs(compute_torsion_angles(coords, stereo_bond_index)) > np.pi / 2
99
+ )
100
+ stereo_bond_violations = (
101
+ pred_stereo_bond_orientations != true_stereo_bond_orientations
102
+ )
103
+
104
+ return {
105
+ "num_chiral_atom_violations": chiral_atom_violations.sum(),
106
+ "num_chiral_atoms": chiral_atom_index.shape[1],
107
+ "num_stereo_bond_violations": stereo_bond_violations.sum(),
108
+ "num_stereo_bonds": stereo_bond_index.shape[1],
109
+ }
110
+
111
+
112
+ def check_ligand_flatness(structure, constraints, buffer=0.25):
113
+ coords = structure.coords["coords"]
114
+
115
+ planar_ring_5_index = constraints.planar_ring_5_constraints["atom_idxs"]
116
+ ring_5_coords = coords[planar_ring_5_index, :]
117
+ centered_ring_5_coords = ring_5_coords - ring_5_coords.mean(axis=-2, keepdims=True)
118
+ ring_5_vecs = np.linalg.svd(centered_ring_5_coords)[2][..., -1, :, None]
119
+ ring_5_dists = np.abs((centered_ring_5_coords @ ring_5_vecs).squeeze(axis=-1))
120
+ ring_5_violations = np.all(ring_5_dists <= buffer, axis=-1)
121
+
122
+ planar_ring_6_index = constraints.planar_ring_6_constraints["atom_idxs"]
123
+ ring_6_coords = coords[planar_ring_6_index, :]
124
+ centered_ring_6_coords = ring_6_coords - ring_6_coords.mean(axis=-2, keepdims=True)
125
+ ring_6_vecs = np.linalg.svd(centered_ring_6_coords)[2][..., -1, :, None]
126
+ ring_6_dists = np.abs((centered_ring_6_coords @ ring_6_vecs)).squeeze(axis=-1)
127
+ ring_6_violations = np.any(ring_6_dists >= buffer, axis=-1)
128
+
129
+ planar_bond_index = constraints.planar_bond_constraints["atom_idxs"]
130
+ bond_coords = coords[planar_bond_index, :]
131
+ centered_bond_coords = bond_coords - bond_coords.mean(axis=-2, keepdims=True)
132
+ bond_vecs = np.linalg.svd(centered_bond_coords)[2][..., -1, :, None]
133
+ bond_dists = np.abs((centered_bond_coords @ bond_vecs)).squeeze(axis=-1)
134
+ bond_violations = np.any(bond_dists >= buffer, axis=-1)
135
+
136
+ return {
137
+ "num_planar_5_ring_violations": ring_5_violations.sum(),
138
+ "num_planar_5_rings": ring_5_violations.shape[0],
139
+ "num_planar_6_ring_violations": ring_6_violations.sum(),
140
+ "num_planar_6_rings": ring_6_violations.shape[0],
141
+ "num_planar_double_bond_violations": bond_violations.sum(),
142
+ "num_planar_double_bonds": bond_violations.shape[0],
143
+ }
144
+
145
+
146
+ def check_steric_clash(structure, molecules, buffer=0.25):
147
+ result = {}
148
+ for type_i in const.chain_types:
149
+ out_type_i = type_i.lower()
150
+ out_type_i = out_type_i if out_type_i != "nonpolymer" else "ligand"
151
+ result[f"num_chain_pairs_sym_{out_type_i}"] = 0
152
+ result[f"num_chain_clashes_sym_{out_type_i}"] = 0
153
+ for type_j in const.chain_types:
154
+ out_type_j = type_j.lower()
155
+ out_type_j = out_type_j if out_type_j != "nonpolymer" else "ligand"
156
+ result[f"num_chain_pairs_asym_{out_type_i}_{out_type_j}"] = 0
157
+ result[f"num_chain_clashes_asym_{out_type_i}_{out_type_j}"] = 0
158
+
159
+ connected_chains = set()
160
+ for bond in structure.bonds:
161
+ if bond["chain_1"] != bond["chain_2"]:
162
+ connected_chains.add(tuple(sorted((bond["chain_1"], bond["chain_2"]))))
163
+
164
+ vdw_radii = []
165
+ for res in structure.residues:
166
+ mol = molecules[res["name"]]
167
+ token_atoms = structure.atoms[
168
+ res["atom_idx"] : res["atom_idx"] + res["atom_num"]
169
+ ]
170
+ atom_name_to_ref = {a.GetProp("name"): a for a in mol.GetAtoms()}
171
+ token_atoms_ref = [atom_name_to_ref[a["name"]] for a in token_atoms]
172
+ vdw_radii.extend(
173
+ [const.vdw_radii[a.GetAtomicNum() - 1] for a in token_atoms_ref]
174
+ )
175
+ vdw_radii = np.array(vdw_radii, dtype=np.float32)
176
+
177
+ np.array([a.GetAtomicNum() for a in token_atoms_ref])
178
+ for i, chain_i in enumerate(structure.chains):
179
+ for j, chain_j in enumerate(structure.chains):
180
+ if (
181
+ chain_i["atom_num"] == 1
182
+ or chain_j["atom_num"] == 1
183
+ or j <= i
184
+ or (i, j) in connected_chains
185
+ ):
186
+ continue
187
+ coords_i = structure.coords["coords"][
188
+ chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"]
189
+ ]
190
+ coords_j = structure.coords["coords"][
191
+ chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"]
192
+ ]
193
+ dists = np.linalg.norm(coords_i[:, None, :] - coords_j[None, :, :], axis=-1)
194
+ radii_i = vdw_radii[
195
+ chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"]
196
+ ]
197
+ radii_j = vdw_radii[
198
+ chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"]
199
+ ]
200
+ radii_sum = radii_i[:, None] + radii_j[None, :]
201
+ is_clashing = np.any(dists < radii_sum * (1.00 - buffer))
202
+ type_i = const.chain_types[chain_i["mol_type"]].lower()
203
+ type_j = const.chain_types[chain_j["mol_type"]].lower()
204
+ type_i = type_i if type_i != "nonpolymer" else "ligand"
205
+ type_j = type_j if type_j != "nonpolymer" else "ligand"
206
+ is_symmetric = (
207
+ chain_i["entity_id"] == chain_j["entity_id"]
208
+ and chain_i["atom_num"] == chain_j["atom_num"]
209
+ )
210
+ if is_symmetric:
211
+ key = "sym_" + type_i
212
+ else:
213
+ key = "asym_" + type_i + "_" + type_j
214
+ result["num_chain_pairs_" + key] += 1
215
+ result["num_chain_clashes_" + key] += int(is_clashing)
216
+ return result
217
+
218
+
219
+ cache_dir = Path("/data/rbg/users/jwohlwend/boltz-cache")
220
+ ccd_path = cache_dir / "ccd.pkl"
221
+ moldir = cache_dir / "mols"
222
+ with ccd_path.open("rb") as file:
223
+ ccd = pickle.load(file)
224
+
225
+ boltz1_dir = Path(
226
+ "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/boltz/predictions"
227
+ )
228
+ boltz1x_dir = Path(
229
+ "/data/scratch/getzn/boltz_private/boltz_1x_test_results_final_new/full_predictions"
230
+ )
231
+ chai_dir = Path(
232
+ "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/chai"
233
+ )
234
+ af3_dir = Path(
235
+ "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/af3"
236
+ )
237
+
238
+ boltz1_pdb_ids = set(os.listdir(boltz1_dir))
239
+ boltz1x_pdb_ids = set(os.listdir(boltz1x_dir))
240
+ chai_pdb_ids = set(os.listdir(chai_dir))
241
+ af3_pdb_ids = set([pdb_id for pdb_id in os.listdir(af3_dir)])
242
+ common_pdb_ids = boltz1_pdb_ids & boltz1x_pdb_ids & chai_pdb_ids & af3_pdb_ids
243
+
244
+ tools = ["boltz1", "boltz1x", "chai", "af3"]
245
+ num_samples = 5
246
+
247
+
248
+ def process_fn(key):
249
+ tool, pdb_id, model_idx = key
250
+ if tool == "boltz1":
251
+ cif_path = boltz1_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif"
252
+ elif tool == "boltz1x":
253
+ cif_path = boltz1x_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif"
254
+ elif tool == "chai":
255
+ cif_path = chai_dir / pdb_id / f"pred.model_idx_{model_idx}.cif"
256
+ elif tool == "af3":
257
+ cif_path = af3_dir / pdb_id.lower() / f"seed-1_sample-{model_idx}" / "model.cif"
258
+
259
+ parsed_structure = parse_mmcif(
260
+ cif_path,
261
+ ccd,
262
+ moldir,
263
+ )
264
+ structure = parsed_structure.data
265
+ constraints = parsed_structure.residue_constraints
266
+
267
+ record = {
268
+ "tool": tool,
269
+ "pdb_id": pdb_id,
270
+ "model_idx": model_idx,
271
+ }
272
+ record.update(check_ligand_distance_geometry(structure, constraints))
273
+ record.update(check_ligand_stereochemistry(structure, constraints))
274
+ record.update(check_ligand_flatness(structure, constraints))
275
+ record.update(check_steric_clash(structure, molecules=ccd))
276
+ return record
277
+
278
+
279
+ keys = []
280
+ for tool in tools:
281
+ for pdb_id in common_pdb_ids:
282
+ for model_idx in range(num_samples):
283
+ keys.append((tool, pdb_id, model_idx))
284
+
285
+ process_fn(keys[0])
286
+ records = []
287
+ with Pool(48) as p:
288
+ with tqdm(total=len(keys)) as pbar:
289
+ for record in p.imap_unordered(process_fn, keys):
290
+ records.append(record)
291
+ pbar.update(1)
292
+ df = pd.DataFrame.from_records(records)
293
+
294
+ df["num_chain_clashes_all"] = df[
295
+ [key for key in df.columns if "chain_clash" in key]
296
+ ].sum(axis=1)
297
+ df["num_pairs_all"] = df[[key for key in df.columns if "chain_pair" in key]].sum(axis=1)
298
+ df["clash_free"] = df["num_chain_clashes_all"] == 0
299
+ df["valid_ligand"] = (
300
+ df[[key for key in df.columns if "violation" in key]].sum(axis=1) == 0
301
+ )
302
+ df["valid"] = (df["clash_free"]) & (df["valid_ligand"])
303
+
304
+ df.to_csv("physical_checks_test.csv")
protify/FastPLMs/boltz/scripts/eval/run_evals.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import concurrent.futures
3
+ import subprocess
4
+ from pathlib import Path
5
+
6
+ from tqdm import tqdm
7
+
8
+ OST_COMPARE_STRUCTURE = r"""
9
+ #!/bin/bash
10
+ # https://openstructure.org/docs/2.7/actions/#ost-compare-structures
11
+
12
+ IMAGE_NAME=openstructure-0.2.8
13
+
14
+ command="compare-structures \
15
+ -m {model_file} \
16
+ -r {reference_file} \
17
+ --fault-tolerant \
18
+ --min-pep-length 4 \
19
+ --min-nuc-length 4 \
20
+ -o {output_path} \
21
+ --lddt --bb-lddt --qs-score --dockq \
22
+ --ics --ips --rigid-scores --patch-scores --tm-score"
23
+
24
+ sudo docker run -u $(id -u):$(id -g) --rm --volume {mount}:{mount} $IMAGE_NAME $command
25
+ """
26
+
27
+
28
+ OST_COMPARE_LIGAND = r"""
29
+ #!/bin/bash
30
+ # https://openstructure.org/docs/2.7/actions/#ost-compare-structures
31
+
32
+ IMAGE_NAME=openstructure-0.2.8
33
+
34
+ command="compare-ligand-structures \
35
+ -m {model_file} \
36
+ -r {reference_file} \
37
+ --fault-tolerant \
38
+ --lddt-pli --rmsd \
39
+ --substructure-match \
40
+ -o {output_path}"
41
+
42
+ sudo docker run -u $(id -u):$(id -g) --rm --volume {mount}:{mount} $IMAGE_NAME $command
43
+ """
44
+
45
+
46
+ def evaluate_structure(
47
+ name: str,
48
+ pred: Path,
49
+ reference: Path,
50
+ outdir: str,
51
+ mount: str,
52
+ executable: str = "/bin/bash",
53
+ ) -> None:
54
+ """Evaluate the structure."""
55
+ # Evaluate polymer metrics
56
+ out_path = Path(outdir) / f"{name}.json"
57
+
58
+ if out_path.exists():
59
+ print( # noqa: T201
60
+ f"Skipping recomputation of {name} as protein json file already exists"
61
+ )
62
+ else:
63
+ subprocess.run(
64
+ OST_COMPARE_STRUCTURE.format(
65
+ model_file=str(pred),
66
+ reference_file=str(reference),
67
+ output_path=str(out_path),
68
+ mount=mount,
69
+ ),
70
+ shell=True, # noqa: S602
71
+ check=False,
72
+ executable=executable,
73
+ capture_output=True,
74
+ )
75
+
76
+ # Evaluate ligand metrics
77
+ out_path = Path(outdir) / f"{name}_ligand.json"
78
+ if out_path.exists():
79
+ print(f"Skipping recomputation of {name} as ligand json file already exists") # noqa: T201
80
+ else:
81
+ subprocess.run(
82
+ OST_COMPARE_LIGAND.format(
83
+ model_file=str(pred),
84
+ reference_file=str(reference),
85
+ output_path=str(out_path),
86
+ mount=mount,
87
+ ),
88
+ shell=True, # noqa: S602
89
+ check=False,
90
+ executable=executable,
91
+ capture_output=True,
92
+ )
93
+
94
+
95
+ def main(args):
96
+ # Aggregate the predictions and references
97
+ files = list(args.data.iterdir())
98
+ names = {f.stem.lower(): f for f in files}
99
+
100
+ # Create the output directory
101
+ args.outdir.mkdir(parents=True, exist_ok=True)
102
+
103
+ first_item = True
104
+ with concurrent.futures.ThreadPoolExecutor(args.max_workers) as executor:
105
+ futures = []
106
+ for name, folder in names.items():
107
+ for model_id in range(5):
108
+ # Split the input data
109
+ if args.format == "af3":
110
+ pred_path = folder / f"seed-1_sample-{model_id}" / "model.cif"
111
+ elif args.format == "chai":
112
+ pred_path = folder / f"pred.model_idx_{model_id}.cif"
113
+ elif args.format == "boltz":
114
+ name_file = (
115
+ f"{name[0].upper()}{name[1:]}"
116
+ if args.testset == "casp"
117
+ else name.lower()
118
+ )
119
+ pred_path = folder / f"{name_file}_model_{model_id}.cif"
120
+
121
+ if args.testset == "casp":
122
+ ref_path = args.pdb / f"{name[0].upper()}{name[1:]}.cif"
123
+ elif args.testset == "test":
124
+ ref_path = args.pdb / f"{name.lower()}.cif.gz"
125
+
126
+ if first_item:
127
+ # Evaluate the first item in the first prediction
128
+ # Ensures that the docker image is downloaded
129
+ evaluate_structure(
130
+ name=f"{name}_model_{model_id}",
131
+ pred=str(pred_path),
132
+ reference=str(ref_path),
133
+ outdir=str(args.outdir),
134
+ mount=args.mount,
135
+ executable=args.executable,
136
+ )
137
+ first_item = False
138
+ else:
139
+ future = executor.submit(
140
+ evaluate_structure,
141
+ name=f"{name}_model_{model_id}",
142
+ pred=str(pred_path),
143
+ reference=str(ref_path),
144
+ outdir=str(args.outdir),
145
+ mount=args.mount,
146
+ executable=args.executable,
147
+ )
148
+ futures.append(future)
149
+
150
+ # Wait for all tasks to complete
151
+ with tqdm(total=len(futures)) as pbar:
152
+ for _ in concurrent.futures.as_completed(futures):
153
+ pbar.update(1)
154
+
155
+
156
+ if __name__ == "__main__":
157
+ parser = argparse.ArgumentParser()
158
+ parser.add_argument("data", type=Path)
159
+ parser.add_argument("pdb", type=Path)
160
+ parser.add_argument("outdir", type=Path)
161
+ parser.add_argument("--format", type=str, default="af3")
162
+ parser.add_argument("--testset", type=str, default="casp")
163
+ parser.add_argument("--mount", type=str)
164
+ parser.add_argument("--executable", type=str, default="/bin/bash")
165
+ parser.add_argument("--max-workers", type=int, default=32)
166
+ args = parser.parse_args()
167
+ main(args)
protify/FastPLMs/boltz/scripts/process/ccd.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compute conformers and symmetries for all the CCD molecules."""
2
+
3
+ import argparse
4
+ import multiprocessing
5
+ import pickle
6
+ import sys
7
+ from functools import partial
8
+ from pathlib import Path
9
+
10
+ import pandas as pd
11
+ import rdkit
12
+ from p_tqdm import p_uimap
13
+ from pdbeccdutils.core import ccd_reader
14
+ from pdbeccdutils.core.component import ConformerType
15
+ from rdkit import rdBase
16
+ from rdkit.Chem import AllChem
17
+ from rdkit.Chem.rdchem import Conformer, Mol
18
+ from tqdm import tqdm
19
+
20
+
21
+ def load_molecules(components: str) -> list[Mol]:
22
+ """Load the CCD components file.
23
+
24
+ Parameters
25
+ ----------
26
+ components : str
27
+ Path to the CCD components file.
28
+
29
+ Returns
30
+ -------
31
+ list[Mol]
32
+
33
+ """
34
+ components: dict[str, ccd_reader.CCDReaderResult]
35
+ components = ccd_reader.read_pdb_components_file(components)
36
+
37
+ mols = []
38
+ for name, component in components.items():
39
+ mol = component.component.mol
40
+ mol.SetProp("PDB_NAME", name)
41
+ mols.append(mol)
42
+
43
+ return mols
44
+
45
+
46
+ def compute_3d(mol: Mol, version: str = "v3") -> bool:
47
+ """Generate 3D coordinates using EKTDG method.
48
+
49
+ Taken from `pdbeccdutils.core.component.Component`.
50
+
51
+ Parameters
52
+ ----------
53
+ mol: Mol
54
+ The RDKit molecule to process
55
+ version: str, optional
56
+ The ETKDG version, defaults ot v3
57
+
58
+ Returns
59
+ -------
60
+ bool
61
+ Whether computation was successful.
62
+
63
+ """
64
+ if version == "v3":
65
+ options = rdkit.Chem.AllChem.ETKDGv3()
66
+ elif version == "v2":
67
+ options = rdkit.Chem.AllChem.ETKDGv2()
68
+ else:
69
+ options = rdkit.Chem.AllChem.ETKDGv2()
70
+
71
+ options.clearConfs = False
72
+ conf_id = -1
73
+
74
+ try:
75
+ conf_id = rdkit.Chem.AllChem.EmbedMolecule(mol, options)
76
+ rdkit.Chem.AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000)
77
+
78
+ except RuntimeError:
79
+ pass # Force field issue here
80
+ except ValueError:
81
+ pass # sanitization issue here
82
+
83
+ if conf_id != -1:
84
+ conformer = mol.GetConformer(conf_id)
85
+ conformer.SetProp("name", ConformerType.Computed.name)
86
+ conformer.SetProp("coord_generation", f"ETKDG{version}")
87
+
88
+ return True
89
+
90
+ return False
91
+
92
+
93
+ def get_conformer(mol: Mol, c_type: ConformerType) -> Conformer:
94
+ """Retrieve an rdkit object for a deemed conformer.
95
+
96
+ Taken from `pdbeccdutils.core.component.Component`.
97
+
98
+ Parameters
99
+ ----------
100
+ mol: Mol
101
+ The molecule to process.
102
+ c_type: ConformerType
103
+ The conformer type to extract.
104
+
105
+ Returns
106
+ -------
107
+ Conformer
108
+ The desired conformer, if any.
109
+
110
+ Raises
111
+ ------
112
+ ValueError
113
+ If there are no conformers of the given tyoe.
114
+
115
+ """
116
+ for c in mol.GetConformers():
117
+ try:
118
+ if c.GetProp("name") == c_type.name:
119
+ return c
120
+ except KeyError: # noqa: PERF203
121
+ pass
122
+
123
+ msg = f"Conformer {c_type.name} does not exist."
124
+ raise ValueError(msg)
125
+
126
+
127
+ def compute_symmetries(mol: Mol) -> list[list[int]]:
128
+ """Compute the symmetries of a molecule.
129
+
130
+ Parameters
131
+ ----------
132
+ mol : Mol
133
+ The molecule to process
134
+
135
+ Returns
136
+ -------
137
+ list[list[int]]
138
+ The symmetries as a list of index permutations
139
+
140
+ """
141
+ mol = AllChem.RemoveHs(mol)
142
+ idx_map = {}
143
+ atom_idx = 0
144
+ for i, atom in enumerate(mol.GetAtoms()):
145
+ # Skip if leaving atoms
146
+ if int(atom.GetProp("leaving_atom")):
147
+ continue
148
+ idx_map[i] = atom_idx
149
+ atom_idx += 1
150
+
151
+ # Calculate self permutations
152
+ permutations = []
153
+ raw_permutations = mol.GetSubstructMatches(mol, uniquify=False)
154
+ for raw_permutation in raw_permutations:
155
+ # Filter out permutations with leaving atoms
156
+ try:
157
+ if {raw_permutation[idx] for idx in idx_map} == set(idx_map.keys()):
158
+ permutation = [
159
+ idx_map[idx] for idx in raw_permutation if idx in idx_map
160
+ ]
161
+ permutations.append(permutation)
162
+ except Exception: # noqa: S110, PERF203, BLE001
163
+ pass
164
+ serialized_permutations = pickle.dumps(permutations)
165
+ mol.SetProp("symmetries", serialized_permutations.hex())
166
+ return permutations
167
+
168
+
169
+ def process(mol: Mol, output: str) -> tuple[str, str]:
170
+ """Process a CCD component.
171
+
172
+ Parameters
173
+ ----------
174
+ mol : Mol
175
+ The molecule to process
176
+ output : str
177
+ The directory to save the molecules
178
+
179
+ Returns
180
+ -------
181
+ str
182
+ The name of the component
183
+ str
184
+ The result of the conformer generation
185
+
186
+ """
187
+ # Get name
188
+ name = mol.GetProp("PDB_NAME")
189
+
190
+ # Check if single atom
191
+ if mol.GetNumAtoms() == 1:
192
+ result = "single"
193
+ else:
194
+ # Get the 3D conformer
195
+ try:
196
+ # Try to generate a 3D conformer with RDKit
197
+ success = compute_3d(mol, version="v3")
198
+ if success:
199
+ _ = get_conformer(mol, ConformerType.Computed)
200
+ result = "computed"
201
+
202
+ # Otherwise, default to the ideal coordinates
203
+ else:
204
+ _ = get_conformer(mol, ConformerType.Ideal)
205
+ result = "ideal"
206
+ except ValueError:
207
+ result = "failed"
208
+
209
+ # Dump the molecule
210
+ path = Path(output) / f"{name}.pkl"
211
+ with path.open("wb") as f:
212
+ pickle.dump(mol, f)
213
+
214
+ # Output the results
215
+ return name, result
216
+
217
+
218
+ def main(args: argparse.Namespace) -> None:
219
+ """Process conformers."""
220
+ # Set property saving
221
+ rdkit.Chem.SetDefaultPickleProperties(rdkit.Chem.PropertyPickleOptions.AllProps)
222
+
223
+ # Load components
224
+ print("Loading components") # noqa: T201
225
+ molecules = load_molecules(args.components)
226
+
227
+ # Reset stdout and stderr, as pdbccdutils messes with them
228
+ sys.stdout = sys.__stdout__
229
+ sys.stderr = sys.__stderr__
230
+
231
+ # Disable rdkit warnings
232
+ blocker = rdBase.BlockLogs() # noqa: F841
233
+
234
+ # Setup processing function
235
+ outdir = Path(args.outdir)
236
+ outdir.mkdir(parents=True, exist_ok=True)
237
+ mol_output = outdir / "mols"
238
+ mol_output.mkdir(parents=True, exist_ok=True)
239
+ process_fn = partial(process, output=str(mol_output))
240
+
241
+ # Process the files in parallel
242
+ print("Processing components") # noqa: T201
243
+ metadata = []
244
+
245
+ # Check if we can run in parallel
246
+ max_processes = multiprocessing.cpu_count()
247
+ num_processes = max(1, min(args.num_processes, max_processes, len(molecules)))
248
+ parallel = num_processes > 1
249
+
250
+ if parallel:
251
+ for name, result in p_uimap(
252
+ process_fn,
253
+ molecules,
254
+ num_cpus=num_processes,
255
+ ):
256
+ metadata.append({"name": name, "result": result})
257
+ else:
258
+ for mol in tqdm(molecules):
259
+ name, result = process_fn(mol)
260
+ metadata.append({"name": name, "result": result})
261
+
262
+ # Load and group outputs
263
+ molecules = {}
264
+ for item in metadata:
265
+ if item["result"] == "failed":
266
+ continue
267
+
268
+ # Load the mol file
269
+ path = mol_output / f"{item['name']}.pkl"
270
+ with path.open("rb") as f:
271
+ mol = pickle.load(f) # noqa: S301
272
+ molecules[item["name"]] = mol
273
+
274
+ # Dump metadata
275
+ path = outdir / "results.csv"
276
+ metadata = pd.DataFrame(metadata)
277
+ metadata.to_csv(path)
278
+
279
+ # Dump the components
280
+ path = outdir / "ccd.pkl"
281
+ with path.open("wb") as f:
282
+ pickle.dump(molecules, f)
283
+
284
+
285
+ if __name__ == "__main__":
286
+ parser = argparse.ArgumentParser()
287
+ parser.add_argument("--components", type=str)
288
+ parser.add_argument("--outdir", type=str)
289
+ parser.add_argument(
290
+ "--num_processes",
291
+ type=int,
292
+ default=multiprocessing.cpu_count(),
293
+ )
294
+ args = parser.parse_args()
295
+ main(args)
protify/FastPLMs/boltz/scripts/process/cluster.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Create a mapping from structure and chain ID to MSA indices."""
2
+
3
+ import argparse
4
+ import hashlib
5
+ import json
6
+ import pickle
7
+ import subprocess
8
+ from pathlib import Path
9
+
10
+ import pandas as pd
11
+ from Bio import SeqIO
12
+
13
+
14
+ def hash_sequence(seq: str) -> str:
15
+ """Hash a sequence."""
16
+ return hashlib.sha256(seq.encode()).hexdigest()
17
+
18
+
19
+ def main(args: argparse.Namespace) -> None:
20
+ """Create clustering."""
21
+ # Set output directory
22
+ outdir = Path(args.outdir)
23
+ outdir.mkdir(parents=True, exist_ok=True)
24
+
25
+ # Split the sequences into proteins and nucleotides
26
+ with Path(args.sequences).open("r") as f:
27
+ data = list(SeqIO.parse(f, "fasta"))
28
+
29
+ proteins = set()
30
+ shorts = set()
31
+ nucleotides = set()
32
+
33
+ # Separate the sequences into proteins, nucleotides and short sequences
34
+ # Short sequences cause a bug in the clustering, so they are separated
35
+ for seq in data:
36
+ if set(str(seq.seq)).issubset({"A", "C", "G", "T", "U", "N"}):
37
+ nucleotides.add(str(seq.seq).strip())
38
+ elif len(str(seq.seq).strip()) < 10: # noqa: PLR2004
39
+ shorts.add(str(seq.seq).strip())
40
+ else:
41
+ proteins.add(str(seq.seq).strip())
42
+
43
+ # Run mmseqs on the protein data
44
+ proteins = [f">{hash_sequence(seq)}\n{seq}" for seq in proteins]
45
+ with (outdir / "proteins.fasta").open("w") as f:
46
+ f.write("\n".join(proteins))
47
+
48
+ subprocess.run(
49
+ f"{args.mmseqs} easy-cluster {outdir / 'proteins.fasta'} {outdir / 'clust_prot'} {outdir / 'tmp'} --min-seq-id 0.4", # noqa: E501
50
+ shell=True, # noqa: S602
51
+ check=True,
52
+ )
53
+
54
+ # Load protein clusters
55
+ clustering_path = outdir / "clust_prot_cluster.tsv"
56
+ protein_data = pd.read_csv(clustering_path, sep="\t", header=None)
57
+ clusters = protein_data[0]
58
+ items = protein_data[1]
59
+ clustering = dict(zip(list(items), list(clusters)))
60
+
61
+ # Each shqrt sequence is given an id
62
+ for short in shorts:
63
+ short_id = hash_sequence(short)
64
+ clustering[short_id] = short_id
65
+
66
+ # Each unique rna sequence is given an id
67
+ for nucl in nucleotides:
68
+ nucl_id = hash_sequence(nucl)
69
+ clustering[nucl_id] = nucl_id
70
+
71
+ # Load ligand data
72
+ with Path(args.ccd).open("rb") as handle:
73
+ ligand_data = pickle.load(handle) # noqa: S301
74
+
75
+ # Each unique ligand CCD is given an id
76
+ for ccd_code in ligand_data:
77
+ clustering[ccd_code] = ccd_code
78
+
79
+ # Save clustering
80
+ with (outdir / "clustering.json").open("w") as handle:
81
+ json.dump(clustering, handle)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ parser = argparse.ArgumentParser()
86
+ parser.add_argument(
87
+ "--sequences",
88
+ type=str,
89
+ help="Input to protein fasta.",
90
+ required=True,
91
+ )
92
+ parser.add_argument(
93
+ "--ccd",
94
+ type=str,
95
+ help="Input to rna fasta.",
96
+ required=True,
97
+ )
98
+ parser.add_argument(
99
+ "--outdir",
100
+ type=str,
101
+ help="Output directory.",
102
+ required=True,
103
+ )
104
+ parser.add_argument(
105
+ "--mmseqs",
106
+ type=str,
107
+ help="Path to mmseqs program.",
108
+ default="mmseqs",
109
+ )
110
+ args = parser.parse_args()
111
+ main(args)
protify/FastPLMs/boltz/scripts/process/mmcif.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from dataclasses import dataclass, replace
3
+ from typing import Optional
4
+
5
+ import gemmi
6
+ import numpy as np
7
+ from rdkit import rdBase
8
+ from rdkit.Chem import AllChem
9
+ from rdkit.Chem.rdchem import Conformer, Mol
10
+ from sklearn.neighbors import KDTree
11
+
12
+ from boltz.data import const
13
+ from boltz.data.types import (
14
+ Atom,
15
+ Bond,
16
+ Chain,
17
+ Connection,
18
+ Interface,
19
+ Residue,
20
+ Structure,
21
+ StructureInfo,
22
+ )
23
+
24
+ ####################################################################################################
25
+ # DATACLASSES
26
+ ####################################################################################################
27
+
28
+
29
+ @dataclass(frozen=True, slots=True)
30
+ class ParsedAtom:
31
+ """A parsed atom object."""
32
+
33
+ name: str
34
+ element: int
35
+ charge: int
36
+ coords: tuple[float, float, float]
37
+ conformer: tuple[float, float, float]
38
+ is_present: bool
39
+ chirality: int
40
+
41
+
42
+ @dataclass(frozen=True, slots=True)
43
+ class ParsedBond:
44
+ """A parsed bond object."""
45
+
46
+ atom_1: int
47
+ atom_2: int
48
+ type: int
49
+
50
+
51
+ @dataclass(frozen=True, slots=True)
52
+ class ParsedResidue:
53
+ """A parsed residue object."""
54
+
55
+ name: str
56
+ type: int
57
+ idx: int
58
+ atoms: list[ParsedAtom]
59
+ bonds: list[ParsedBond]
60
+ orig_idx: Optional[int]
61
+ atom_center: int
62
+ atom_disto: int
63
+ is_standard: bool
64
+ is_present: bool
65
+
66
+
67
+ @dataclass(frozen=True, slots=True)
68
+ class ParsedChain:
69
+ """A parsed chain object."""
70
+
71
+ name: str
72
+ entity: str
73
+ type: str
74
+ residues: list[ParsedResidue]
75
+ sequence: list[str]
76
+
77
+
78
+ @dataclass(frozen=True, slots=True)
79
+ class ParsedConnection:
80
+ """A parsed connection object."""
81
+
82
+ chain_1: str
83
+ chain_2: str
84
+ residue_index_1: int
85
+ residue_index_2: int
86
+ atom_index_1: str
87
+ atom_index_2: str
88
+
89
+
90
+ @dataclass(frozen=True, slots=True)
91
+ class ParsedStructure:
92
+ """A parsed structure object."""
93
+
94
+ data: Structure
95
+ info: StructureInfo
96
+ covalents: list[int]
97
+
98
+
99
+ ####################################################################################################
100
+ # HELPERS
101
+ ####################################################################################################
102
+
103
+
104
+ def get_dates(block: gemmi.cif.Block) -> tuple[str, str, str]:
105
+ """Get the deposited, released, and last revision dates.
106
+
107
+ Parameters
108
+ ----------
109
+ block : gemmi.cif.Block
110
+ The block to process.
111
+
112
+ Returns
113
+ -------
114
+ str
115
+ The deposited date.
116
+ str
117
+ The released date.
118
+ str
119
+ The last revision date.
120
+
121
+ """
122
+ deposited = "_pdbx_database_status.recvd_initial_deposition_date"
123
+ revision = "_pdbx_audit_revision_history.revision_date"
124
+ deposit_date = revision_date = release_date = ""
125
+ with contextlib.suppress(Exception):
126
+ deposit_date = block.find([deposited])[0][0]
127
+ release_date = block.find([revision])[0][0]
128
+ revision_date = block.find([revision])[-1][0]
129
+
130
+ return deposit_date, release_date, revision_date
131
+
132
+
133
+ def get_resolution(block: gemmi.cif.Block) -> float:
134
+ """Get the resolution from a gemmi structure.
135
+
136
+ Parameters
137
+ ----------
138
+ block : gemmi.cif.Block
139
+ The block to process.
140
+
141
+ Returns
142
+ -------
143
+ float
144
+ The resolution.
145
+
146
+ """
147
+ resolution = 0.0
148
+ for res_key in (
149
+ "_refine.ls_d_res_high",
150
+ "_em_3d_reconstruction.resolution",
151
+ "_reflns.d_resolution_high",
152
+ ):
153
+ with contextlib.suppress(Exception):
154
+ resolution = float(block.find([res_key])[0].str(0))
155
+ break
156
+ return resolution
157
+
158
+
159
+ def get_method(block: gemmi.cif.Block) -> str:
160
+ """Get the method from a gemmi structure.
161
+
162
+ Parameters
163
+ ----------
164
+ block : gemmi.cif.Block
165
+ The block to process.
166
+
167
+ Returns
168
+ -------
169
+ str
170
+ The method.
171
+
172
+ """
173
+ method = ""
174
+ method_key = "_exptl.method"
175
+ with contextlib.suppress(Exception):
176
+ methods = block.find([method_key])
177
+ method = ",".join([m.str(0).lower() for m in methods])
178
+
179
+ return method
180
+
181
+
182
+ def convert_atom_name(name: str) -> tuple[int, int, int, int]:
183
+ """Convert an atom name to a standard format.
184
+
185
+ Parameters
186
+ ----------
187
+ name : str
188
+ The atom name.
189
+
190
+ Returns
191
+ -------
192
+ tuple[int, int, int, int]
193
+ The converted atom name.
194
+
195
+ """
196
+ name = name.strip()
197
+ name = [ord(c) - 32 for c in name]
198
+ name = name + [0] * (4 - len(name))
199
+ return tuple(name)
200
+
201
+
202
+ def get_unk_token(dtype: gemmi.PolymerType) -> str:
203
+ """Get the unknown token for a given entity type.
204
+
205
+ Parameters
206
+ ----------
207
+ dtype : gemmi.EntityType
208
+ The entity type.
209
+
210
+ Returns
211
+ -------
212
+ str
213
+ The unknown token.
214
+
215
+ """
216
+ if dtype == gemmi.PolymerType.PeptideL:
217
+ unk = const.unk_token["PROTEIN"]
218
+ elif dtype == gemmi.PolymerType.Dna:
219
+ unk = const.unk_token["DNA"]
220
+ elif dtype == gemmi.PolymerType.Rna:
221
+ unk = const.unk_token["RNA"]
222
+ else:
223
+ msg = f"Unknown polymer type: {dtype}"
224
+ raise ValueError(msg)
225
+
226
+ return unk
227
+
228
+
229
+ def get_conformer(mol: Mol) -> Conformer:
230
+ """Retrieve an rdkit object for a deemed conformer.
231
+
232
+ Inspired by `pdbeccdutils.core.component.Component`.
233
+
234
+ Parameters
235
+ ----------
236
+ mol: Mol
237
+ The molecule to process.
238
+
239
+ Returns
240
+ -------
241
+ Conformer
242
+ The desired conformer, if any.
243
+
244
+ Raises
245
+ ------
246
+ ValueError
247
+ If there are no conformers of the given tyoe.
248
+
249
+ """
250
+ for c in mol.GetConformers():
251
+ try:
252
+ if c.GetProp("name") == "Computed":
253
+ return c
254
+ except KeyError: # noqa: PERF203
255
+ pass
256
+
257
+ for c in mol.GetConformers():
258
+ try:
259
+ if c.GetProp("name") == "Ideal":
260
+ return c
261
+ except KeyError: # noqa: PERF203
262
+ pass
263
+
264
+ msg = "Conformer does not exist."
265
+ raise ValueError(msg)
266
+
267
+
268
+ def compute_covalent_ligands(
269
+ connections: list[gemmi.Connection],
270
+ subchain_map: dict[tuple[str, int], str],
271
+ entities: dict[str, gemmi.Entity],
272
+ ) -> set[str]:
273
+ """Compute the covalent ligands from a list of connections.
274
+
275
+ Parameters
276
+ ----------
277
+ connections: List[gemmi.Connection]
278
+ The connections to process.
279
+ subchain_map: dict[tuple[str, int], str]
280
+ The mapping from chain, residue index to subchain name.
281
+ entities: dict[str, gemmi.Entity]
282
+ The entities in the structure.
283
+
284
+ Returns
285
+ -------
286
+ set
287
+ The covalent ligand subchains.
288
+
289
+ """
290
+ # Get covalent chain ids
291
+ covalent_chain_ids = set()
292
+ for connection in connections:
293
+ if connection.type.name != "Covale":
294
+ continue
295
+
296
+ # Map to correct subchain
297
+ chain_1_name = connection.partner1.chain_name
298
+ chain_2_name = connection.partner2.chain_name
299
+
300
+ res_1_id = connection.partner1.res_id.seqid
301
+ res_1_id = str(res_1_id.num) + str(res_1_id.icode).strip()
302
+
303
+ res_2_id = connection.partner2.res_id.seqid
304
+ res_2_id = str(res_2_id.num) + str(res_2_id.icode).strip()
305
+
306
+ subchain_1 = subchain_map[(chain_1_name, res_1_id)]
307
+ subchain_2 = subchain_map[(chain_2_name, res_2_id)]
308
+
309
+ # If non-polymer or branched, add to set
310
+ entity_1 = entities[subchain_1].entity_type.name
311
+ entity_2 = entities[subchain_2].entity_type.name
312
+
313
+ if entity_1 in {"NonPolymer", "Branched"}:
314
+ covalent_chain_ids.add(subchain_1)
315
+ if entity_2 in {"NonPolymer", "Branched"}:
316
+ covalent_chain_ids.add(subchain_2)
317
+
318
+ return covalent_chain_ids
319
+
320
+
321
+ def compute_interfaces(atom_data: np.ndarray, chain_data: np.ndarray) -> np.ndarray:
322
+ """Compute the chain-chain interfaces from a gemmi structure.
323
+
324
+ Parameters
325
+ ----------
326
+ atom_data : List[tuple]
327
+ The atom data.
328
+ chain_data : List[tuple]
329
+ The chain data.
330
+
331
+ Returns
332
+ -------
333
+ List[tuple[int, int]]
334
+ The interfaces.
335
+
336
+ """
337
+ # Compute chain_id per atom
338
+ chain_ids = []
339
+ for idx, chain in enumerate(chain_data):
340
+ chain_ids.extend([idx] * chain["atom_num"])
341
+ chain_ids = np.array(chain_ids)
342
+
343
+ # Filte to present atoms
344
+ coords = atom_data["coords"]
345
+ mask = atom_data["is_present"]
346
+
347
+ coords = coords[mask]
348
+ chain_ids = chain_ids[mask]
349
+
350
+ # Compute the distance matrix
351
+ tree = KDTree(coords, metric="euclidean")
352
+ query = tree.query_radius(coords, const.atom_interface_cutoff)
353
+
354
+ # Get unique chain pairs
355
+ interfaces = set()
356
+ for c1, pairs in zip(chain_ids, query):
357
+ chains = np.unique(chain_ids[pairs])
358
+ chains = chains[chains != c1]
359
+ interfaces.update((c1, c2) for c2 in chains)
360
+
361
+ # Get unique chain pairs
362
+ interfaces = [(min(i, j), max(i, j)) for i, j in interfaces]
363
+ interfaces = list({(int(i), int(j)) for i, j in interfaces})
364
+ interfaces = np.array(interfaces, dtype=Interface)
365
+ return interfaces
366
+
367
+
368
+ ####################################################################################################
369
+ # PARSING
370
+ ####################################################################################################
371
+
372
+
373
+ def parse_ccd_residue( # noqa: PLR0915, C901
374
+ name: str,
375
+ components: dict[str, Mol],
376
+ res_idx: int,
377
+ gemmi_mol: Optional[gemmi.Residue] = None,
378
+ is_covalent: bool = False,
379
+ ) -> Optional[ParsedResidue]:
380
+ """Parse an MMCIF ligand.
381
+
382
+ First tries to get the SMILES string from the RCSB.
383
+ Then, tries to infer atom ordering using RDKit.
384
+
385
+ Parameters
386
+ ----------
387
+ name: str
388
+ The name of the molecule to parse.
389
+ components : dict
390
+ The preprocessed PDB components dictionary.
391
+ res_idx : int
392
+ The residue index.
393
+ gemmi_mol : Optional[gemmi.Residue]
394
+ The PDB molecule, as a gemmi Residue object, if any.
395
+
396
+ Returns
397
+ -------
398
+ ParsedResidue, optional
399
+ The output ParsedResidue, if successful.
400
+
401
+ """
402
+ unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
403
+ # Check if we have a PDB structure for this residue,
404
+ # it could be a missing residue from the sequence
405
+ is_present = gemmi_mol is not None
406
+
407
+ # Save original index (required for parsing connections)
408
+ if is_present:
409
+ orig_idx = gemmi_mol.seqid
410
+ orig_idx = str(orig_idx.num) + str(orig_idx.icode).strip()
411
+ else:
412
+ orig_idx = None
413
+
414
+ # Get reference component
415
+ ref_mol = components[name]
416
+
417
+ # Remove hydrogens
418
+ ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
419
+
420
+ # Check if this is a single atom CCD residue
421
+ if ref_mol.GetNumAtoms() == 1:
422
+ pos = (0, 0, 0)
423
+ if is_present:
424
+ pos = (
425
+ gemmi_mol[0].pos.x,
426
+ gemmi_mol[0].pos.y,
427
+ gemmi_mol[0].pos.z,
428
+ )
429
+ ref_atom = ref_mol.GetAtoms()[0]
430
+ chirality_type = const.chirality_type_ids.get(
431
+ str(ref_atom.GetChiralTag()), unk_chirality
432
+ )
433
+ atom = ParsedAtom(
434
+ name=ref_atom.GetProp("name"),
435
+ element=ref_atom.GetAtomicNum(),
436
+ charge=ref_atom.GetFormalCharge(),
437
+ coords=pos,
438
+ conformer=(0, 0, 0),
439
+ is_present=is_present,
440
+ chirality=chirality_type,
441
+ )
442
+ unk_prot_id = const.unk_token_ids["PROTEIN"]
443
+ residue = ParsedResidue(
444
+ name=name,
445
+ type=unk_prot_id,
446
+ atoms=[atom],
447
+ bonds=[],
448
+ idx=res_idx,
449
+ orig_idx=orig_idx,
450
+ atom_center=0, # Placeholder, no center
451
+ atom_disto=0, # Placeholder, no center
452
+ is_standard=False,
453
+ is_present=is_present,
454
+ )
455
+ return residue
456
+
457
+ # If multi-atom, start by getting the PDB coordinates
458
+ pdb_pos = {}
459
+ if is_present:
460
+ # Match atoms based on names
461
+ for atom in gemmi_mol:
462
+ atom: gemmi.Atom
463
+ pos = (atom.pos.x, atom.pos.y, atom.pos.z)
464
+ pdb_pos[atom.name] = pos
465
+
466
+ # Get reference conformer coordinates
467
+ conformer = get_conformer(ref_mol)
468
+
469
+ # Parse each atom in order of the reference mol
470
+ atoms = []
471
+ atom_idx = 0
472
+ idx_map = {} # Used for bonds later
473
+
474
+ for i, atom in enumerate(ref_mol.GetAtoms()):
475
+ # Get atom name, charge, element and reference coordinates
476
+ atom_name = atom.GetProp("name")
477
+ charge = atom.GetFormalCharge()
478
+ element = atom.GetAtomicNum()
479
+ ref_coords = conformer.GetAtomPosition(atom.GetIdx())
480
+ ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)
481
+ chirality_type = const.chirality_type_ids.get(
482
+ str(atom.GetChiralTag()), unk_chirality
483
+ )
484
+
485
+ # If the atom is a leaving atom, skip if not in the PDB and is_covalent
486
+ if (
487
+ int(atom.GetProp("leaving_atom")) == 1
488
+ and is_covalent
489
+ and (atom_name not in pdb_pos)
490
+ ):
491
+ continue
492
+
493
+ # Get PDB coordinates, if any
494
+ coords = pdb_pos.get(atom_name)
495
+ if coords is None:
496
+ atom_is_present = False
497
+ coords = (0, 0, 0)
498
+ else:
499
+ atom_is_present = True
500
+
501
+ # Add atom to list
502
+ atoms.append(
503
+ ParsedAtom(
504
+ name=atom_name,
505
+ element=element,
506
+ charge=charge,
507
+ coords=coords,
508
+ conformer=ref_coords,
509
+ is_present=atom_is_present,
510
+ chirality=chirality_type,
511
+ )
512
+ )
513
+ idx_map[i] = atom_idx
514
+ atom_idx += 1
515
+
516
+ # Load bonds
517
+ bonds = []
518
+ unk_bond = const.bond_type_ids[const.unk_bond_type]
519
+ for bond in ref_mol.GetBonds():
520
+ idx_1 = bond.GetBeginAtomIdx()
521
+ idx_2 = bond.GetEndAtomIdx()
522
+
523
+ # Skip bonds with atoms ignored
524
+ if (idx_1 not in idx_map) or (idx_2 not in idx_map):
525
+ continue
526
+
527
+ idx_1 = idx_map[idx_1]
528
+ idx_2 = idx_map[idx_2]
529
+ start = min(idx_1, idx_2)
530
+ end = max(idx_1, idx_2)
531
+ bond_type = bond.GetBondType().name
532
+ bond_type = const.bond_type_ids.get(bond_type, unk_bond)
533
+ bonds.append(ParsedBond(start, end, bond_type))
534
+
535
+ unk_prot_id = const.unk_token_ids["PROTEIN"]
536
+ return ParsedResidue(
537
+ name=name,
538
+ type=unk_prot_id,
539
+ atoms=atoms,
540
+ bonds=bonds,
541
+ idx=res_idx,
542
+ atom_center=0,
543
+ atom_disto=0,
544
+ orig_idx=orig_idx,
545
+ is_standard=False,
546
+ is_present=is_present,
547
+ )
548
+
549
+
550
+ def parse_polymer( # noqa: C901, PLR0915, PLR0912
551
+ polymer: gemmi.ResidueSpan,
552
+ polymer_type: gemmi.PolymerType,
553
+ sequence: list[str],
554
+ chain_id: str,
555
+ entity: str,
556
+ components: dict[str, Mol],
557
+ ) -> Optional[ParsedChain]:
558
+ """Process a gemmi Polymer into a chain object.
559
+
560
+ Performs alignment of the full sequence to the polymer
561
+ residues. Loads coordinates and masks for the atoms in
562
+ the polymer, following the ordering in const.atom_order.
563
+
564
+ Parameters
565
+ ----------
566
+ polymer : gemmi.ResidueSpan
567
+ The polymer to process.
568
+ polymer_type : gemmi.PolymerType
569
+ The polymer type.
570
+ sequence : str
571
+ The full sequence of the polymer.
572
+ chain_id : str
573
+ The chain identifier.
574
+ entity : str
575
+ The entity name.
576
+ components : dict[str, Mol]
577
+ The preprocessed PDB components dictionary.
578
+
579
+ Returns
580
+ -------
581
+ ParsedChain, optional
582
+ The output chain, if successful.
583
+
584
+ Raises
585
+ ------
586
+ ValueError
587
+ If the alignment fails.
588
+
589
+ """
590
+ # Get unknown chirality token
591
+ unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
592
+
593
+ # Ignore microheterogenities (pick first)
594
+ sequence = [gemmi.Entity.first_mon(item) for item in sequence]
595
+
596
+ # Align full sequence to polymer residues
597
+ # This is a simple way to handle all the different numbering schemes
598
+ result = gemmi.align_sequence_to_polymer(
599
+ sequence,
600
+ polymer,
601
+ polymer_type,
602
+ gemmi.AlignmentScoring(),
603
+ )
604
+
605
+ # Get coordinates and masks
606
+ i = 0
607
+ ref_res = set(const.tokens)
608
+ parsed = []
609
+ for j, match in enumerate(result.match_string):
610
+ # Get residue name from sequence
611
+ res_name = sequence[j]
612
+
613
+ # Check if we have a match in the structure
614
+ res = None
615
+ name_to_atom = {}
616
+
617
+ if match == "|":
618
+ # Get pdb residue
619
+ res = polymer[i]
620
+ name_to_atom = {a.name.upper(): a for a in res}
621
+
622
+ # Double check the match
623
+ if res.name != res_name:
624
+ msg = "Alignment mismatch!"
625
+ raise ValueError(msg)
626
+
627
+ # Increment polymer index
628
+ i += 1
629
+
630
+ # Map MSE to MET, put the selenium atom in the sulphur column
631
+ if res_name == "MSE":
632
+ res_name = "MET"
633
+ if "SE" in name_to_atom:
634
+ name_to_atom["SD"] = name_to_atom["SE"]
635
+
636
+ # Handle non-standard residues
637
+ elif res_name not in ref_res:
638
+ residue = parse_ccd_residue(
639
+ name=res_name,
640
+ components=components,
641
+ res_idx=j,
642
+ gemmi_mol=res,
643
+ is_covalent=True,
644
+ )
645
+ parsed.append(residue)
646
+ continue
647
+
648
+ # Load regular residues
649
+ ref_mol = components[res_name]
650
+ ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
651
+ ref_conformer = get_conformer(ref_mol)
652
+
653
+ # Only use reference atoms set in constants
654
+ ref_name_to_atom = {a.GetProp("name"): a for a in ref_mol.GetAtoms()}
655
+ ref_atoms = [ref_name_to_atom[a] for a in const.ref_atoms[res_name]]
656
+
657
+ # Iterate, always in the same order
658
+ atoms: list[ParsedAtom] = []
659
+
660
+ for ref_atom in ref_atoms:
661
+ # Get atom name
662
+ atom_name = ref_atom.GetProp("name")
663
+ idx = ref_atom.GetIdx()
664
+
665
+ # Get conformer coordinates
666
+ ref_coords = ref_conformer.GetAtomPosition(idx)
667
+ ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)
668
+
669
+ # Get coordinated from PDB
670
+ if atom_name in name_to_atom:
671
+ atom = name_to_atom[atom_name]
672
+ atom_is_present = True
673
+ coords = (atom.pos.x, atom.pos.y, atom.pos.z)
674
+ else:
675
+ atom_is_present = False
676
+ coords = (0, 0, 0)
677
+
678
+ # Add atom to list
679
+ atoms.append(
680
+ ParsedAtom(
681
+ name=atom_name,
682
+ element=ref_atom.GetAtomicNum(),
683
+ charge=ref_atom.GetFormalCharge(),
684
+ coords=coords,
685
+ conformer=ref_coords,
686
+ is_present=atom_is_present,
687
+ chirality=const.chirality_type_ids.get(
688
+ str(ref_atom.GetChiralTag()), unk_chirality
689
+ ),
690
+ )
691
+ )
692
+
693
+ # Fix naming errors in arginine residues where NH2 is
694
+ # incorrectly assigned to be closer to CD than NH1
695
+ if (res is not None) and (res_name == "ARG"):
696
+ ref_atoms: list[str] = const.ref_atoms["ARG"]
697
+ cd = atoms[ref_atoms.index("CD")]
698
+ nh1 = atoms[ref_atoms.index("NH1")]
699
+ nh2 = atoms[ref_atoms.index("NH2")]
700
+
701
+ cd_coords = np.array(cd.coords)
702
+ nh1_coords = np.array(nh1.coords)
703
+ nh2_coords = np.array(nh2.coords)
704
+
705
+ if all(atom.is_present for atom in (cd, nh1, nh2)) and (
706
+ np.linalg.norm(nh1_coords - cd_coords)
707
+ > np.linalg.norm(nh2_coords - cd_coords)
708
+ ):
709
+ atoms[ref_atoms.index("NH1")] = replace(nh1, coords=nh2.coords)
710
+ atoms[ref_atoms.index("NH2")] = replace(nh2, coords=nh1.coords)
711
+
712
+ # Add residue to parsed list
713
+ if res is not None:
714
+ orig_idx = res.seqid
715
+ orig_idx = str(orig_idx.num) + str(orig_idx.icode).strip()
716
+ else:
717
+ orig_idx = None
718
+
719
+ atom_center = const.res_to_center_atom_id[res_name]
720
+ atom_disto = const.res_to_disto_atom_id[res_name]
721
+ parsed.append(
722
+ ParsedResidue(
723
+ name=res_name,
724
+ type=const.token_ids[res_name],
725
+ atoms=atoms,
726
+ bonds=[],
727
+ idx=j,
728
+ atom_center=atom_center,
729
+ atom_disto=atom_disto,
730
+ is_standard=True,
731
+ is_present=res is not None,
732
+ orig_idx=orig_idx,
733
+ )
734
+ )
735
+
736
+ # Get polymer class
737
+ if polymer_type == gemmi.PolymerType.PeptideL:
738
+ chain_type = const.chain_type_ids["PROTEIN"]
739
+ elif polymer_type == gemmi.PolymerType.Dna:
740
+ chain_type = const.chain_type_ids["DNA"]
741
+ elif polymer_type == gemmi.PolymerType.Rna:
742
+ chain_type = const.chain_type_ids["RNA"]
743
+
744
+ # Return polymer object
745
+ return ParsedChain(
746
+ name=chain_id,
747
+ entity=entity,
748
+ residues=parsed,
749
+ type=chain_type,
750
+ sequence=gemmi.one_letter_code(sequence),
751
+ )
752
+
753
+
754
+ def parse_connection(
755
+ connection: gemmi.Connection,
756
+ chains: list[ParsedChain],
757
+ subchain_map: dict[tuple[str, int], str],
758
+ ) -> ParsedConnection:
759
+ """Parse (covalent) connection from a gemmi Connection.
760
+
761
+ Parameters
762
+ ----------
763
+ connections : gemmi.ConnectionList
764
+ The connection list to parse.
765
+ chains : List[Chain]
766
+ The parsed chains.
767
+ subchain_map : dict[tuple[str, int], str]
768
+ The mapping from chain, residue index to subchain name.
769
+
770
+ Returns
771
+ -------
772
+ List[Connection]
773
+ The parsed connections.
774
+
775
+ """
776
+ # Map to correct subchains
777
+ chain_1_name = connection.partner1.chain_name
778
+ chain_2_name = connection.partner2.chain_name
779
+
780
+ res_1_id = connection.partner1.res_id.seqid
781
+ res_1_id = str(res_1_id.num) + str(res_1_id.icode).strip()
782
+
783
+ res_2_id = connection.partner2.res_id.seqid
784
+ res_2_id = str(res_2_id.num) + str(res_2_id.icode).strip()
785
+
786
+ subchain_1 = subchain_map[(chain_1_name, res_1_id)]
787
+ subchain_2 = subchain_map[(chain_2_name, res_2_id)]
788
+
789
+ # Get chain indices
790
+ chain_1 = next(chain for chain in chains if (chain.name == subchain_1))
791
+ chain_2 = next(chain for chain in chains if (chain.name == subchain_2))
792
+
793
+ # Get residue indices
794
+ res_1_idx, res_1 = next(
795
+ (idx, res)
796
+ for idx, res in enumerate(chain_1.residues)
797
+ if (res.orig_idx == res_1_id)
798
+ )
799
+ res_2_idx, res_2 = next(
800
+ (idx, res)
801
+ for idx, res in enumerate(chain_2.residues)
802
+ if (res.orig_idx == res_2_id)
803
+ )
804
+
805
+ # Get atom indices
806
+ atom_index_1 = next(
807
+ idx
808
+ for idx, atom in enumerate(res_1.atoms)
809
+ if atom.name == connection.partner1.atom_name
810
+ )
811
+ atom_index_2 = next(
812
+ idx
813
+ for idx, atom in enumerate(res_2.atoms)
814
+ if atom.name == connection.partner2.atom_name
815
+ )
816
+
817
+ conn = ParsedConnection(
818
+ chain_1=subchain_1,
819
+ chain_2=subchain_2,
820
+ residue_index_1=res_1_idx,
821
+ residue_index_2=res_2_idx,
822
+ atom_index_1=atom_index_1,
823
+ atom_index_2=atom_index_2,
824
+ )
825
+
826
+ return conn
827
+
828
+
829
+ def parse_mmcif( # noqa: C901, PLR0915, PLR0912
830
+ path: str,
831
+ components: dict[str, Mol],
832
+ use_assembly: bool = True,
833
+ ) -> ParsedStructure:
834
+ """Parse a structure in MMCIF format.
835
+
836
+ Parameters
837
+ ----------
838
+ mmcif_file : PathLike
839
+ Path to the MMCIF file.
840
+ components: dict[str, Mol]
841
+ The preprocessed PDB components dictionary.
842
+ use_assembly: bool
843
+ Whether to use the first assembly.
844
+
845
+ Returns
846
+ -------
847
+ ParsedStructure
848
+ The parsed structure.
849
+
850
+ """
851
+ # Disable rdkit warnings
852
+ blocker = rdBase.BlockLogs() # noqa: F841
853
+
854
+ # Parse MMCIF input file
855
+ block = gemmi.cif.read(str(path))[0]
856
+
857
+ # Extract medatadata
858
+ deposit_date, release_date, revision_date = get_dates(block)
859
+ resolution = get_resolution(block)
860
+ method = get_method(block)
861
+
862
+ # Load structure object
863
+ structure = gemmi.make_structure_from_block(block)
864
+
865
+ # Clean up the structure
866
+ structure.merge_chain_parts()
867
+ structure.remove_waters()
868
+ structure.remove_hydrogens()
869
+ structure.remove_alternative_conformations()
870
+ structure.remove_empty_chains()
871
+
872
+ # Expand assembly 1
873
+ if use_assembly and structure.assemblies:
874
+ how = gemmi.HowToNameCopiedChain.AddNumber
875
+ assembly_name = structure.assemblies[0].name
876
+ structure.transform_to_assembly(assembly_name, how=how)
877
+
878
+ # Parse entities
879
+ # Create mapping from subchain id to entity
880
+ entities: dict[str, gemmi.Entity] = {}
881
+ entity_ids: dict[str, int] = {}
882
+ for entity_id, entity in enumerate(structure.entities):
883
+ entity: gemmi.Entity
884
+ if entity.entity_type.name == "Water":
885
+ continue
886
+ for subchain_id in entity.subchains:
887
+ entities[subchain_id] = entity
888
+ entity_ids[subchain_id] = entity_id
889
+
890
+ # Create mapping from chain, residue to subchains
891
+ # since a Connection uses the chains and not subchins
892
+ subchain_map = {}
893
+ for chain in structure[0]:
894
+ for residue in chain:
895
+ seq_id = residue.seqid
896
+ seq_id = str(seq_id.num) + str(seq_id.icode).strip()
897
+ subchain_map[(chain.name, seq_id)] = residue.subchain
898
+
899
+ # Find covalent ligands
900
+ covalent_chain_ids = compute_covalent_ligands(
901
+ connections=structure.connections,
902
+ subchain_map=subchain_map,
903
+ entities=entities,
904
+ )
905
+
906
+ # Parse chains
907
+ chains: list[ParsedChain] = []
908
+ chain_seqs = []
909
+ for raw_chain in structure[0].subchains():
910
+ # Check chain type
911
+ subchain_id = raw_chain.subchain_id()
912
+ entity: gemmi.Entity = entities[subchain_id]
913
+ entity_type = entity.entity_type.name
914
+
915
+ # Parse a polymer
916
+ if entity_type == "Polymer":
917
+ # Skip PeptideD, DnaRnaHybrid, Pna, Other
918
+ if entity.polymer_type.name not in {
919
+ "PeptideL",
920
+ "Dna",
921
+ "Rna",
922
+ }:
923
+ continue
924
+
925
+ # Add polymer if successful
926
+ parsed_polymer = parse_polymer(
927
+ polymer=raw_chain,
928
+ polymer_type=entity.polymer_type,
929
+ sequence=entity.full_sequence,
930
+ chain_id=subchain_id,
931
+ entity=entity.name,
932
+ components=components,
933
+ )
934
+ if parsed_polymer is not None:
935
+ chains.append(parsed_polymer)
936
+ chain_seqs.append(parsed_polymer.sequence)
937
+
938
+ # Parse a non-polymer
939
+ elif entity_type in {"NonPolymer", "Branched"}:
940
+ # Skip UNL or other missing ligands
941
+ if any(components.get(lig.name) is None for lig in raw_chain):
942
+ continue
943
+
944
+ residues = []
945
+ for lig_idx, ligand in enumerate(raw_chain):
946
+ # Check if ligand is covalent
947
+ if entity_type == "Branched":
948
+ is_covalent = True
949
+ else:
950
+ is_covalent = subchain_id in covalent_chain_ids
951
+
952
+ ligand: gemmi.Residue
953
+ residue = parse_ccd_residue(
954
+ name=ligand.name,
955
+ components=components,
956
+ res_idx=lig_idx,
957
+ gemmi_mol=ligand,
958
+ is_covalent=is_covalent,
959
+ )
960
+ residues.append(residue)
961
+
962
+ if residues:
963
+ chains.append(
964
+ ParsedChain(
965
+ name=subchain_id,
966
+ entity=entity.name,
967
+ residues=residues,
968
+ type=const.chain_type_ids["NONPOLYMER"],
969
+ sequence=None,
970
+ )
971
+ )
972
+
973
+ # If no chains parsed fail
974
+ if not chains:
975
+ msg = "No chains parsed!"
976
+ raise ValueError(msg)
977
+
978
+ # Parse covalent connections
979
+ connections: list[ParsedConnection] = []
980
+ for connection in structure.connections:
981
+ # Skip non-covalent connections
982
+ connection: gemmi.Connection
983
+ if connection.type.name != "Covale":
984
+ continue
985
+
986
+ parsed_connection = parse_connection(
987
+ connection=connection,
988
+ chains=chains,
989
+ subchain_map=subchain_map,
990
+ )
991
+ connections.append(parsed_connection)
992
+
993
+ # Create tables
994
+ atom_data = []
995
+ bond_data = []
996
+ res_data = []
997
+ chain_data = []
998
+ connection_data = []
999
+
1000
+ # Convert parsed chains to tables
1001
+ atom_idx = 0
1002
+ res_idx = 0
1003
+ asym_id = 0
1004
+ sym_count = {}
1005
+ chain_to_idx = {}
1006
+ res_to_idx = {}
1007
+
1008
+ for asym_id, chain in enumerate(chains):
1009
+ # Compute number of atoms and residues
1010
+ res_num = len(chain.residues)
1011
+ atom_num = sum(len(res.atoms) for res in chain.residues)
1012
+
1013
+ # Find all copies of this chain in the assembly
1014
+ entity_id = entity_ids[chain.name]
1015
+ sym_id = sym_count.get(entity_id, 0)
1016
+ chain_data.append(
1017
+ (
1018
+ chain.name,
1019
+ chain.type,
1020
+ entity_id,
1021
+ sym_id,
1022
+ asym_id,
1023
+ atom_idx,
1024
+ atom_num,
1025
+ res_idx,
1026
+ res_num,
1027
+ )
1028
+ )
1029
+ chain_to_idx[chain.name] = asym_id
1030
+ sym_count[entity_id] = sym_id + 1
1031
+
1032
+ # Add residue, atom, bond, data
1033
+ for i, res in enumerate(chain.residues):
1034
+ atom_center = atom_idx + res.atom_center
1035
+ atom_disto = atom_idx + res.atom_disto
1036
+ res_data.append(
1037
+ (
1038
+ res.name,
1039
+ res.type,
1040
+ res.idx,
1041
+ atom_idx,
1042
+ len(res.atoms),
1043
+ atom_center,
1044
+ atom_disto,
1045
+ res.is_standard,
1046
+ res.is_present,
1047
+ )
1048
+ )
1049
+ res_to_idx[(chain.name, i)] = (res_idx, atom_idx)
1050
+
1051
+ for bond in res.bonds:
1052
+ atom_1 = atom_idx + bond.atom_1
1053
+ atom_2 = atom_idx + bond.atom_2
1054
+ bond_data.append((atom_1, atom_2, bond.type))
1055
+
1056
+ for atom in res.atoms:
1057
+ atom_data.append(
1058
+ (
1059
+ convert_atom_name(atom.name),
1060
+ atom.element,
1061
+ atom.charge,
1062
+ atom.coords,
1063
+ atom.conformer,
1064
+ atom.is_present,
1065
+ atom.chirality,
1066
+ )
1067
+ )
1068
+ atom_idx += 1
1069
+
1070
+ res_idx += 1
1071
+
1072
+ # Convert connections to tables
1073
+ for conn in connections:
1074
+ chain_1_idx = chain_to_idx[conn.chain_1]
1075
+ chain_2_idx = chain_to_idx[conn.chain_2]
1076
+ res_1_idx, atom_1_offset = res_to_idx[(conn.chain_1, conn.residue_index_1)]
1077
+ res_2_idx, atom_2_offset = res_to_idx[(conn.chain_2, conn.residue_index_2)]
1078
+ atom_1_idx = atom_1_offset + conn.atom_index_1
1079
+ atom_2_idx = atom_2_offset + conn.atom_index_2
1080
+ connection_data.append(
1081
+ (
1082
+ chain_1_idx,
1083
+ chain_2_idx,
1084
+ res_1_idx,
1085
+ res_2_idx,
1086
+ atom_1_idx,
1087
+ atom_2_idx,
1088
+ )
1089
+ )
1090
+
1091
+ # Convert into datatypes
1092
+ atoms = np.array(atom_data, dtype=Atom)
1093
+ bonds = np.array(bond_data, dtype=Bond)
1094
+ residues = np.array(res_data, dtype=Residue)
1095
+ chains = np.array(chain_data, dtype=Chain)
1096
+ connections = np.array(connection_data, dtype=Connection)
1097
+ mask = np.ones(len(chain_data), dtype=bool)
1098
+
1099
+ # Compute interface chains (find chains with a heavy atom within 5A)
1100
+ interfaces = compute_interfaces(atoms, chains)
1101
+
1102
+ # Return parsed structure
1103
+ info = StructureInfo(
1104
+ deposited=deposit_date,
1105
+ revised=revision_date,
1106
+ released=release_date,
1107
+ resolution=resolution,
1108
+ method=method,
1109
+ num_chains=len(chains),
1110
+ num_interfaces=len(interfaces),
1111
+ )
1112
+
1113
+ data = Structure(
1114
+ atoms=atoms,
1115
+ bonds=bonds,
1116
+ residues=residues,
1117
+ chains=chains,
1118
+ connections=connections,
1119
+ interfaces=interfaces,
1120
+ mask=mask,
1121
+ )
1122
+
1123
+ return ParsedStructure(data=data, info=info, covalents=[])
protify/FastPLMs/boltz/scripts/process/msa.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import multiprocessing
3
+ from dataclasses import asdict
4
+ from functools import partial
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ from p_tqdm import p_umap
10
+ from redis import Redis
11
+ from tqdm import tqdm
12
+
13
+ from boltz.data.parse.a3m import parse_a3m
14
+
15
+
16
+ class Resource:
17
+ """A shared resource for processing."""
18
+
19
+ def __init__(self, host: str, port: int) -> None:
20
+ """Initialize the redis database."""
21
+ self._redis = Redis(host=host, port=port)
22
+
23
+ def get(self, key: str) -> Any: # noqa: ANN401
24
+ """Get an item from the Redis database."""
25
+ return self._redis.get(key)
26
+
27
+ def __getitem__(self, key: str) -> Any: # noqa: ANN401
28
+ """Get an item from the resource."""
29
+ out = self.get(key)
30
+ if out is None:
31
+ raise KeyError(key)
32
+ return out
33
+
34
+
35
+ def process_msa(
36
+ path: Path,
37
+ outdir: str,
38
+ max_seqs: int,
39
+ resource: Resource,
40
+ ) -> None:
41
+ """Run processing in a worker thread."""
42
+ outdir = Path(outdir)
43
+ out_path = outdir / f"{path.stem}.npz"
44
+ if not out_path.exists():
45
+ msa = parse_a3m(path, resource, max_seqs)
46
+ np.savez_compressed(out_path, **asdict(msa))
47
+
48
+
49
+ def process(args) -> None:
50
+ """Run the data processing task."""
51
+ # Create output directory
52
+ args.outdir.mkdir(parents=True, exist_ok=True)
53
+
54
+ # Load the resource
55
+ resource = Resource(host=args.redis_host, port=args.redis_port)
56
+
57
+ # Get data points
58
+ print("Fetching data...")
59
+ data = list(args.msadir.rglob("*.a3m*"))
60
+ print(f"Found {len(data)} MSA's.")
61
+
62
+ # Check if we can run in parallel
63
+ max_processes = multiprocessing.cpu_count()
64
+ num_processes = max(1, min(args.num_processes, max_processes, len(data)))
65
+ parallel = num_processes > 1
66
+
67
+ # Run processing
68
+ if parallel:
69
+ # Create processing function
70
+ fn = partial(
71
+ process_msa,
72
+ outdir=args.outdir,
73
+ max_seqs=args.max_seqs,
74
+ resource=resource,
75
+ )
76
+
77
+ # Run in parallel
78
+ p_umap(fn, data, num_cpus=num_processes)
79
+
80
+ else:
81
+ # Run in serial
82
+ for path in tqdm(data):
83
+ process_msa(
84
+ path,
85
+ outdir=args.outdir,
86
+ max_seqs=args.max_seqs,
87
+ resource=resource,
88
+ )
89
+
90
+
91
+ if __name__ == "__main__":
92
+ parser = argparse.ArgumentParser(description="Process MSA data.")
93
+ parser.add_argument(
94
+ "--msadir",
95
+ type=Path,
96
+ required=True,
97
+ help="The MSA data directory.",
98
+ )
99
+ parser.add_argument(
100
+ "--outdir",
101
+ type=Path,
102
+ default="data",
103
+ help="The output directory.",
104
+ )
105
+ parser.add_argument(
106
+ "--num-processes",
107
+ type=int,
108
+ default=multiprocessing.cpu_count(),
109
+ help="The number of processes.",
110
+ )
111
+ parser.add_argument(
112
+ "--redis-host",
113
+ type=str,
114
+ default="localhost",
115
+ help="The Redis host.",
116
+ )
117
+ parser.add_argument(
118
+ "--redis-port",
119
+ type=int,
120
+ default=7777,
121
+ help="The Redis port.",
122
+ )
123
+ parser.add_argument(
124
+ "--max-seqs",
125
+ type=int,
126
+ default=16384,
127
+ help="The maximum number of sequences.",
128
+ )
129
+ args = parser.parse_args()
130
+ process(args)
protify/FastPLMs/boltz/scripts/process/rcsb.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import multiprocessing
4
+ import pickle
5
+ import traceback
6
+ from dataclasses import asdict, dataclass, replace
7
+ from functools import partial
8
+ from pathlib import Path
9
+ from typing import Any, Optional
10
+
11
+ import numpy as np
12
+ import rdkit
13
+ from mmcif import parse_mmcif
14
+ from p_tqdm import p_umap
15
+ from redis import Redis
16
+ from tqdm import tqdm
17
+
18
+ from boltz.data.filter.static.filter import StaticFilter
19
+ from boltz.data.filter.static.ligand import ExcludedLigands
20
+ from boltz.data.filter.static.polymer import (
21
+ ClashingChainsFilter,
22
+ ConsecutiveCA,
23
+ MinimumLengthFilter,
24
+ UnknownFilter,
25
+ )
26
+ from boltz.data.types import ChainInfo, InterfaceInfo, Record, Target
27
+
28
+
29
+ @dataclass(frozen=True, slots=True)
30
+ class PDB:
31
+ """A raw MMCIF PDB file."""
32
+
33
+ id: str
34
+ path: str
35
+
36
+
37
+ class Resource:
38
+ """A shared resource for processing."""
39
+
40
+ def __init__(self, host: str, port: int) -> None:
41
+ """Initialize the redis database."""
42
+ self._redis = Redis(host=host, port=port)
43
+
44
+ def get(self, key: str) -> Any: # noqa: ANN401
45
+ """Get an item from the Redis database."""
46
+ value = self._redis.get(key)
47
+ if value is not None:
48
+ value = pickle.loads(value) # noqa: S301
49
+ return value
50
+
51
+ def __getitem__(self, key: str) -> Any: # noqa: ANN401
52
+ """Get an item from the resource."""
53
+ out = self.get(key)
54
+ if out is None:
55
+ raise KeyError(key)
56
+ return out
57
+
58
+
59
+ def fetch(datadir: Path, max_file_size: Optional[int] = None) -> list[PDB]:
60
+ """Fetch the PDB files."""
61
+ data = []
62
+ excluded = 0
63
+ for file in datadir.rglob("*.cif*"):
64
+ # The clustering file is annotated by pdb_entity id
65
+ pdb_id = str(file.stem).lower()
66
+
67
+ # Check file size and skip if too large
68
+ if max_file_size is not None and (file.stat().st_size > max_file_size):
69
+ excluded += 1
70
+ continue
71
+
72
+ # Create the target
73
+ target = PDB(id=pdb_id, path=str(file))
74
+ data.append(target)
75
+
76
+ print(f"Excluded {excluded} files due to size.") # noqa: T201
77
+ return data
78
+
79
+
80
+ def finalize(outdir: Path) -> None:
81
+ """Run post-processing in main thread.
82
+
83
+ Parameters
84
+ ----------
85
+ outdir : Path
86
+ The output directory.
87
+
88
+ """
89
+ # Group records into a manifest
90
+ records_dir = outdir / "records"
91
+
92
+ failed_count = 0
93
+ records = []
94
+ for record in records_dir.iterdir():
95
+ path = record
96
+ try:
97
+ with path.open("r") as f:
98
+ records.append(json.load(f))
99
+ except: # noqa: E722
100
+ failed_count += 1
101
+ print(f"Failed to parse {record}") # noqa: T201
102
+ if failed_count > 0:
103
+ print(f"Failed to parse {failed_count} entries.") # noqa: T201
104
+ else:
105
+ print("All entries parsed successfully.")
106
+
107
+ # Save manifest
108
+ outpath = outdir / "manifest.json"
109
+ with outpath.open("w") as f:
110
+ json.dump(records, f)
111
+
112
+
113
+ def parse(data: PDB, resource: Resource, clusters: dict) -> Target:
114
+ """Process a structure.
115
+
116
+ Parameters
117
+ ----------
118
+ data : PDB
119
+ The raw input data.
120
+ resource: Resource
121
+ The shared resource.
122
+
123
+ Returns
124
+ -------
125
+ Target
126
+ The processed data.
127
+
128
+ """
129
+ # Get the PDB id
130
+ pdb_id = data.id.lower()
131
+
132
+ # Parse structure
133
+ parsed = parse_mmcif(data.path, resource)
134
+ structure = parsed.data
135
+ structure_info = parsed.info
136
+
137
+ # Create chain metadata
138
+ chain_info = []
139
+ for i, chain in enumerate(structure.chains):
140
+ key = f"{pdb_id}_{chain['entity_id']}"
141
+ chain_info.append(
142
+ ChainInfo(
143
+ chain_id=i,
144
+ chain_name=chain["name"],
145
+ msa_id="", # FIX
146
+ mol_type=int(chain["mol_type"]),
147
+ cluster_id=clusters.get(key, -1),
148
+ num_residues=int(chain["res_num"]),
149
+ )
150
+ )
151
+
152
+ # Get interface metadata
153
+ interface_info = []
154
+ for interface in structure.interfaces:
155
+ chain_1 = int(interface["chain_1"])
156
+ chain_2 = int(interface["chain_2"])
157
+ interface_info.append(
158
+ InterfaceInfo(
159
+ chain_1=chain_1,
160
+ chain_2=chain_2,
161
+ )
162
+ )
163
+
164
+ # Create record
165
+ record = Record(
166
+ id=data.id,
167
+ structure=structure_info,
168
+ chains=chain_info,
169
+ interfaces=interface_info,
170
+ )
171
+
172
+ return Target(structure=structure, record=record)
173
+
174
+
175
+ def process_structure(
176
+ data: PDB,
177
+ resource: Resource,
178
+ outdir: Path,
179
+ filters: list[StaticFilter],
180
+ clusters: dict,
181
+ ) -> None:
182
+ """Process a target.
183
+
184
+ Parameters
185
+ ----------
186
+ item : PDB
187
+ The raw input data.
188
+ resource: Resource
189
+ The shared resource.
190
+ outdir : Path
191
+ The output directory.
192
+
193
+ """
194
+ # Check if we need to process
195
+ struct_path = outdir / "structures" / f"{data.id}.npz"
196
+ record_path = outdir / "records" / f"{data.id}.json"
197
+
198
+ if struct_path.exists() and record_path.exists():
199
+ return
200
+
201
+ try:
202
+ # Parse the target
203
+ target: Target = parse(data, resource, clusters)
204
+ structure = target.structure
205
+
206
+ # Apply the filters
207
+ mask = structure.mask
208
+ if filters is not None:
209
+ for f in filters:
210
+ filter_mask = f.filter(structure)
211
+ mask = mask & filter_mask
212
+ except Exception: # noqa: BLE001
213
+ traceback.print_exc()
214
+ print(f"Failed to parse {data.id}")
215
+ return
216
+
217
+ # Replace chains and interfaces
218
+ chains = []
219
+ for i, chain in enumerate(target.record.chains):
220
+ chains.append(replace(chain, valid=bool(mask[i])))
221
+
222
+ interfaces = []
223
+ for interface in target.record.interfaces:
224
+ chain_1 = bool(mask[interface.chain_1])
225
+ chain_2 = bool(mask[interface.chain_2])
226
+ interfaces.append(replace(interface, valid=(chain_1 and chain_2)))
227
+
228
+ # Replace structure and record
229
+ structure = replace(structure, mask=mask)
230
+ record = replace(target.record, chains=chains, interfaces=interfaces)
231
+ target = replace(target, structure=structure, record=record)
232
+
233
+ # Dump structure
234
+ np.savez_compressed(struct_path, **asdict(structure))
235
+
236
+ # Dump record
237
+ with record_path.open("w") as f:
238
+ json.dump(asdict(record), f)
239
+
240
+
241
+ def process(args) -> None:
242
+ """Run the data processing task."""
243
+ # Create output directory
244
+ args.outdir.mkdir(parents=True, exist_ok=True)
245
+
246
+ # Create output directories
247
+ records_dir = args.outdir / "records"
248
+ records_dir.mkdir(parents=True, exist_ok=True)
249
+
250
+ structure_dir = args.outdir / "structures"
251
+ structure_dir.mkdir(parents=True, exist_ok=True)
252
+
253
+ # Load clusters
254
+ with Path(args.clusters).open("r") as f:
255
+ clusters: dict[str, str] = json.load(f)
256
+ clusters = {k.lower(): v.lower() for k, v in clusters.items()}
257
+
258
+ # Load filters
259
+ filters = [
260
+ ExcludedLigands(),
261
+ MinimumLengthFilter(min_len=4, max_len=5000),
262
+ UnknownFilter(),
263
+ ConsecutiveCA(max_dist=10.0),
264
+ ClashingChainsFilter(freq=0.3, dist=1.7),
265
+ ]
266
+
267
+ # Set default pickle properties
268
+ pickle_option = rdkit.Chem.PropertyPickleOptions.AllProps
269
+ rdkit.Chem.SetDefaultPickleProperties(pickle_option)
270
+
271
+ # Load shared data from redis
272
+ resource = Resource(host=args.redis_host, port=args.redis_port)
273
+
274
+ # Get data points
275
+ print("Fetching data...")
276
+ data = fetch(args.datadir)
277
+
278
+ # Check if we can run in parallel
279
+ max_processes = multiprocessing.cpu_count()
280
+ num_processes = max(1, min(args.num_processes, max_processes, len(data)))
281
+ parallel = num_processes > 1
282
+
283
+ # Run processing
284
+ print("Processing data...")
285
+ if parallel:
286
+ # Create processing function
287
+ fn = partial(
288
+ process_structure,
289
+ resource=resource,
290
+ outdir=args.outdir,
291
+ clusters=clusters,
292
+ filters=filters,
293
+ )
294
+ # Run processing in parallel
295
+ p_umap(fn, data, num_cpus=num_processes)
296
+ else:
297
+ for item in tqdm(data):
298
+ process_structure(
299
+ item,
300
+ resource=resource,
301
+ outdir=args.outdir,
302
+ clusters=clusters,
303
+ filters=filters,
304
+ )
305
+
306
+ # Finalize
307
+ finalize(args.outdir)
308
+
309
+
310
+ if __name__ == "__main__":
311
+ parser = argparse.ArgumentParser(description="Process MSA data.")
312
+ parser.add_argument(
313
+ "--datadir",
314
+ type=Path,
315
+ required=True,
316
+ help="The data containing the MMCIF files.",
317
+ )
318
+ parser.add_argument(
319
+ "--clusters",
320
+ type=Path,
321
+ required=True,
322
+ help="Path to the cluster file.",
323
+ )
324
+ parser.add_argument(
325
+ "--outdir",
326
+ type=Path,
327
+ default="data",
328
+ help="The output directory.",
329
+ )
330
+ parser.add_argument(
331
+ "--num-processes",
332
+ type=int,
333
+ default=multiprocessing.cpu_count(),
334
+ help="The number of processes.",
335
+ )
336
+ parser.add_argument(
337
+ "--redis-host",
338
+ type=str,
339
+ default="localhost",
340
+ help="The Redis host.",
341
+ )
342
+ parser.add_argument(
343
+ "--redis-port",
344
+ type=int,
345
+ default=7777,
346
+ help="The Redis port.",
347
+ )
348
+ parser.add_argument(
349
+ "--use-assembly",
350
+ action="store_true",
351
+ help="Whether to use assembly 1.",
352
+ )
353
+ parser.add_argument(
354
+ "--max-file-size",
355
+ type=int,
356
+ default=None,
357
+ )
358
+ args = parser.parse_args()
359
+ process(args)
protify/FastPLMs/boltz/scripts/train/train.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import string
4
+ import sys
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import hydra
10
+ import omegaconf
11
+ import pytorch_lightning as pl
12
+ import torch
13
+ import torch.multiprocessing
14
+ from omegaconf import OmegaConf, listconfig
15
+ from pytorch_lightning import LightningModule
16
+ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
17
+ from pytorch_lightning.loggers import WandbLogger
18
+ from pytorch_lightning.strategies import DDPStrategy
19
+ from pytorch_lightning.utilities import rank_zero_only
20
+
21
+ from boltz.data.module.training import BoltzTrainingDataModule, DataConfig
22
+
23
+
24
+ @dataclass
25
+ class TrainConfig:
26
+ """Train configuration.
27
+
28
+ Attributes
29
+ ----------
30
+ data : DataConfig
31
+ The data configuration.
32
+ model : ModelConfig
33
+ The model configuration.
34
+ output : str
35
+ The output directory.
36
+ trainer : Optional[dict]
37
+ The trainer configuration.
38
+ resume : Optional[str]
39
+ The resume checkpoint.
40
+ pretrained : Optional[str]
41
+ The pretrained model.
42
+ wandb : Optional[dict]
43
+ The wandb configuration.
44
+ disable_checkpoint : bool
45
+ Disable checkpoint.
46
+ matmul_precision : Optional[str]
47
+ The matmul precision.
48
+ find_unused_parameters : Optional[bool]
49
+ Find unused parameters.
50
+ save_top_k : Optional[int]
51
+ Save top k checkpoints.
52
+ validation_only : bool
53
+ Run validation only.
54
+ debug : bool
55
+ Debug mode.
56
+ strict_loading : bool
57
+ Fail on mismatched checkpoint weights.
58
+ load_confidence_from_trunk: Optional[bool]
59
+ Load pre-trained confidence weights from trunk.
60
+
61
+ """
62
+
63
+ data: DataConfig
64
+ model: LightningModule
65
+ output: str
66
+ trainer: Optional[dict] = None
67
+ resume: Optional[str] = None
68
+ pretrained: Optional[str] = None
69
+ wandb: Optional[dict] = None
70
+ disable_checkpoint: bool = False
71
+ matmul_precision: Optional[str] = None
72
+ find_unused_parameters: Optional[bool] = False
73
+ save_top_k: Optional[int] = 1
74
+ validation_only: bool = False
75
+ debug: bool = False
76
+ strict_loading: bool = True
77
+ load_confidence_from_trunk: Optional[bool] = False
78
+
79
+
80
+ def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR0915
81
+ """Run training.
82
+
83
+ Parameters
84
+ ----------
85
+ raw_config : str
86
+ The input yaml configuration.
87
+ args : list[str]
88
+ Any command line overrides.
89
+
90
+ """
91
+ # Load the configuration
92
+ raw_config = omegaconf.OmegaConf.load(raw_config)
93
+
94
+ # Apply input arguments
95
+ args = omegaconf.OmegaConf.from_dotlist(args)
96
+ raw_config = omegaconf.OmegaConf.merge(raw_config, args)
97
+
98
+ # Instantiate the task
99
+ cfg = hydra.utils.instantiate(raw_config)
100
+ cfg = TrainConfig(**cfg)
101
+
102
+ # Set matmul precision
103
+ if cfg.matmul_precision is not None:
104
+ torch.set_float32_matmul_precision(cfg.matmul_precision)
105
+
106
+ # Create trainer dict
107
+ trainer = cfg.trainer
108
+ if trainer is None:
109
+ trainer = {}
110
+
111
+ # Flip some arguments in debug mode
112
+ devices = trainer.get("devices", 1)
113
+
114
+ wandb = cfg.wandb
115
+ if cfg.debug:
116
+ if isinstance(devices, int):
117
+ devices = 1
118
+ elif isinstance(devices, (list, listconfig.ListConfig)):
119
+ devices = [devices[0]]
120
+ trainer["devices"] = devices
121
+ cfg.data.num_workers = 0
122
+ if wandb:
123
+ wandb = None
124
+
125
+ # Create objects
126
+ data_config = DataConfig(**cfg.data)
127
+ data_module = BoltzTrainingDataModule(data_config)
128
+ model_module = cfg.model
129
+
130
+ if cfg.pretrained and not cfg.resume:
131
+ # Load the pretrained weights into the confidence module
132
+ if cfg.load_confidence_from_trunk:
133
+ checkpoint = torch.load(cfg.pretrained, map_location="cpu")
134
+
135
+ # Modify parameter names in the state_dict
136
+ new_state_dict = {}
137
+ for key, value in checkpoint["state_dict"].items():
138
+ if not key.startswith("structure_module") and not key.startswith(
139
+ "distogram_module"
140
+ ):
141
+ new_key = "confidence_module." + key
142
+ new_state_dict[new_key] = value
143
+ new_state_dict.update(checkpoint["state_dict"])
144
+
145
+ # Update the checkpoint with the new state_dict
146
+ checkpoint["state_dict"] = new_state_dict
147
+
148
+ # Save the modified checkpoint
149
+ random_string = "".join(
150
+ random.choices(string.ascii_lowercase + string.digits, k=10)
151
+ )
152
+ file_path = os.path.dirname(cfg.pretrained) + "/" + random_string + ".ckpt"
153
+ print(
154
+ f"Saving modified checkpoint to {file_path} created by broadcasting trunk of {cfg.pretrained} to confidence module."
155
+ )
156
+ torch.save(checkpoint, file_path)
157
+ else:
158
+ file_path = cfg.pretrained
159
+
160
+ print(f"Loading model from {file_path}")
161
+ model_module = type(model_module).load_from_checkpoint(
162
+ file_path, map_location="cpu", strict=False, **(model_module.hparams)
163
+ )
164
+
165
+ if cfg.load_confidence_from_trunk:
166
+ os.remove(file_path)
167
+
168
+ # Create checkpoint callback
169
+ callbacks = []
170
+ dirpath = cfg.output
171
+ if not cfg.disable_checkpoint:
172
+ mc = ModelCheckpoint(
173
+ monitor="val/lddt",
174
+ save_top_k=cfg.save_top_k,
175
+ save_last=True,
176
+ mode="max",
177
+ every_n_epochs=1,
178
+ )
179
+ callbacks = [mc]
180
+
181
+ # Create wandb logger
182
+ loggers = []
183
+ if wandb:
184
+ wdb_logger = WandbLogger(
185
+ name=wandb["name"],
186
+ group=wandb["name"],
187
+ save_dir=cfg.output,
188
+ project=wandb["project"],
189
+ entity=wandb["entity"],
190
+ log_model=False,
191
+ )
192
+ loggers.append(wdb_logger)
193
+ # Save the config to wandb
194
+
195
+ @rank_zero_only
196
+ def save_config_to_wandb() -> None:
197
+ config_out = Path(wdb_logger.experiment.dir) / "run.yaml"
198
+ with Path.open(config_out, "w") as f:
199
+ OmegaConf.save(raw_config, f)
200
+ wdb_logger.experiment.save(str(config_out))
201
+
202
+ save_config_to_wandb()
203
+
204
+ # Set up trainer
205
+ strategy = "auto"
206
+ if (isinstance(devices, int) and devices > 1) or (
207
+ isinstance(devices, (list, listconfig.ListConfig)) and len(devices) > 1
208
+ ):
209
+ strategy = DDPStrategy(find_unused_parameters=cfg.find_unused_parameters)
210
+
211
+ trainer = pl.Trainer(
212
+ default_root_dir=str(dirpath),
213
+ strategy=strategy,
214
+ callbacks=callbacks,
215
+ logger=loggers,
216
+ enable_checkpointing=not cfg.disable_checkpoint,
217
+ reload_dataloaders_every_n_epochs=1,
218
+ **trainer,
219
+ )
220
+
221
+ if not cfg.strict_loading:
222
+ model_module.strict_loading = False
223
+
224
+ if cfg.validation_only:
225
+ trainer.validate(
226
+ model_module,
227
+ datamodule=data_module,
228
+ ckpt_path=cfg.resume,
229
+ )
230
+ else:
231
+ trainer.fit(
232
+ model_module,
233
+ datamodule=data_module,
234
+ ckpt_path=cfg.resume,
235
+ )
236
+
237
+
238
+ if __name__ == "__main__":
239
+ arg1 = sys.argv[1]
240
+ arg2 = sys.argv[2:]
241
+ train(arg1, arg2)
protify/FastPLMs/boltz/src/boltz/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from importlib.metadata import PackageNotFoundError, version
2
+
3
+ try: # noqa: SIM105
4
+ __version__ = version("boltz")
5
+ except PackageNotFoundError:
6
+ # package is not installed
7
+ pass
protify/FastPLMs/boltz/src/boltz/data/__init__.py ADDED
File without changes
protify/FastPLMs/boltz/src/boltz/data/const.py ADDED
@@ -0,0 +1,1184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ####################################################################################################
2
+ # CHAINS
3
+ ####################################################################################################
4
+
5
+ chain_types = [
6
+ "PROTEIN",
7
+ "DNA",
8
+ "RNA",
9
+ "NONPOLYMER",
10
+ ]
11
+ chain_type_ids = {chain: i for i, chain in enumerate(chain_types)}
12
+
13
+ out_types = [
14
+ "dna_protein",
15
+ "rna_protein",
16
+ "ligand_protein",
17
+ "dna_ligand",
18
+ "rna_ligand",
19
+ "intra_ligand",
20
+ "intra_dna",
21
+ "intra_rna",
22
+ "intra_protein",
23
+ "protein_protein",
24
+ "modified",
25
+ ]
26
+
27
+ out_types_weights_af3 = {
28
+ "dna_protein": 10.0,
29
+ "rna_protein": 10.0,
30
+ "ligand_protein": 10.0,
31
+ "dna_ligand": 5.0,
32
+ "rna_ligand": 5.0,
33
+ "intra_ligand": 20.0,
34
+ "intra_dna": 4.0,
35
+ "intra_rna": 16.0,
36
+ "intra_protein": 20.0,
37
+ "protein_protein": 20.0,
38
+ "modified": 0.0,
39
+ }
40
+
41
+ out_types_weights = {
42
+ "dna_protein": 5.0,
43
+ "rna_protein": 5.0,
44
+ "ligand_protein": 20.0,
45
+ "dna_ligand": 2.0,
46
+ "rna_ligand": 2.0,
47
+ "intra_ligand": 20.0,
48
+ "intra_dna": 2.0,
49
+ "intra_rna": 8.0,
50
+ "intra_protein": 20.0,
51
+ "protein_protein": 20.0,
52
+ "modified": 0.0,
53
+ }
54
+
55
+
56
+ out_single_types = ["protein", "ligand", "dna", "rna"]
57
+
58
+ clash_types = [
59
+ "dna_protein",
60
+ "rna_protein",
61
+ "ligand_protein",
62
+ "protein_protein",
63
+ "dna_ligand",
64
+ "rna_ligand",
65
+ "ligand_ligand",
66
+ "rna_dna",
67
+ "dna_dna",
68
+ "rna_rna",
69
+ ]
70
+
71
+ chain_types_to_clash_type = {
72
+ frozenset(("PROTEIN", "DNA")): "dna_protein",
73
+ frozenset(("PROTEIN", "RNA")): "rna_protein",
74
+ frozenset(("PROTEIN", "NONPOLYMER")): "ligand_protein",
75
+ frozenset(("PROTEIN",)): "protein_protein",
76
+ frozenset(("NONPOLYMER", "DNA")): "dna_ligand",
77
+ frozenset(("NONPOLYMER", "RNA")): "rna_ligand",
78
+ frozenset(("NONPOLYMER",)): "ligand_ligand",
79
+ frozenset(("DNA", "RNA")): "rna_dna",
80
+ frozenset(("DNA",)): "dna_dna",
81
+ frozenset(("RNA",)): "rna_rna",
82
+ }
83
+
84
+ chain_type_to_out_single_type = {
85
+ "PROTEIN": "protein",
86
+ "DNA": "dna",
87
+ "RNA": "rna",
88
+ "NONPOLYMER": "ligand",
89
+ }
90
+ ####################################################################################################
91
+ # RESIDUES & TOKENS
92
+ ####################################################################################################
93
+
94
+
95
+ canonical_tokens = [
96
+ "ALA",
97
+ "ARG",
98
+ "ASN",
99
+ "ASP",
100
+ "CYS",
101
+ "GLN",
102
+ "GLU",
103
+ "GLY",
104
+ "HIS",
105
+ "ILE",
106
+ "LEU",
107
+ "LYS",
108
+ "MET",
109
+ "PHE",
110
+ "PRO",
111
+ "SER",
112
+ "THR",
113
+ "TRP",
114
+ "TYR",
115
+ "VAL",
116
+ "UNK", # unknown protein token
117
+ ]
118
+
119
+ tokens = [
120
+ "<pad>",
121
+ "-",
122
+ *canonical_tokens,
123
+ "A",
124
+ "G",
125
+ "C",
126
+ "U",
127
+ "N", # unknown rna token
128
+ "DA",
129
+ "DG",
130
+ "DC",
131
+ "DT",
132
+ "DN", # unknown dna token
133
+ ]
134
+
135
+ token_ids = {token: i for i, token in enumerate(tokens)}
136
+ num_tokens = len(tokens)
137
+ unk_token = {"PROTEIN": "UNK", "DNA": "DN", "RNA": "N"}
138
+ unk_token_ids = {m: token_ids[t] for m, t in unk_token.items()}
139
+
140
+ prot_letter_to_token = {
141
+ "A": "ALA",
142
+ "R": "ARG",
143
+ "N": "ASN",
144
+ "D": "ASP",
145
+ "C": "CYS",
146
+ "E": "GLU",
147
+ "Q": "GLN",
148
+ "G": "GLY",
149
+ "H": "HIS",
150
+ "I": "ILE",
151
+ "L": "LEU",
152
+ "K": "LYS",
153
+ "M": "MET",
154
+ "F": "PHE",
155
+ "P": "PRO",
156
+ "S": "SER",
157
+ "T": "THR",
158
+ "W": "TRP",
159
+ "Y": "TYR",
160
+ "V": "VAL",
161
+ "X": "UNK",
162
+ "J": "UNK",
163
+ "B": "UNK",
164
+ "Z": "UNK",
165
+ "O": "UNK",
166
+ "U": "UNK",
167
+ "-": "-",
168
+ }
169
+
170
+ prot_token_to_letter = {v: k for k, v in prot_letter_to_token.items()}
171
+ prot_token_to_letter["UNK"] = "X"
172
+
173
+ rna_letter_to_token = {
174
+ "A": "A",
175
+ "G": "G",
176
+ "C": "C",
177
+ "U": "U",
178
+ "N": "N",
179
+ }
180
+ rna_token_to_letter = {v: k for k, v in rna_letter_to_token.items()}
181
+
182
+ dna_letter_to_token = {
183
+ "A": "DA",
184
+ "G": "DG",
185
+ "C": "DC",
186
+ "T": "DT",
187
+ "N": "DN",
188
+ }
189
+ dna_token_to_letter = {v: k for k, v in dna_letter_to_token.items()}
190
+
191
+ ####################################################################################################
192
+ # ATOMS
193
+ ####################################################################################################
194
+
195
+ num_elements = 128
196
+
197
+ chirality_types = [
198
+ "CHI_UNSPECIFIED",
199
+ "CHI_TETRAHEDRAL_CW",
200
+ "CHI_TETRAHEDRAL_CCW",
201
+ "CHI_SQUAREPLANAR",
202
+ "CHI_OCTAHEDRAL",
203
+ "CHI_TRIGONALBIPYRAMIDAL",
204
+ "CHI_OTHER",
205
+ ]
206
+ chirality_type_ids = {chirality: i for i, chirality in enumerate(chirality_types)}
207
+ unk_chirality_type = "CHI_OTHER"
208
+
209
+ hybridization_map = [
210
+ "S",
211
+ "SP",
212
+ "SP2",
213
+ "SP2D",
214
+ "SP3",
215
+ "SP3D",
216
+ "SP3D2",
217
+ "OTHER",
218
+ "UNSPECIFIED",
219
+ ]
220
+ hybridization_type_ids = {hybrid: i for i, hybrid in enumerate(hybridization_map)}
221
+ unk_hybridization_type = "UNSPECIFIED"
222
+
223
+ # fmt: off
224
+ ref_atoms = {
225
+ "PAD": [],
226
+ "UNK": ["N", "CA", "C", "O", "CB"],
227
+ "-": [],
228
+ "ALA": ["N", "CA", "C", "O", "CB"],
229
+ "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
230
+ "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"],
231
+ "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"],
232
+ "CYS": ["N", "CA", "C", "O", "CB", "SG"],
233
+ "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"],
234
+ "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"],
235
+ "GLY": ["N", "CA", "C", "O"],
236
+ "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"],
237
+ "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"],
238
+ "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"],
239
+ "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"],
240
+ "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
241
+ "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
242
+ "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD"],
243
+ "SER": ["N", "CA", "C", "O", "CB", "OG"],
244
+ "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2"],
245
+ "TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"], # noqa: E501
246
+ "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"],
247
+ "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2"],
248
+ "A": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"], # noqa: E501
249
+ "G": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"], # noqa: E501
250
+ "C": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"], # noqa: E501
251
+ "U": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C6"], # noqa: E501
252
+ "N": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'"], # noqa: E501
253
+ "DA": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"], # noqa: E501
254
+ "DG": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"], # noqa: E501
255
+ "DC": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"], # noqa: E501
256
+ "DT": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C7", "C6"], # noqa: E501
257
+ "DN": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'"]
258
+ }
259
+
260
+ protein_backbone_atom_names = ["N", "CA", "C", "O"]
261
+ nucleic_backbone_atom_names = ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'"]
262
+
263
+ protein_backbone_atom_index = {name: i for i, name in enumerate(protein_backbone_atom_names)}
264
+ nucleic_backbone_atom_index = {name: i for i, name in enumerate(nucleic_backbone_atom_names)}
265
+
266
+ ref_symmetries = {
267
+ "PAD": [],
268
+ "ALA": [],
269
+ "ARG": [],
270
+ "ASN": [],
271
+ "ASP": [[(6, 7), (7, 6)]],
272
+ "CYS": [],
273
+ "GLN": [],
274
+ "GLU": [[(7, 8), (8, 7)]],
275
+ "GLY": [],
276
+ "HIS": [],
277
+ "ILE": [],
278
+ "LEU": [],
279
+ "LYS": [],
280
+ "MET": [],
281
+ "PHE": [[(6, 7), (7, 6), (8, 9), (9, 8)]],
282
+ "PRO": [],
283
+ "SER": [],
284
+ "THR": [],
285
+ "TRP": [],
286
+ "TYR": [[(6, 7), (7, 6), (8, 9), (9, 8)]],
287
+ "VAL": [],
288
+ "A": [[(1, 2), (2, 1)]],
289
+ "G": [[(1, 2), (2, 1)]],
290
+ "C": [[(1, 2), (2, 1)]],
291
+ "U": [[(1, 2), (2, 1)]],
292
+ #"N": [[(1, 2), (2, 1)]],
293
+ "DA": [[(1, 2), (2, 1)]],
294
+ "DG": [[(1, 2), (2, 1)]],
295
+ "DC": [[(1, 2), (2, 1)]],
296
+ "DT": [[(1, 2), (2, 1)]],
297
+ #"DN": [[(1, 2), (2, 1)]]
298
+ }
299
+
300
+
301
+ res_to_center_atom = {
302
+ "UNK": "CA",
303
+ "ALA": "CA",
304
+ "ARG": "CA",
305
+ "ASN": "CA",
306
+ "ASP": "CA",
307
+ "CYS": "CA",
308
+ "GLN": "CA",
309
+ "GLU": "CA",
310
+ "GLY": "CA",
311
+ "HIS": "CA",
312
+ "ILE": "CA",
313
+ "LEU": "CA",
314
+ "LYS": "CA",
315
+ "MET": "CA",
316
+ "PHE": "CA",
317
+ "PRO": "CA",
318
+ "SER": "CA",
319
+ "THR": "CA",
320
+ "TRP": "CA",
321
+ "TYR": "CA",
322
+ "VAL": "CA",
323
+ "A": "C1'",
324
+ "G": "C1'",
325
+ "C": "C1'",
326
+ "U": "C1'",
327
+ "N": "C1'",
328
+ "DA": "C1'",
329
+ "DG": "C1'",
330
+ "DC": "C1'",
331
+ "DT": "C1'",
332
+ "DN": "C1'"
333
+ }
334
+
335
+ res_to_disto_atom = {
336
+ "UNK": "CB",
337
+ "ALA": "CB",
338
+ "ARG": "CB",
339
+ "ASN": "CB",
340
+ "ASP": "CB",
341
+ "CYS": "CB",
342
+ "GLN": "CB",
343
+ "GLU": "CB",
344
+ "GLY": "CA",
345
+ "HIS": "CB",
346
+ "ILE": "CB",
347
+ "LEU": "CB",
348
+ "LYS": "CB",
349
+ "MET": "CB",
350
+ "PHE": "CB",
351
+ "PRO": "CB",
352
+ "SER": "CB",
353
+ "THR": "CB",
354
+ "TRP": "CB",
355
+ "TYR": "CB",
356
+ "VAL": "CB",
357
+ "A": "C4",
358
+ "G": "C4",
359
+ "C": "C2",
360
+ "U": "C2",
361
+ "N": "C1'",
362
+ "DA": "C4",
363
+ "DG": "C4",
364
+ "DC": "C2",
365
+ "DT": "C2",
366
+ "DN": "C1'"
367
+ }
368
+
369
+ res_to_center_atom_id = {
370
+ res: ref_atoms[res].index(atom)
371
+ for res, atom in res_to_center_atom.items()
372
+ }
373
+
374
+ res_to_disto_atom_id = {
375
+ res: ref_atoms[res].index(atom)
376
+ for res, atom in res_to_disto_atom.items()
377
+ }
378
+
379
+ # fmt: on
380
+
381
+ ####################################################################################################
382
+ # BONDS
383
+ ####################################################################################################
384
+
385
+ atom_interface_cutoff = 5.0
386
+ interface_cutoff = 15.0
387
+
388
+ bond_types = [
389
+ "OTHER",
390
+ "SINGLE",
391
+ "DOUBLE",
392
+ "TRIPLE",
393
+ "AROMATIC",
394
+ "COVALENT",
395
+ ]
396
+ bond_type_ids = {bond: i for i, bond in enumerate(bond_types)}
397
+ unk_bond_type = "OTHER"
398
+
399
+
400
+ ####################################################################################################
401
+ # Contacts
402
+ ####################################################################################################
403
+
404
+
405
+ pocket_contact_info = {
406
+ "UNSPECIFIED": 0,
407
+ "UNSELECTED": 1,
408
+ "POCKET": 2,
409
+ "BINDER": 3,
410
+ }
411
+
412
+ contact_conditioning_info = {
413
+ "UNSPECIFIED": 0,
414
+ "UNSELECTED": 1,
415
+ "POCKET>BINDER": 2,
416
+ "BINDER>POCKET": 3,
417
+ "CONTACT": 4,
418
+ }
419
+
420
+
421
+ ####################################################################################################
422
+ # MSA
423
+ ####################################################################################################
424
+
425
+ max_msa_seqs = 16384
426
+ max_paired_seqs = 8192
427
+
428
+
429
+ ####################################################################################################
430
+ # CHUNKING
431
+ ####################################################################################################
432
+
433
+ chunk_size_threshold = 384
434
+
435
+ ####################################################################################################
436
+ # Method conditioning
437
+ ####################################################################################################
438
+
439
+ # Methods
440
+ method_types_ids = {
441
+ "MD": 0,
442
+ "X-RAY DIFFRACTION": 1,
443
+ "ELECTRON MICROSCOPY": 2,
444
+ "SOLUTION NMR": 3,
445
+ "SOLID-STATE NMR": 4,
446
+ "NEUTRON DIFFRACTION": 4,
447
+ "ELECTRON CRYSTALLOGRAPHY": 4,
448
+ "FIBER DIFFRACTION": 4,
449
+ "POWDER DIFFRACTION": 4,
450
+ "INFRARED SPECTROSCOPY": 4,
451
+ "FLUORESCENCE TRANSFER": 4,
452
+ "EPR": 4,
453
+ "THEORETICAL MODEL": 4,
454
+ "SOLUTION SCATTERING": 4,
455
+ "OTHER": 4,
456
+ "AFDB": 5,
457
+ "BOLTZ-1": 6,
458
+ "FUTURE1": 7, # Placeholder for future supervision sources
459
+ "FUTURE2": 8,
460
+ "FUTURE3": 9,
461
+ "FUTURE4": 10,
462
+ "FUTURE5": 11,
463
+ }
464
+ method_types_ids = {k.lower(): v for k, v in method_types_ids.items()}
465
+ num_method_types = len(set(method_types_ids.values()))
466
+
467
+ # Temperature
468
+ temperature_bins = [(265, 280), (280, 295), (295, 310)]
469
+ temperature_bins_ids = {temp: i for i, temp in enumerate(temperature_bins)}
470
+ temperature_bins_ids["other"] = len(temperature_bins)
471
+ num_temp_bins = len(temperature_bins_ids)
472
+
473
+
474
+ # pH
475
+ ph_bins = [(0, 6), (6, 8), (8, 14)]
476
+ ph_bins_ids = {ph: i for i, ph in enumerate(ph_bins)}
477
+ ph_bins_ids["other"] = len(ph_bins)
478
+ num_ph_bins = len(ph_bins_ids)
479
+
480
+ ####################################################################################################
481
+ # VDW_RADII
482
+ ####################################################################################################
483
+
484
+ # fmt: off
485
+ vdw_radii = [
486
+ 1.2, 1.4, 2.2, 1.9, 1.8, 1.7, 1.6, 1.55, 1.5, 1.54,
487
+ 2.4, 2.2, 2.1, 2.1, 1.95, 1.8, 1.8, 1.88, 2.8, 2.4,
488
+ 2.3, 2.15, 2.05, 2.05, 2.05, 2.05, 2.0, 2.0, 2.0, 2.1,
489
+ 2.1, 2.1, 2.05, 1.9, 1.9, 2.02, 2.9, 2.55, 2.4, 2.3,
490
+ 2.15, 2.1, 2.05, 2.05, 2.0, 2.05, 2.1, 2.2, 2.2, 2.25,
491
+ 2.2, 2.1, 2.1, 2.16, 3.0, 2.7, 2.5, 2.48, 2.47, 2.45,
492
+ 2.43, 2.42, 2.4, 2.38, 2.37, 2.35, 2.33, 2.32, 2.3, 2.28,
493
+ 2.27, 2.25, 2.2, 2.1, 2.05, 2.0, 2.0, 2.05, 2.1, 2.05,
494
+ 2.2, 2.3, 2.3, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.4,
495
+ 2.0, 2.3, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
496
+ 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
497
+ 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0
498
+ ]
499
+ # fmt: on
500
+
501
+ ####################################################################################################
502
+ # Excluded ligands
503
+ ####################################################################################################
504
+
505
+ ligand_exclusion = {
506
+ "144",
507
+ "15P",
508
+ "1PE",
509
+ "2F2",
510
+ "2JC",
511
+ "3HR",
512
+ "3SY",
513
+ "7N5",
514
+ "7PE",
515
+ "9JE",
516
+ "AAE",
517
+ "ABA",
518
+ "ACE",
519
+ "ACN",
520
+ "ACT",
521
+ "ACY",
522
+ "AZI",
523
+ "BAM",
524
+ "BCN",
525
+ "BCT",
526
+ "BDN",
527
+ "BEN",
528
+ "BME",
529
+ "BO3",
530
+ "BTB",
531
+ "BTC",
532
+ "BU1",
533
+ "C8E",
534
+ "CAD",
535
+ "CAQ",
536
+ "CBM",
537
+ "CCN",
538
+ "CIT",
539
+ "CL",
540
+ "CLR",
541
+ "CM",
542
+ "CMO",
543
+ "CO3",
544
+ "CPT",
545
+ "CXS",
546
+ "D10",
547
+ "DEP",
548
+ "DIO",
549
+ "DMS",
550
+ "DN",
551
+ "DOD",
552
+ "DOX",
553
+ "EDO",
554
+ "EEE",
555
+ "EGL",
556
+ "EOH",
557
+ "EOX",
558
+ "EPE",
559
+ "ETF",
560
+ "FCY",
561
+ "FJO",
562
+ "FLC",
563
+ "FMT",
564
+ "FW5",
565
+ "GOL",
566
+ "GSH",
567
+ "GTT",
568
+ "GYF",
569
+ "HED",
570
+ "IHP",
571
+ "IHS",
572
+ "IMD",
573
+ "IOD",
574
+ "IPA",
575
+ "IPH",
576
+ "LDA",
577
+ "MB3",
578
+ "MEG",
579
+ "MES",
580
+ "MLA",
581
+ "MLI",
582
+ "MOH",
583
+ "MPD",
584
+ "MRD",
585
+ "MSE",
586
+ "MYR",
587
+ "N",
588
+ "NA",
589
+ "NH2",
590
+ "NH4",
591
+ "NHE",
592
+ "NO3",
593
+ "O4B",
594
+ "OHE",
595
+ "OLA",
596
+ "OLC",
597
+ "OMB",
598
+ "OME",
599
+ "OXA",
600
+ "P6G",
601
+ "PE3",
602
+ "PE4",
603
+ "PEG",
604
+ "PEO",
605
+ "PEP",
606
+ "PG0",
607
+ "PG4",
608
+ "PGE",
609
+ "PGR",
610
+ "PLM",
611
+ "PO4",
612
+ "POL",
613
+ "POP",
614
+ "PVO",
615
+ "SAR",
616
+ "SCN",
617
+ "SEO",
618
+ "SEP",
619
+ "SIN",
620
+ "SO4",
621
+ "SPD",
622
+ "SPM",
623
+ "SR",
624
+ "STE",
625
+ "STO",
626
+ "STU",
627
+ "TAR",
628
+ "TBU",
629
+ "TME",
630
+ "TPO",
631
+ "TRS",
632
+ "UNK",
633
+ "UNL",
634
+ "UNX",
635
+ "UPL",
636
+ "URE",
637
+ }
638
+
639
+
640
+ ####################################################################################################
641
+ # TEMPLATES
642
+ ####################################################################################################
643
+
644
+ min_coverage_residues = 10
645
+ min_coverage_fraction = 0.1
646
+
647
+
648
+ ####################################################################################################
649
+ # Ambiguous atoms
650
+ ####################################################################################################
651
+
652
+ ambiguous_atoms = {
653
+ "CA": {
654
+ "*": "C",
655
+ "OEX": "CA",
656
+ "OEC": "CA",
657
+ "543": "CA",
658
+ "OC6": "CA",
659
+ "OC1": "CA",
660
+ "OC7": "CA",
661
+ "OEY": "CA",
662
+ "OC4": "CA",
663
+ "OC3": "CA",
664
+ "ICA": "CA",
665
+ "CA": "CA",
666
+ "OC2": "CA",
667
+ "OC5": "CA",
668
+ },
669
+ "CD": {"*": "C", "CD": "CD", "CD3": "CD", "CD5": "CD", "CD1": "CD"},
670
+ "BR": "BR",
671
+ "CL": {
672
+ "*": "CL",
673
+ "C8P": "C",
674
+ "L3T": "C",
675
+ "TLC": "C",
676
+ "TZ0": "C",
677
+ "471": "C",
678
+ "NLK": "C",
679
+ "PGM": "C",
680
+ "PNE": "C",
681
+ "RCY": "C",
682
+ "11F": "C",
683
+ "PII": "C",
684
+ "C1Q": "C",
685
+ "4MD": "C",
686
+ "R5A": "C",
687
+ "KW2": "C",
688
+ "I7M": "C",
689
+ "R48": "C",
690
+ "FC3": "C",
691
+ "55V": "C",
692
+ "KPF": "C",
693
+ "SPZ": "C",
694
+ "0TT": "C",
695
+ "R9A": "C",
696
+ "5NA": "C",
697
+ "C55": "C",
698
+ "NIX": "C",
699
+ "5PM": "C",
700
+ "PP8": "C",
701
+ "544": "C",
702
+ "812": "C",
703
+ "NPM": "C",
704
+ "KU8": "C",
705
+ "A1AMM": "C",
706
+ "4S0": "C",
707
+ "AQC": "C",
708
+ "2JK": "C",
709
+ "WJR": "C",
710
+ "A1AAW": "C",
711
+ "85E": "C",
712
+ "MB0": "C",
713
+ "ZAB": "C",
714
+ "85K": "C",
715
+ "GBP": "C",
716
+ "A1H80": "C",
717
+ "A1AFR": "C",
718
+ "L9M": "C",
719
+ "MYK": "C",
720
+ "MB9": "C",
721
+ "38R": "C",
722
+ "EKB": "C",
723
+ "NKF": "C",
724
+ "UMQ": "C",
725
+ "T4K": "C",
726
+ "3PT": "C",
727
+ "A1A7S": "C",
728
+ "1Q9": "C",
729
+ "11R": "C",
730
+ "D2V": "C",
731
+ "SM8": "C",
732
+ "IFC": "C",
733
+ "DB5": "C",
734
+ "L2T": "C",
735
+ "GNB": "C",
736
+ "PP7": "C",
737
+ "072": "C",
738
+ "P88": "C",
739
+ "DRL": "C",
740
+ "C9W": "C",
741
+ "NTP": "C",
742
+ "4HJ": "C",
743
+ "7NA": "C",
744
+ "LPC": "C",
745
+ "T8W": "C",
746
+ "63R": "C",
747
+ "570": "C",
748
+ "R4A": "C",
749
+ "3BG": "C",
750
+ "4RB": "C",
751
+ "GSO": "C",
752
+ "BQ6": "C",
753
+ "R4P": "C",
754
+ "5CP": "C",
755
+ "TTR": "C",
756
+ "6UZ": "C",
757
+ "SPJ": "C",
758
+ "0SA": "C",
759
+ "ZL1": "C",
760
+ "BYG": "C",
761
+ "F0E": "C",
762
+ "PC0": "C",
763
+ "B2Q": "C",
764
+ "KV6": "C",
765
+ "NTO": "C",
766
+ "CLG": "C",
767
+ "R7U": "C",
768
+ "SMQ": "C",
769
+ "GM2": "C",
770
+ "Z7P": "C",
771
+ "NXF": "C",
772
+ "C6Q": "C",
773
+ "A1G": "C",
774
+ "433": "C",
775
+ "L9N": "C",
776
+ "7OX": "C",
777
+ "A1H84": "C",
778
+ "97L": "C",
779
+ "HDV": "C",
780
+ "LUO": "C",
781
+ "R6A": "C",
782
+ "1PC": "C",
783
+ "4PT": "C",
784
+ "SBZ": "C",
785
+ "EAB": "C",
786
+ "FL4": "C",
787
+ "OPS": "C",
788
+ "C2X": "C",
789
+ "SLL": "C",
790
+ "BFC": "C",
791
+ "GIP": "C",
792
+ "7CP": "C",
793
+ "CLH": "C",
794
+ "34E": "C",
795
+ "5NE": "C",
796
+ "PBF": "C",
797
+ "ABD": "C",
798
+ "ABC": "C",
799
+ "LPF": "C",
800
+ "TIZ": "C",
801
+ "4HH": "C",
802
+ "AFC": "C",
803
+ "WQH": "C",
804
+ "9JL": "C",
805
+ "CS3": "C",
806
+ "NL0": "C",
807
+ "KPY": "C",
808
+ "DNA": "C",
809
+ "B3C": "C",
810
+ "TKL": "C",
811
+ "KVS": "C",
812
+ "HO6": "C",
813
+ "NLH": "C",
814
+ "1PB": "C",
815
+ "CYF": "C",
816
+ "G4M": "C",
817
+ "R5B": "C",
818
+ "N4S": "C",
819
+ "N11": "C",
820
+ "C8F": "C",
821
+ "PIJ": "C",
822
+ "WIN": "C",
823
+ "NT1": "C",
824
+ "WJW": "C",
825
+ "HF7": "C",
826
+ "TY1": "C",
827
+ "VM1": "C",
828
+ },
829
+ "OS": {"*": "O", "DWC": "OS", "OHX": "OS", "OS": "OS", "8WV": "OS", "OS4": "OS"},
830
+ "PB": {"*": "P", "ZN9": "PB", "ZN7": "PB", "PBM": "PB", "PB": "PB", "CSB": "PB"},
831
+ "CE": {"*": "C", "CE": "CE"},
832
+ "FE": {"*": "FE", "TFR": "F", "PF5": "F", "IFC": "F", "F5C": "F"},
833
+ "NA": {"*": "N", "CGO": "NA", "R2K": "NA", "LVQ": "NA", "NA": "NA"},
834
+ "ND": {"*": "N", "ND": "ND"},
835
+ "CF": {"*": "C", "CF": "CF"},
836
+ "RU": "RU",
837
+ "BRAF": "BR",
838
+ "EU": "EU",
839
+ "CLAA": "CL",
840
+ "CLBQ": "CL",
841
+ "CM": {"*": "C", "ZCM": "CM"},
842
+ "SN": {"*": "SN", "TAP": "S", "SND": "S", "TAD": "S", "XPT": "S"},
843
+ "AG": "AG",
844
+ "CLN": "CL",
845
+ "CLM": "CL",
846
+ "CLA": {"*": "CL", "PII": "C", "TDL": "C", "D0J": "C", "GM2": "C", "PIJ": "C"},
847
+ "CLB": {
848
+ "*": "CL",
849
+ "TD5": "C",
850
+ "PII": "C",
851
+ "TDL": "C",
852
+ "GM2": "C",
853
+ "TD7": "C",
854
+ "TD6": "C",
855
+ "PIJ": "C",
856
+ },
857
+ "CR": {
858
+ "*": "C",
859
+ "BW9": "CR",
860
+ "CQ4": "CR",
861
+ "AC9": "CR",
862
+ "TIL": "CR",
863
+ "J7U": "CR",
864
+ "CR": "CR",
865
+ },
866
+ "CLAY": "CL",
867
+ "CLBC": "CL",
868
+ "PD": {
869
+ "*": "P",
870
+ "F6Q": "PD",
871
+ "SVP": "PD",
872
+ "SXC": "PD",
873
+ "U5U": "PD",
874
+ "PD": "PD",
875
+ "PLL": "PD",
876
+ },
877
+ "CO": {
878
+ "*": "C",
879
+ "J1S": "CO",
880
+ "OCN": "CO",
881
+ "OL3": "CO",
882
+ "OL4": "CO",
883
+ "B12": "CO",
884
+ "XCO": "CO",
885
+ "UFU": "CO",
886
+ "CON": "CO",
887
+ "OL5": "CO",
888
+ "B13": "CO",
889
+ "7KI": "CO",
890
+ "PL1": "CO",
891
+ "OCO": "CO",
892
+ "J1R": "CO",
893
+ "COH": "CO",
894
+ "SIR": "CO",
895
+ "6KI": "CO",
896
+ "NCO": "CO",
897
+ "9CO": "CO",
898
+ "PC3": "CO",
899
+ "BWU": "CO",
900
+ "B1Z": "CO",
901
+ "J83": "CO",
902
+ "CO": "CO",
903
+ "COY": "CO",
904
+ "CNC": "CO",
905
+ "3CO": "CO",
906
+ "OCL": "CO",
907
+ "R5Q": "CO",
908
+ "X5Z": "CO",
909
+ "CBY": "CO",
910
+ "OLS": "CO",
911
+ "F0X": "CO",
912
+ "I2A": "CO",
913
+ "OCM": "CO",
914
+ },
915
+ "CU": {
916
+ "*": "C",
917
+ "8ZR": "CU",
918
+ "K7E": "CU",
919
+ "CU3": "CU",
920
+ "SI9": "CU",
921
+ "35N": "CU",
922
+ "C2O": "CU",
923
+ "SI7": "CU",
924
+ "B15": "CU",
925
+ "SI0": "CU",
926
+ "CUP": "CU",
927
+ "SQ1": "CU",
928
+ "CUK": "CU",
929
+ "CUL": "CU",
930
+ "SI8": "CU",
931
+ "IC4": "CU",
932
+ "CUM": "CU",
933
+ "MM2": "CU",
934
+ "B30": "CU",
935
+ "S32": "CU",
936
+ "V79": "CU",
937
+ "IMF": "CU",
938
+ "CUN": "CU",
939
+ "MM1": "CU",
940
+ "MP1": "CU",
941
+ "IME": "CU",
942
+ "B17": "CU",
943
+ "C2C": "CU",
944
+ "1CU": "CU",
945
+ "CU6": "CU",
946
+ "C1O": "CU",
947
+ "CU1": "CU",
948
+ "B22": "CU",
949
+ "CUS": "CU",
950
+ "RUQ": "CU",
951
+ "CUF": "CU",
952
+ "CUA": "CU",
953
+ "CU": "CU",
954
+ "CUO": "CU",
955
+ "0TE": "CU",
956
+ "SI4": "CU",
957
+ },
958
+ "CS": {"*": "C", "CS": "CS"},
959
+ "CLQ": "CL",
960
+ "CLR": "CL",
961
+ "CLU": "CL",
962
+ "TE": "TE",
963
+ "NI": {
964
+ "*": "N",
965
+ "USN": "NI",
966
+ "NFO": "NI",
967
+ "NI2": "NI",
968
+ "NFS": "NI",
969
+ "NFR": "NI",
970
+ "82N": "NI",
971
+ "R5N": "NI",
972
+ "NFU": "NI",
973
+ "A1ICD": "NI",
974
+ "NI3": "NI",
975
+ "M43": "NI",
976
+ "MM5": "NI",
977
+ "BF8": "NI",
978
+ "TCN": "NI",
979
+ "NIK": "NI",
980
+ "CUV": "NI",
981
+ "MM6": "NI",
982
+ "J52": "NI",
983
+ "NI": "NI",
984
+ "SNF": "NI",
985
+ "XCC": "NI",
986
+ "F0L": "NI",
987
+ "UWE": "NI",
988
+ "NFC": "NI",
989
+ "3NI": "NI",
990
+ "HNI": "NI",
991
+ "F43": "NI",
992
+ "RQM": "NI",
993
+ "NFE": "NI",
994
+ "NFB": "NI",
995
+ "B51": "NI",
996
+ "NI1": "NI",
997
+ "WCC": "NI",
998
+ "NUF": "NI",
999
+ },
1000
+ "SB": {"*": "S", "UJI": "SB", "SB": "SB", "118": "SB", "SBO": "SB", "3CG": "SB"},
1001
+ "MO": "MO",
1002
+ "SEG": "SE",
1003
+ "CLL": "CL",
1004
+ "CLAH": "CL",
1005
+ "CLC": {
1006
+ "*": "CL",
1007
+ "TD5": "C",
1008
+ "PII": "C",
1009
+ "TDL": "C",
1010
+ "GM2": "C",
1011
+ "TD7": "C",
1012
+ "TD6": "C",
1013
+ "PIJ": "C",
1014
+ },
1015
+ "CLD": {"*": "CL", "PII": "C", "GM2": "C", "PIJ": "C"},
1016
+ "CLAD": "CL",
1017
+ "CLAE": "CL",
1018
+ "LA": "LA",
1019
+ "RH": "RH",
1020
+ "BRAC": "BR",
1021
+ "BRAD": "BR",
1022
+ "CLBN": "CL",
1023
+ "CLAC": "CL",
1024
+ "BRAB": "BR",
1025
+ "BRAE": "BR",
1026
+ "MG": "MG",
1027
+ "IR": "IR",
1028
+ "SE": {
1029
+ "*": "SE",
1030
+ "HII": "S",
1031
+ "NT2": "S",
1032
+ "R2P": "S",
1033
+ "S2P": "S",
1034
+ "0IU": "S",
1035
+ "QMB": "S",
1036
+ "81S": "S",
1037
+ "0QB": "S",
1038
+ "UB4": "S",
1039
+ "OHS": "S",
1040
+ "Q78": "S",
1041
+ "0Y2": "S",
1042
+ "B3M": "S",
1043
+ "NT1": "S",
1044
+ "81R": "S",
1045
+ },
1046
+ "BRAG": "BR",
1047
+ "CLF": {"*": "CL", "PII": "C", "GM2": "C", "PIJ": "C"},
1048
+ "CLE": {"*": "CL", "PII": "C", "GM2": "C", "PIJ": "C"},
1049
+ "BRAX": "BR",
1050
+ "CLK": "CL",
1051
+ "ZN": "ZN",
1052
+ "AS": "AS",
1053
+ "AU": "AU",
1054
+ "PT": "PT",
1055
+ "CLAS": "CL",
1056
+ "MN": "MN",
1057
+ "CLBE": "CL",
1058
+ "CLBF": "CL",
1059
+ "CLAF": "CL",
1060
+ "NA'": {"*": "N", "CGO": "NA"},
1061
+ "BRAH": "BR",
1062
+ "BRAI": "BR",
1063
+ "BRA": "BR",
1064
+ "BRB": "BR",
1065
+ "BRAV": "BR",
1066
+ "HG": {
1067
+ "*": "HG",
1068
+ "BBA": "H",
1069
+ "MID": "H",
1070
+ "APM": "H",
1071
+ "4QQ": "H",
1072
+ "0ZG": "H",
1073
+ "APH": "H",
1074
+ },
1075
+ "AR": "AR",
1076
+ "D": "H",
1077
+ "CLAN": "CL",
1078
+ "SI": "SI",
1079
+ "CLS": "CL",
1080
+ "ZR": "ZR",
1081
+ "CLAR": {"*": "CL", "ZM4": "C"},
1082
+ "HO": "HO",
1083
+ "CLI": {"*": "CL", "GM2": "C"},
1084
+ "CLH": {"*": "CL", "GM2": "C"},
1085
+ "CLAP": "CL",
1086
+ "CLBL": "CL",
1087
+ "CLBM": "CL",
1088
+ "PR": {"*": "PR", "UF0": "P", "252": "P"},
1089
+ "IN": "IN",
1090
+ "CLJ": "CL",
1091
+ "BRU": "BR",
1092
+ "SC": {"*": "S", "SFL": "SC"},
1093
+ "CLG": {"*": "CL", "GM2": "C"},
1094
+ "BRAT": "BR",
1095
+ "BRAR": "BR",
1096
+ "CLAG": "CL",
1097
+ "CLAB": "CL",
1098
+ "CLV": "CL",
1099
+ "TI": "TI",
1100
+ "CLAX": "CL",
1101
+ "CLAJ": "CL",
1102
+ "CL'": {"*": "CL", "BNR": "C", "25A": "C", "BDA": "C"},
1103
+ "CLAW": "CL",
1104
+ "BRF": "BR",
1105
+ "BRE": "BR",
1106
+ "RE": "RE",
1107
+ "GD": "GD",
1108
+ "SM": {"*": "S", "SM": "SM"},
1109
+ "CLBH": "CL",
1110
+ "CLBI": "CL",
1111
+ "CLAI": "CL",
1112
+ "CLY": "CL",
1113
+ "CLZ": "CL",
1114
+ "AC": "AC",
1115
+ "BR'": "BR",
1116
+ "CLT": "CL",
1117
+ "CLO": "CL",
1118
+ "CLP": "CL",
1119
+ "LU": "LU",
1120
+ "BA": {"*": "B", "BA": "BA"},
1121
+ "CLAU": "CL",
1122
+ "RB": "RB",
1123
+ "LI": "LI",
1124
+ "MOM": "MO",
1125
+ "BRAQ": "BR",
1126
+ "SR": {"*": "S", "SR": "SR", "OER": "SR"},
1127
+ "CLAT": "CL",
1128
+ "BRAL": "BR",
1129
+ "SEB": "SE",
1130
+ "CLW": "CL",
1131
+ "CLX": "CL",
1132
+ "BE": "BE",
1133
+ "BRG": "BR",
1134
+ "SEA": "SE",
1135
+ "BRAW": "BR",
1136
+ "BRBB": "BR",
1137
+ "ER": "ER",
1138
+ "TH": "TH",
1139
+ "BRR": "BR",
1140
+ "CLBV": "CL",
1141
+ "AL": "AL",
1142
+ "CLAV": "CL",
1143
+ "BRH": "BR",
1144
+ "CLAQ": "CL",
1145
+ "GA": "GA",
1146
+ "X": "*",
1147
+ "TL": "TL",
1148
+ "CLBB": "CL",
1149
+ "TB": "TB",
1150
+ "CLAK": "CL",
1151
+ "XE": {"*": "*", "XE": "XE"},
1152
+ "SEL": "SE",
1153
+ "PU": {"*": "P", "4PU": "PU"},
1154
+ "CLAZ": "CL",
1155
+ "SE'": "SE",
1156
+ "CLBA": "CL",
1157
+ "SEN": "SE",
1158
+ "SNN": "SN",
1159
+ "MOB": "MO",
1160
+ "YB": "YB",
1161
+ "BRC": "BR",
1162
+ "BRD": "BR",
1163
+ "CLAM": "CL",
1164
+ "DA": "H",
1165
+ "DB": "H",
1166
+ "DC": "H",
1167
+ "DXT": "H",
1168
+ "DXU": "H",
1169
+ "DXX": "H",
1170
+ "DXY": "H",
1171
+ "DXZ": "H",
1172
+ "DY": "DY",
1173
+ "TA": "TA",
1174
+ "XD": "*",
1175
+ "SED": "SE",
1176
+ "CLAL": "CL",
1177
+ "BRAJ": "BR",
1178
+ "AM": "AM",
1179
+ "CLAO": "CL",
1180
+ "BI": "BI",
1181
+ "KR": "KR",
1182
+ "BRBJ": "BR",
1183
+ "UNK": "*",
1184
+ }
protify/FastPLMs/boltz/src/boltz/data/crop/__init__.py ADDED
File without changes
protify/FastPLMs/boltz/src/boltz/data/crop/affinity.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import replace
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+
6
+ from boltz.data import const
7
+ from boltz.data.crop.cropper import Cropper
8
+ from boltz.data.types import Tokenized
9
+
10
+
11
+ class AffinityCropper(Cropper):
12
+ """Interpolate between contiguous and spatial crops."""
13
+
14
+ def __init__(
15
+ self,
16
+ neighborhood_size: int = 10,
17
+ max_tokens_protein: int = 200,
18
+ ) -> None:
19
+ """Initialize the cropper.
20
+
21
+ Parameters
22
+ ----------
23
+ neighborhood_size : int
24
+ Modulates the type of cropping to be performed.
25
+ Smaller neighborhoods result in more spatial
26
+ cropping. Larger neighborhoods result in more
27
+ continuous cropping.
28
+
29
+ """
30
+ self.neighborhood_size = neighborhood_size
31
+ self.max_tokens_protein = max_tokens_protein
32
+
33
+ def crop(
34
+ self,
35
+ data: Tokenized,
36
+ max_tokens: int,
37
+ max_atoms: Optional[int] = None,
38
+ ) -> Tokenized:
39
+ """Crop the data to a maximum number of tokens.
40
+
41
+ Parameters
42
+ ----------
43
+ data : Tokenized
44
+ The tokenized data.
45
+ max_tokens : int
46
+ The maximum number of tokens to crop.
47
+ random : np.random.RandomState
48
+ The random state for reproducibility.
49
+ max_atoms : Optional[int]
50
+ The maximum number of atoms to consider.
51
+
52
+ Returns
53
+ -------
54
+ Tokenized
55
+ The cropped data.
56
+
57
+ """
58
+ # Get token data
59
+ token_data = data.tokens
60
+ token_bonds = data.bonds
61
+
62
+ # Filter to resolved tokens
63
+ valid_tokens = token_data[token_data["resolved_mask"]]
64
+
65
+ # Check if we have any valid tokens
66
+ if not valid_tokens.size:
67
+ msg = "No valid tokens in structure"
68
+ raise ValueError(msg)
69
+
70
+ # compute minimum distance to ligand
71
+ ligand_coords = valid_tokens[valid_tokens["affinity_mask"]]["center_coords"]
72
+ dists = np.min(
73
+ np.sum(
74
+ (valid_tokens["center_coords"][:, None] - ligand_coords[None]) ** 2,
75
+ axis=-1,
76
+ )
77
+ ** 0.5,
78
+ axis=1,
79
+ )
80
+
81
+ indices = np.argsort(dists)
82
+
83
+ # Select cropped indices
84
+ cropped: set[int] = set()
85
+ total_atoms = 0
86
+
87
+ # protein tokens
88
+ cropped_protein: set[int] = set()
89
+ ligand_ids = set(
90
+ valid_tokens[
91
+ valid_tokens["mol_type"] == const.chain_type_ids["NONPOLYMER"]
92
+ ]["token_idx"]
93
+ )
94
+
95
+ for idx in indices:
96
+ # Get the token
97
+ token = valid_tokens[idx]
98
+
99
+ # Get all tokens from this chain
100
+ chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]]
101
+
102
+ # Pick the whole chain if possible, otherwise select
103
+ # a contiguous subset centered at the query token
104
+ if len(chain_tokens) <= self.neighborhood_size:
105
+ new_tokens = chain_tokens
106
+ else:
107
+ # First limit to the maximum set of tokens, with the
108
+ # neighborhood on both sides to handle edges. This
109
+ # is mostly for efficiency with the while loop below.
110
+ min_idx = token["res_idx"] - self.neighborhood_size
111
+ max_idx = token["res_idx"] + self.neighborhood_size
112
+
113
+ max_token_set = chain_tokens
114
+ max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx]
115
+ max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx]
116
+
117
+ # Start by adding just the query token
118
+ new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]]
119
+
120
+ # Expand the neighborhood until we have enough tokens, one
121
+ # by one to handle some edge cases with non-standard chains.
122
+ # We switch to the res_idx instead of the token_idx to always
123
+ # include all tokens from modified residues or from ligands.
124
+ min_idx = max_idx = token["res_idx"]
125
+ while new_tokens.size < self.neighborhood_size:
126
+ min_idx = min_idx - 1
127
+ max_idx = max_idx + 1
128
+ new_tokens = max_token_set
129
+ new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx]
130
+ new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx]
131
+
132
+ # Compute new tokens and new atoms
133
+ new_indices = set(new_tokens["token_idx"]) - cropped
134
+ new_tokens = token_data[list(new_indices)]
135
+ new_atoms = np.sum(new_tokens["atom_num"])
136
+
137
+ # Stop if we exceed the max number of tokens or atoms
138
+ if (
139
+ (len(new_indices) > (max_tokens - len(cropped)))
140
+ or ((max_atoms is not None) and ((total_atoms + new_atoms) > max_atoms))
141
+ or (
142
+ len(cropped_protein | new_indices - ligand_ids)
143
+ > self.max_tokens_protein
144
+ )
145
+ ):
146
+ break
147
+
148
+ # Add new indices
149
+ cropped.update(new_indices)
150
+ total_atoms += new_atoms
151
+
152
+ # Add protein indices
153
+ cropped_protein.update(new_indices - ligand_ids)
154
+
155
+ # Get the cropped tokens sorted by index
156
+ token_data = token_data[sorted(cropped)]
157
+
158
+ # Only keep bonds within the cropped tokens
159
+ indices = token_data["token_idx"]
160
+ token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)]
161
+ token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)]
162
+
163
+ # Return the cropped tokens
164
+ return replace(data, tokens=token_data, bonds=token_bonds)
protify/FastPLMs/boltz/src/boltz/data/crop/boltz.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import replace
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ from scipy.spatial.distance import cdist
6
+
7
+ from boltz.data import const
8
+ from boltz.data.crop.cropper import Cropper
9
+ from boltz.data.types import Tokenized
10
+
11
+
12
+ def pick_random_token(
13
+ tokens: np.ndarray,
14
+ random: np.random.RandomState,
15
+ ) -> np.ndarray:
16
+ """Pick a random token from the data.
17
+
18
+ Parameters
19
+ ----------
20
+ tokens : np.ndarray
21
+ The token data.
22
+ random : np.ndarray
23
+ The random state for reproducibility.
24
+
25
+ Returns
26
+ -------
27
+ np.ndarray
28
+ The selected token.
29
+
30
+ """
31
+ return tokens[random.randint(len(tokens))]
32
+
33
+
34
+ def pick_chain_token(
35
+ tokens: np.ndarray,
36
+ chain_id: int,
37
+ random: np.random.RandomState,
38
+ ) -> np.ndarray:
39
+ """Pick a random token from a chain.
40
+
41
+ Parameters
42
+ ----------
43
+ tokens : np.ndarray
44
+ The token data.
45
+ chain_id : int
46
+ The chain ID.
47
+ random : np.ndarray
48
+ The random state for reproducibility.
49
+
50
+ Returns
51
+ -------
52
+ np.ndarray
53
+ The selected token.
54
+
55
+ """
56
+ # Filter to chain
57
+ chain_tokens = tokens[tokens["asym_id"] == chain_id]
58
+
59
+ # Pick from chain, fallback to all tokens
60
+ if chain_tokens.size:
61
+ query = pick_random_token(chain_tokens, random)
62
+ else:
63
+ query = pick_random_token(tokens, random)
64
+
65
+ return query
66
+
67
+
68
+ def pick_interface_token(
69
+ tokens: np.ndarray,
70
+ interface: np.ndarray,
71
+ random: np.random.RandomState,
72
+ ) -> np.ndarray:
73
+ """Pick a random token from an interface.
74
+
75
+ Parameters
76
+ ----------
77
+ tokens : np.ndarray
78
+ The token data.
79
+ interface : int
80
+ The interface ID.
81
+ random : np.ndarray
82
+ The random state for reproducibility.
83
+
84
+ Returns
85
+ -------
86
+ np.ndarray
87
+ The selected token.
88
+
89
+ """
90
+ # Sample random interface
91
+ chain_1 = int(interface["chain_1"])
92
+ chain_2 = int(interface["chain_2"])
93
+
94
+ tokens_1 = tokens[tokens["asym_id"] == chain_1]
95
+ tokens_2 = tokens[tokens["asym_id"] == chain_2]
96
+
97
+ # If no interface, pick from the chains
98
+ if tokens_1.size and (not tokens_2.size):
99
+ query = pick_random_token(tokens_1, random)
100
+ elif tokens_2.size and (not tokens_1.size):
101
+ query = pick_random_token(tokens_2, random)
102
+ elif (not tokens_1.size) and (not tokens_2.size):
103
+ query = pick_random_token(tokens, random)
104
+ else:
105
+ # If we have tokens, compute distances
106
+ tokens_1_coords = tokens_1["center_coords"]
107
+ tokens_2_coords = tokens_2["center_coords"]
108
+
109
+ dists = cdist(tokens_1_coords, tokens_2_coords)
110
+ cuttoff = dists < const.interface_cutoff
111
+
112
+ # In rare cases, the interface cuttoff is slightly
113
+ # too small, then we slightly expand it if it happens
114
+ if not np.any(cuttoff):
115
+ cuttoff = dists < (const.interface_cutoff + 5.0)
116
+
117
+ tokens_1 = tokens_1[np.any(cuttoff, axis=1)]
118
+ tokens_2 = tokens_2[np.any(cuttoff, axis=0)]
119
+
120
+ # Select random token
121
+ candidates = np.concatenate([tokens_1, tokens_2])
122
+ query = pick_random_token(candidates, random)
123
+
124
+ return query
125
+
126
+
127
+ class BoltzCropper(Cropper):
128
+ """Interpolate between contiguous and spatial crops."""
129
+
130
+ def __init__(self, min_neighborhood: int = 0, max_neighborhood: int = 40) -> None:
131
+ """Initialize the cropper.
132
+
133
+ Modulates the type of cropping to be performed.
134
+ Smaller neighborhoods result in more spatial
135
+ cropping. Larger neighborhoods result in more
136
+ continuous cropping. A mix can be achieved by
137
+ providing a range over which to sample.
138
+
139
+ Parameters
140
+ ----------
141
+ min_neighborhood : int
142
+ The minimum neighborhood size, by default 0.
143
+ max_neighborhood : int
144
+ The maximum neighborhood size, by default 40.
145
+
146
+ """
147
+ sizes = list(range(min_neighborhood, max_neighborhood + 1, 2))
148
+ self.neighborhood_sizes = sizes
149
+
150
+ def crop( # noqa: PLR0915
151
+ self,
152
+ data: Tokenized,
153
+ max_tokens: int,
154
+ random: np.random.RandomState,
155
+ max_atoms: Optional[int] = None,
156
+ chain_id: Optional[int] = None,
157
+ interface_id: Optional[int] = None,
158
+ ) -> Tokenized:
159
+ """Crop the data to a maximum number of tokens.
160
+
161
+ Parameters
162
+ ----------
163
+ data : Tokenized
164
+ The tokenized data.
165
+ max_tokens : int
166
+ The maximum number of tokens to crop.
167
+ random : np.random.RandomState
168
+ The random state for reproducibility.
169
+ max_atoms : int, optional
170
+ The maximum number of atoms to consider.
171
+ chain_id : int, optional
172
+ The chain ID to crop.
173
+ interface_id : int, optional
174
+ The interface ID to crop.
175
+
176
+ Returns
177
+ -------
178
+ Tokenized
179
+ The cropped data.
180
+
181
+ """
182
+ # Check inputs
183
+ if chain_id is not None and interface_id is not None:
184
+ msg = "Only one of chain_id or interface_id can be provided."
185
+ raise ValueError(msg)
186
+
187
+ # Randomly select a neighborhood size
188
+ neighborhood_size = random.choice(self.neighborhood_sizes)
189
+
190
+ # Get token data
191
+ token_data = data.tokens
192
+ token_bonds = data.bonds
193
+ mask = data.structure.mask
194
+ chains = data.structure.chains
195
+ interfaces = data.structure.interfaces
196
+
197
+ # Filter to valid chains
198
+ valid_chains = chains[mask]
199
+
200
+ # Filter to valid interfaces
201
+ valid_interfaces = interfaces
202
+ valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_1"]]]
203
+ valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_2"]]]
204
+
205
+ # Filter to resolved tokens
206
+ valid_tokens = token_data[token_data["resolved_mask"]]
207
+
208
+ # Check if we have any valid tokens
209
+ if not valid_tokens.size:
210
+ msg = "No valid tokens in structure"
211
+ raise ValueError(msg)
212
+
213
+ # Pick a random token, chain, or interface
214
+ if chain_id is not None:
215
+ query = pick_chain_token(valid_tokens, chain_id, random)
216
+ elif interface_id is not None:
217
+ interface = interfaces[interface_id]
218
+ query = pick_interface_token(valid_tokens, interface, random)
219
+ elif valid_interfaces.size:
220
+ idx = random.randint(len(valid_interfaces))
221
+ interface = valid_interfaces[idx]
222
+ query = pick_interface_token(valid_tokens, interface, random)
223
+ else:
224
+ idx = random.randint(len(valid_chains))
225
+ chain_id = valid_chains[idx]["asym_id"]
226
+ query = pick_chain_token(valid_tokens, chain_id, random)
227
+
228
+ # Sort all tokens by distance to query_coords
229
+ dists = valid_tokens["center_coords"] - query["center_coords"]
230
+ indices = np.argsort(np.linalg.norm(dists, axis=1))
231
+
232
+ # Select cropped indices
233
+ cropped: set[int] = set()
234
+ total_atoms = 0
235
+ for idx in indices:
236
+ # Get the token
237
+ token = valid_tokens[idx]
238
+
239
+ # Get all tokens from this chain
240
+ chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]]
241
+
242
+ # Pick the whole chain if possible, otherwise select
243
+ # a contiguous subset centered at the query token
244
+ if len(chain_tokens) <= neighborhood_size:
245
+ new_tokens = chain_tokens
246
+ else:
247
+ # First limit to the maximum set of tokens, with the
248
+ # neighborhood on both sides to handle edges. This
249
+ # is mostly for efficiency with the while loop below.
250
+ min_idx = token["res_idx"] - neighborhood_size
251
+ max_idx = token["res_idx"] + neighborhood_size
252
+
253
+ max_token_set = chain_tokens
254
+ max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx]
255
+ max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx]
256
+
257
+ # Start by adding just the query token
258
+ new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]]
259
+
260
+ # Expand the neighborhood until we have enough tokens, one
261
+ # by one to handle some edge cases with non-standard chains.
262
+ # We switch to the res_idx instead of the token_idx to always
263
+ # include all tokens from modified residues or from ligands.
264
+ min_idx = max_idx = token["res_idx"]
265
+ while new_tokens.size < neighborhood_size:
266
+ min_idx = min_idx - 1
267
+ max_idx = max_idx + 1
268
+ new_tokens = max_token_set
269
+ new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx]
270
+ new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx]
271
+
272
+ # Compute new tokens and new atoms
273
+ new_indices = set(new_tokens["token_idx"]) - cropped
274
+ new_tokens = token_data[list(new_indices)]
275
+ new_atoms = np.sum(new_tokens["atom_num"])
276
+
277
+ # Stop if we exceed the max number of tokens or atoms
278
+ if (len(new_indices) > (max_tokens - len(cropped))) or (
279
+ (max_atoms is not None) and ((total_atoms + new_atoms) > max_atoms)
280
+ ):
281
+ break
282
+
283
+ # Add new indices
284
+ cropped.update(new_indices)
285
+ total_atoms += new_atoms
286
+
287
+ # Get the cropped tokens sorted by index
288
+ token_data = token_data[sorted(cropped)]
289
+
290
+ # Only keep bonds within the cropped tokens
291
+ indices = token_data["token_idx"]
292
+ token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)]
293
+ token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)]
294
+
295
+ # Return the cropped tokens
296
+ return replace(data, tokens=token_data, bonds=token_bonds)
protify/FastPLMs/boltz/src/boltz/data/crop/cropper.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+
6
+ from boltz.data.types import Tokenized
7
+
8
+
9
+ class Cropper(ABC):
10
+ """Abstract base class for cropper."""
11
+
12
+ @abstractmethod
13
+ def crop(
14
+ self,
15
+ data: Tokenized,
16
+ max_tokens: int,
17
+ random: np.random.RandomState,
18
+ max_atoms: Optional[int] = None,
19
+ chain_id: Optional[int] = None,
20
+ interface_id: Optional[int] = None,
21
+ ) -> Tokenized:
22
+ """Crop the data to a maximum number of tokens.
23
+
24
+ Parameters
25
+ ----------
26
+ data : Tokenized
27
+ The tokenized data.
28
+ max_tokens : int
29
+ The maximum number of tokens to crop.
30
+ random : np.random.RandomState
31
+ The random state for reproducibility.
32
+ max_atoms : Optional[int]
33
+ The maximum number of atoms to consider.
34
+ chain_id : Optional[int]
35
+ The chain ID to crop.
36
+ interface_id : Optional[int]
37
+ The interface ID to crop.
38
+
39
+ Returns
40
+ -------
41
+ Tokenized
42
+ The cropped data.
43
+
44
+ """
45
+ raise NotImplementedError
protify/FastPLMs/boltz/src/boltz/data/feature/__init__.py ADDED
File without changes
protify/FastPLMs/boltz/src/boltz/data/feature/featurizer.py ADDED
@@ -0,0 +1,1225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from typing import Optional
4
+ from collections import deque
5
+ import numba
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ import torch
9
+ from numba import types
10
+ from torch import Tensor, from_numpy
11
+ from torch.nn.functional import one_hot
12
+
13
+ from boltz.data import const
14
+ from boltz.data.feature.symmetry import (
15
+ get_amino_acids_symmetries,
16
+ get_chain_symmetries,
17
+ get_ligand_symmetries,
18
+ )
19
+ from boltz.data.pad import pad_dim
20
+ from boltz.data.types import (
21
+ MSA,
22
+ MSADeletion,
23
+ MSAResidue,
24
+ MSASequence,
25
+ Tokenized,
26
+ )
27
+ from boltz.model.modules.utils import center_random_augmentation
28
+
29
+ ####################################################################################################
30
+ # HELPERS
31
+ ####################################################################################################
32
+
33
+
34
+ def compute_frames_nonpolymer(
35
+ data: Tokenized,
36
+ coords,
37
+ resolved_mask,
38
+ atom_to_token,
39
+ frame_data: list,
40
+ resolved_frame_data: list,
41
+ ) -> tuple[list, list]:
42
+ """Get the frames for non-polymer tokens.
43
+
44
+ Parameters
45
+ ----------
46
+ data : Tokenized
47
+ The tokenized data.
48
+ frame_data : list
49
+ The frame data.
50
+ resolved_frame_data : list
51
+ The resolved frame data.
52
+
53
+ Returns
54
+ -------
55
+ tuple[list, list]
56
+ The frame data and resolved frame data.
57
+
58
+ """
59
+ frame_data = np.array(frame_data)
60
+ resolved_frame_data = np.array(resolved_frame_data)
61
+ asym_id_token = data.tokens["asym_id"]
62
+ asym_id_atom = data.tokens["asym_id"][atom_to_token]
63
+ token_idx = 0
64
+ atom_idx = 0
65
+ for id in np.unique(data.tokens["asym_id"]):
66
+ mask_chain_token = asym_id_token == id
67
+ mask_chain_atom = asym_id_atom == id
68
+ num_tokens = mask_chain_token.sum()
69
+ num_atoms = mask_chain_atom.sum()
70
+ if (
71
+ data.tokens[token_idx]["mol_type"] != const.chain_type_ids["NONPOLYMER"]
72
+ or num_atoms < 3
73
+ ):
74
+ token_idx += num_tokens
75
+ atom_idx += num_atoms
76
+ continue
77
+ dist_mat = (
78
+ (
79
+ coords.reshape(-1, 3)[mask_chain_atom][:, None, :]
80
+ - coords.reshape(-1, 3)[mask_chain_atom][None, :, :]
81
+ )
82
+ ** 2
83
+ ).sum(-1) ** 0.5
84
+ resolved_pair = 1 - (
85
+ resolved_mask[mask_chain_atom][None, :]
86
+ * resolved_mask[mask_chain_atom][:, None]
87
+ ).astype(np.float32)
88
+ resolved_pair[resolved_pair == 1] = math.inf
89
+ indices = np.argsort(dist_mat + resolved_pair, axis=1)
90
+ frames = (
91
+ np.concatenate(
92
+ [
93
+ indices[:, 1:2],
94
+ indices[:, 0:1],
95
+ indices[:, 2:3],
96
+ ],
97
+ axis=1,
98
+ )
99
+ + atom_idx
100
+ )
101
+ frame_data[token_idx : token_idx + num_atoms, :] = frames
102
+ resolved_frame_data[token_idx : token_idx + num_atoms] = resolved_mask[
103
+ frames
104
+ ].all(axis=1)
105
+ token_idx += num_tokens
106
+ atom_idx += num_atoms
107
+ frames_expanded = coords.reshape(-1, 3)[frame_data]
108
+
109
+ mask_collinear = compute_collinear_mask(
110
+ frames_expanded[:, 1] - frames_expanded[:, 0],
111
+ frames_expanded[:, 1] - frames_expanded[:, 2],
112
+ )
113
+ return frame_data, resolved_frame_data & mask_collinear
114
+
115
+
116
+ def compute_collinear_mask(v1, v2):
117
+ norm1 = np.linalg.norm(v1, axis=1, keepdims=True)
118
+ norm2 = np.linalg.norm(v2, axis=1, keepdims=True)
119
+ v1 = v1 / (norm1 + 1e-6)
120
+ v2 = v2 / (norm2 + 1e-6)
121
+ mask_angle = np.abs(np.sum(v1 * v2, axis=1)) < 0.9063
122
+ mask_overlap1 = norm1.reshape(-1) > 1e-2
123
+ mask_overlap2 = norm2.reshape(-1) > 1e-2
124
+ return mask_angle & mask_overlap1 & mask_overlap2
125
+
126
+
127
+ def dummy_msa(residues: np.ndarray) -> MSA:
128
+ """Create a dummy MSA for a chain.
129
+
130
+ Parameters
131
+ ----------
132
+ residues : np.ndarray
133
+ The residues for the chain.
134
+
135
+ Returns
136
+ -------
137
+ MSA
138
+ The dummy MSA.
139
+
140
+ """
141
+ residues = [res["res_type"] for res in residues]
142
+ deletions = []
143
+ sequences = [(0, -1, 0, len(residues), 0, 0)]
144
+ return MSA(
145
+ residues=np.array(residues, dtype=MSAResidue),
146
+ deletions=np.array(deletions, dtype=MSADeletion),
147
+ sequences=np.array(sequences, dtype=MSASequence),
148
+ )
149
+
150
+
151
+ def construct_paired_msa( # noqa: C901, PLR0915, PLR0912
152
+ data: Tokenized,
153
+ max_seqs: int,
154
+ max_pairs: int = 8192,
155
+ max_total: int = 16384,
156
+ random_subset: bool = False,
157
+ ) -> tuple[Tensor, Tensor, Tensor]:
158
+ """Pair the MSA data.
159
+
160
+ Parameters
161
+ ----------
162
+ data : Input
163
+ The input data.
164
+
165
+ Returns
166
+ -------
167
+ Tensor
168
+ The MSA data.
169
+ Tensor
170
+ The deletion data.
171
+ Tensor
172
+ Mask indicating paired sequences.
173
+
174
+ """
175
+ # Get unique chains (ensuring monotonicity in the order)
176
+ assert np.all(np.diff(data.tokens["asym_id"], n=1) >= 0)
177
+ chain_ids = np.unique(data.tokens["asym_id"])
178
+
179
+ # Get relevant MSA, and create a dummy for chains without
180
+ msa = {k: data.msa[k] for k in chain_ids if k in data.msa}
181
+ for chain_id in chain_ids:
182
+ if chain_id not in msa:
183
+ chain = data.structure.chains[chain_id]
184
+ res_start = chain["res_idx"]
185
+ res_end = res_start + chain["res_num"]
186
+ residues = data.structure.residues[res_start:res_end]
187
+ msa[chain_id] = dummy_msa(residues)
188
+
189
+ # Map taxonomies to (chain_id, seq_idx)
190
+ taxonomy_map: dict[str, list] = {}
191
+ for chain_id, chain_msa in msa.items():
192
+ sequences = chain_msa.sequences
193
+ sequences = sequences[sequences["taxonomy"] != -1]
194
+ for sequence in sequences:
195
+ seq_idx = sequence["seq_idx"]
196
+ taxon = sequence["taxonomy"]
197
+ taxonomy_map.setdefault(taxon, []).append((chain_id, seq_idx))
198
+
199
+ # Remove taxonomies with only one sequence and sort by the
200
+ # number of chain_id present in each of the taxonomies
201
+ taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1}
202
+ taxonomy_map = sorted(
203
+ taxonomy_map.items(),
204
+ key=lambda x: len({c for c, _ in x[1]}),
205
+ reverse=True,
206
+ )
207
+
208
+ # Keep track of the sequences available per chain, keeping the original
209
+ # order of the sequences in the MSA to favor the best matching sequences
210
+ visited = {(c, s) for c, items in taxonomy_map for s in items}
211
+ available = {}
212
+ for c in chain_ids:
213
+ available[c] = deque(
214
+ i for i in range(1, len(msa[c].sequences)) if (c, i) not in visited
215
+ )
216
+
217
+ # Create sequence pairs
218
+ is_paired = []
219
+ pairing = []
220
+
221
+ # Start with the first sequence for each chain
222
+ is_paired.append({c: 1 for c in chain_ids})
223
+ pairing.append({c: 0 for c in chain_ids})
224
+
225
+ # Then add up to 8191 paired rows
226
+ for _, pairs in taxonomy_map:
227
+ # Group occurences by chain_id in case we have multiple
228
+ # sequences from the same chain and same taxonomy
229
+ chain_occurences = {}
230
+ for chain_id, seq_idx in pairs:
231
+ chain_occurences.setdefault(chain_id, []).append(seq_idx)
232
+
233
+ # We create as many pairings as the maximum number of occurences
234
+ max_occurences = max(len(v) for v in chain_occurences.values())
235
+ for i in range(max_occurences):
236
+ row_pairing = {}
237
+ row_is_paired = {}
238
+
239
+ # Add the chains present in the taxonomy
240
+ for chain_id, seq_idxs in chain_occurences.items():
241
+ # Roll over the sequence index to maximize diversity
242
+ idx = i % len(seq_idxs)
243
+ seq_idx = seq_idxs[idx]
244
+
245
+ # Add the sequence to the pairing
246
+ row_pairing[chain_id] = seq_idx
247
+ row_is_paired[chain_id] = 1
248
+
249
+ # Add any missing chains
250
+ for chain_id in chain_ids:
251
+ if chain_id not in row_pairing:
252
+ row_is_paired[chain_id] = 0
253
+ if available[chain_id]:
254
+ # Add the next available sequence
255
+ row_pairing[chain_id] = available[chain_id].popleft()
256
+ else:
257
+ # No more sequences available, we place a gap
258
+ row_pairing[chain_id] = -1
259
+
260
+ pairing.append(row_pairing)
261
+ is_paired.append(row_is_paired)
262
+
263
+ # Break if we have enough pairs
264
+ if len(pairing) >= max_pairs:
265
+ break
266
+
267
+ # Break if we have enough pairs
268
+ if len(pairing) >= max_pairs:
269
+ break
270
+
271
+ # Now add up to 16384 unpaired rows total
272
+ max_left = max(len(v) for v in available.values())
273
+ for _ in range(min(max_total - len(pairing), max_left)):
274
+ row_pairing = {}
275
+ row_is_paired = {}
276
+ for chain_id in chain_ids:
277
+ row_is_paired[chain_id] = 0
278
+ if available[chain_id]:
279
+ # Add the next available sequence
280
+ row_pairing[chain_id] = available[chain_id].popleft()
281
+ else:
282
+ # No more sequences available, we place a gap
283
+ row_pairing[chain_id] = -1
284
+
285
+ pairing.append(row_pairing)
286
+ is_paired.append(row_is_paired)
287
+
288
+ # Break if we have enough sequences
289
+ if len(pairing) >= max_total:
290
+ break
291
+
292
+ # Randomly sample a subset of the pairs
293
+ # ensuring the first row is always present
294
+ if random_subset:
295
+ num_seqs = len(pairing)
296
+ if num_seqs > max_seqs:
297
+ indices = np.random.choice(
298
+ list(range(1, num_seqs)), size=max_seqs - 1, replace=False
299
+ ) # noqa: NPY002
300
+ pairing = [pairing[0]] + [pairing[i] for i in indices]
301
+ is_paired = [is_paired[0]] + [is_paired[i] for i in indices]
302
+ else:
303
+ # Deterministic downsample to max_seqs
304
+ pairing = pairing[:max_seqs]
305
+ is_paired = is_paired[:max_seqs]
306
+
307
+ # Map (chain_id, seq_idx, res_idx) to deletion
308
+ deletions = numba.typed.Dict.empty(
309
+ key_type=numba.types.Tuple(
310
+ [numba.types.int64, numba.types.int64, numba.types.int64]),
311
+ value_type=numba.types.int64
312
+ )
313
+ for chain_id, chain_msa in msa.items():
314
+ chain_deletions = chain_msa.deletions
315
+ for sequence in chain_msa.sequences:
316
+ seq_idx = sequence["seq_idx"]
317
+ del_start = sequence["del_start"]
318
+ del_end = sequence["del_end"]
319
+ chain_deletions = chain_deletions[del_start:del_end]
320
+ for deletion_data in chain_deletions:
321
+ res_idx = deletion_data["res_idx"]
322
+ deletion_values = deletion_data["deletion"]
323
+ deletions[(chain_id, seq_idx, res_idx)] = deletion_values
324
+
325
+ # Add all the token MSA data
326
+ msa_data, del_data, paired_data = prepare_msa_arrays(
327
+ data.tokens, pairing, is_paired, deletions, msa
328
+ )
329
+
330
+ msa_data = torch.tensor(msa_data, dtype=torch.long)
331
+ del_data = torch.tensor(del_data, dtype=torch.float)
332
+ paired_data = torch.tensor(paired_data, dtype=torch.float)
333
+
334
+ return msa_data, del_data, paired_data
335
+
336
+
337
+ def prepare_msa_arrays(
338
+ tokens,
339
+ pairing: list[dict[int, int]],
340
+ is_paired: list[dict[int, int]],
341
+ deletions: dict[tuple[int, int, int], int],
342
+ msa: dict[int, MSA],
343
+ ) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
344
+ """Reshape data to play nicely with numba jit."""
345
+ token_asym_ids_arr = np.array([t["asym_id"] for t in tokens], dtype=np.int64)
346
+ token_res_idxs_arr = np.array([t["res_idx"] for t in tokens], dtype=np.int64)
347
+
348
+ chain_ids = sorted(msa.keys())
349
+
350
+ # chain_ids are not necessarily contiguous (e.g. they might be 0, 24, 25).
351
+ # This allows us to look up a chain_id by it's index in the chain_ids list.
352
+ chain_id_to_idx = {chain_id: i for i, chain_id in enumerate(chain_ids)}
353
+ token_asym_ids_idx_arr = np.array(
354
+ [chain_id_to_idx[asym_id] for asym_id in token_asym_ids_arr], dtype=np.int64
355
+ )
356
+
357
+ pairing_arr = np.zeros((len(pairing), len(chain_ids)), dtype=np.int64)
358
+ is_paired_arr = np.zeros((len(is_paired), len(chain_ids)), dtype=np.int64)
359
+
360
+ for i, row_pairing in enumerate(pairing):
361
+ for chain_id in chain_ids:
362
+ pairing_arr[i, chain_id_to_idx[chain_id]] = row_pairing[chain_id]
363
+
364
+ for i, row_is_paired in enumerate(is_paired):
365
+ for chain_id in chain_ids:
366
+ is_paired_arr[i, chain_id_to_idx[chain_id]] = row_is_paired[chain_id]
367
+
368
+ max_seq_len = max(len(msa[chain_id].sequences) for chain_id in chain_ids)
369
+
370
+ # we want res_start from sequences
371
+ msa_sequences = np.full((len(chain_ids), max_seq_len), -1, dtype=np.int64)
372
+ for chain_id in chain_ids:
373
+ for i, seq in enumerate(msa[chain_id].sequences):
374
+ msa_sequences[chain_id_to_idx[chain_id], i] = seq["res_start"]
375
+
376
+ max_residues_len = max(len(msa[chain_id].residues) for chain_id in chain_ids)
377
+ msa_residues = np.full((len(chain_ids), max_residues_len), -1, dtype=np.int64)
378
+ for chain_id in chain_ids:
379
+ residues = msa[chain_id].residues.astype(np.int64)
380
+ idxs = np.arange(len(residues))
381
+ chain_idx = chain_id_to_idx[chain_id]
382
+ msa_residues[chain_idx, idxs] = residues
383
+
384
+ return _prepare_msa_arrays_inner(
385
+ token_asym_ids_arr,
386
+ token_res_idxs_arr,
387
+ token_asym_ids_idx_arr,
388
+ pairing_arr,
389
+ is_paired_arr,
390
+ deletions,
391
+ msa_sequences,
392
+ msa_residues,
393
+ const.token_ids["-"],
394
+ )
395
+
396
+
397
+ deletions_dict_type = types.DictType(types.UniTuple(types.int64, 3), types.int64)
398
+
399
+
400
+ @numba.njit(
401
+ [
402
+ types.Tuple(
403
+ (
404
+ types.int64[:, ::1], # msa_data
405
+ types.int64[:, ::1], # del_data
406
+ types.int64[:, ::1], # paired_data
407
+ )
408
+ )(
409
+ types.int64[::1], # token_asym_ids
410
+ types.int64[::1], # token_res_idxs
411
+ types.int64[::1], # token_asym_ids_idx
412
+ types.int64[:, ::1], # pairing
413
+ types.int64[:, ::1], # is_paired
414
+ deletions_dict_type, # deletions
415
+ types.int64[:, ::1], # msa_sequences
416
+ types.int64[:, ::1], # msa_residues
417
+ types.int64, # gap_token
418
+ )
419
+ ],
420
+ cache=True,
421
+ )
422
+ def _prepare_msa_arrays_inner(
423
+ token_asym_ids: npt.NDArray[np.int64],
424
+ token_res_idxs: npt.NDArray[np.int64],
425
+ token_asym_ids_idx: npt.NDArray[np.int64],
426
+ pairing: npt.NDArray[np.int64],
427
+ is_paired: npt.NDArray[np.int64],
428
+ deletions: dict[tuple[int, int, int], int],
429
+ msa_sequences: npt.NDArray[np.int64],
430
+ msa_residues: npt.NDArray[np.int64],
431
+ gap_token: int,
432
+ ) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
433
+ n_tokens = len(token_asym_ids)
434
+ n_pairs = len(pairing)
435
+ msa_data = np.full((n_tokens, n_pairs), gap_token, dtype=np.int64)
436
+ paired_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
437
+ del_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
438
+
439
+ # Add all the token MSA data
440
+ for token_idx in range(n_tokens):
441
+ chain_id_idx = token_asym_ids_idx[token_idx]
442
+ chain_id = token_asym_ids[token_idx]
443
+ res_idx = token_res_idxs[token_idx]
444
+
445
+ for pair_idx in range(n_pairs):
446
+ seq_idx = pairing[pair_idx, chain_id_idx]
447
+ paired_data[token_idx, pair_idx] = is_paired[pair_idx, chain_id_idx]
448
+
449
+ # Add residue type
450
+ if seq_idx != -1:
451
+ res_start = msa_sequences[chain_id_idx, seq_idx]
452
+ res_type = msa_residues[chain_id_idx, res_start + res_idx]
453
+ k = (chain_id, seq_idx, res_idx)
454
+ if k in deletions:
455
+ del_data[token_idx, pair_idx] = deletions[k]
456
+ msa_data[token_idx, pair_idx] = res_type
457
+
458
+ return msa_data, del_data, paired_data
459
+
460
+
461
+ ####################################################################################################
462
+ # FEATURES
463
+ ####################################################################################################
464
+
465
+
466
+ def select_subset_from_mask(mask, p):
467
+ num_true = np.sum(mask)
468
+ v = np.random.geometric(p) + 1
469
+ k = min(v, num_true)
470
+
471
+ true_indices = np.where(mask)[0]
472
+
473
+ # Randomly select k indices from the true_indices
474
+ selected_indices = np.random.choice(true_indices, size=k, replace=False)
475
+
476
+ new_mask = np.zeros_like(mask)
477
+ new_mask[selected_indices] = 1
478
+
479
+ return new_mask
480
+
481
+
482
+ def process_token_features(
483
+ data: Tokenized,
484
+ max_tokens: Optional[int] = None,
485
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
486
+ binder_pocket_cutoff: Optional[float] = 6.0,
487
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
488
+ only_ligand_binder_pocket: Optional[bool] = False,
489
+ inference_binder: Optional[list[int]] = None,
490
+ inference_pocket: Optional[list[tuple[int, int]]] = None,
491
+ ) -> dict[str, Tensor]:
492
+ """Get the token features.
493
+
494
+ Parameters
495
+ ----------
496
+ data : Tokenized
497
+ The tokenized data.
498
+ max_tokens : int
499
+ The maximum number of tokens.
500
+
501
+ Returns
502
+ -------
503
+ dict[str, Tensor]
504
+ The token features.
505
+
506
+ """
507
+ # Token data
508
+ token_data = data.tokens
509
+ token_bonds = data.bonds
510
+
511
+ # Token core features
512
+ token_index = torch.arange(len(token_data), dtype=torch.long)
513
+ residue_index = from_numpy(token_data["res_idx"].copy()).long()
514
+ asym_id = from_numpy(token_data["asym_id"].copy()).long()
515
+ entity_id = from_numpy(token_data["entity_id"].copy()).long()
516
+ sym_id = from_numpy(token_data["sym_id"].copy()).long()
517
+ mol_type = from_numpy(token_data["mol_type"].copy()).long()
518
+ res_type = from_numpy(token_data["res_type"].copy()).long()
519
+ res_type = one_hot(res_type, num_classes=const.num_tokens)
520
+ disto_center = from_numpy(token_data["disto_coords"].copy())
521
+
522
+ # Token mask features
523
+ pad_mask = torch.ones(len(token_data), dtype=torch.float)
524
+ resolved_mask = from_numpy(token_data["resolved_mask"].copy()).float()
525
+ disto_mask = from_numpy(token_data["disto_mask"].copy()).float()
526
+ cyclic_period = from_numpy(token_data["cyclic_period"].copy())
527
+
528
+ # Token bond features
529
+ if max_tokens is not None:
530
+ pad_len = max_tokens - len(token_data)
531
+ num_tokens = max_tokens if pad_len > 0 else len(token_data)
532
+ else:
533
+ num_tokens = len(token_data)
534
+
535
+ tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)}
536
+ bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float)
537
+ for token_bond in token_bonds:
538
+ token_1 = tok_to_idx[token_bond["token_1"]]
539
+ token_2 = tok_to_idx[token_bond["token_2"]]
540
+ bonds[token_1, token_2] = 1
541
+ bonds[token_2, token_1] = 1
542
+
543
+ bonds = bonds.unsqueeze(-1)
544
+
545
+ # Pocket conditioned feature
546
+ pocket_feature = (
547
+ np.zeros(len(token_data)) + const.pocket_contact_info["UNSPECIFIED"]
548
+ )
549
+ if inference_binder is not None:
550
+ assert inference_pocket is not None
551
+ pocket_residues = set(inference_pocket)
552
+ for idx, token in enumerate(token_data):
553
+ if token["asym_id"] == inference_binder:
554
+ pocket_feature[idx] = const.pocket_contact_info["BINDER"]
555
+ elif (token["asym_id"], token["res_idx"]) in pocket_residues:
556
+ pocket_feature[idx] = const.pocket_contact_info["POCKET"]
557
+ else:
558
+ pocket_feature[idx] = const.pocket_contact_info["UNSELECTED"]
559
+ elif (
560
+ binder_pocket_conditioned_prop > 0.0
561
+ and random.random() < binder_pocket_conditioned_prop
562
+ ):
563
+ # choose as binder a random ligand in the crop, if there are no ligands select a protein chain
564
+ binder_asym_ids = np.unique(
565
+ token_data["asym_id"][
566
+ token_data["mol_type"] == const.chain_type_ids["NONPOLYMER"]
567
+ ]
568
+ )
569
+
570
+ if len(binder_asym_ids) == 0:
571
+ if not only_ligand_binder_pocket:
572
+ binder_asym_ids = np.unique(token_data["asym_id"])
573
+
574
+ if len(binder_asym_ids) > 0:
575
+ pocket_asym_id = random.choice(binder_asym_ids)
576
+ binder_mask = token_data["asym_id"] == pocket_asym_id
577
+
578
+ binder_coords = []
579
+ for token in token_data:
580
+ if token["asym_id"] == pocket_asym_id:
581
+ binder_coords.append(
582
+ data.structure.atoms["coords"][
583
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
584
+ ]
585
+ )
586
+ binder_coords = np.concatenate(binder_coords, axis=0)
587
+
588
+ # find the tokens in the pocket
589
+ token_dist = np.zeros(len(token_data)) + 1000
590
+ for i, token in enumerate(token_data):
591
+ if (
592
+ token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
593
+ and token["asym_id"] != pocket_asym_id
594
+ and token["resolved_mask"] == 1
595
+ ):
596
+ token_coords = data.structure.atoms["coords"][
597
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
598
+ ]
599
+
600
+ # find chain and apply chain transformation
601
+ for chain in data.structure.chains:
602
+ if chain["asym_id"] == token["asym_id"]:
603
+ break
604
+
605
+ token_dist[i] = np.min(
606
+ np.linalg.norm(
607
+ token_coords[:, None, :] - binder_coords[None, :, :],
608
+ axis=-1,
609
+ )
610
+ )
611
+
612
+ pocket_mask = token_dist < binder_pocket_cutoff
613
+
614
+ if np.sum(pocket_mask) > 0:
615
+ pocket_feature = (
616
+ np.zeros(len(token_data)) + const.pocket_contact_info["UNSELECTED"]
617
+ )
618
+ pocket_feature[binder_mask] = const.pocket_contact_info["BINDER"]
619
+
620
+ if binder_pocket_sampling_geometric_p > 0.0:
621
+ # select a subset of the pocket, according
622
+ # to a geometric distribution with one as minimum
623
+ pocket_mask = select_subset_from_mask(
624
+ pocket_mask, binder_pocket_sampling_geometric_p
625
+ )
626
+
627
+ pocket_feature[pocket_mask] = const.pocket_contact_info["POCKET"]
628
+ pocket_feature = from_numpy(pocket_feature).long()
629
+ pocket_feature = one_hot(pocket_feature, num_classes=len(const.pocket_contact_info))
630
+
631
+ # Pad to max tokens if given
632
+ if max_tokens is not None:
633
+ pad_len = max_tokens - len(token_data)
634
+ if pad_len > 0:
635
+ token_index = pad_dim(token_index, 0, pad_len)
636
+ residue_index = pad_dim(residue_index, 0, pad_len)
637
+ asym_id = pad_dim(asym_id, 0, pad_len)
638
+ entity_id = pad_dim(entity_id, 0, pad_len)
639
+ sym_id = pad_dim(sym_id, 0, pad_len)
640
+ mol_type = pad_dim(mol_type, 0, pad_len)
641
+ res_type = pad_dim(res_type, 0, pad_len)
642
+ disto_center = pad_dim(disto_center, 0, pad_len)
643
+ pad_mask = pad_dim(pad_mask, 0, pad_len)
644
+ resolved_mask = pad_dim(resolved_mask, 0, pad_len)
645
+ disto_mask = pad_dim(disto_mask, 0, pad_len)
646
+ pocket_feature = pad_dim(pocket_feature, 0, pad_len)
647
+ cyclic_period = pad_dim(cyclic_period, 0, pad_len)
648
+
649
+ token_features = {
650
+ "token_index": token_index,
651
+ "residue_index": residue_index,
652
+ "asym_id": asym_id,
653
+ "entity_id": entity_id,
654
+ "sym_id": sym_id,
655
+ "mol_type": mol_type,
656
+ "res_type": res_type,
657
+ "disto_center": disto_center,
658
+ "token_bonds": bonds,
659
+ "token_pad_mask": pad_mask,
660
+ "token_resolved_mask": resolved_mask,
661
+ "token_disto_mask": disto_mask,
662
+ "pocket_feature": pocket_feature,
663
+ "cyclic_period": cyclic_period,
664
+ }
665
+ return token_features
666
+
667
+
668
+ def process_atom_features(
669
+ data: Tokenized,
670
+ atoms_per_window_queries: int = 32,
671
+ min_dist: float = 2.0,
672
+ max_dist: float = 22.0,
673
+ num_bins: int = 64,
674
+ max_atoms: Optional[int] = None,
675
+ max_tokens: Optional[int] = None,
676
+ ) -> dict[str, Tensor]:
677
+ """Get the atom features.
678
+
679
+ Parameters
680
+ ----------
681
+ data : Tokenized
682
+ The tokenized data.
683
+ max_atoms : int, optional
684
+ The maximum number of atoms.
685
+
686
+ Returns
687
+ -------
688
+ dict[str, Tensor]
689
+ The atom features.
690
+
691
+ """
692
+ # Filter to tokens' atoms
693
+ atom_data = []
694
+ ref_space_uid = []
695
+ coord_data = []
696
+ frame_data = []
697
+ resolved_frame_data = []
698
+ atom_to_token = []
699
+ token_to_rep_atom = [] # index on cropped atom table
700
+ r_set_to_rep_atom = []
701
+ disto_coords = []
702
+ atom_idx = 0
703
+
704
+ chain_res_ids = {}
705
+ for token_id, token in enumerate(data.tokens):
706
+ # Get the chain residue ids
707
+ chain_idx, res_id = token["asym_id"], token["res_idx"]
708
+ chain = data.structure.chains[chain_idx]
709
+
710
+ if (chain_idx, res_id) not in chain_res_ids:
711
+ new_idx = len(chain_res_ids)
712
+ chain_res_ids[(chain_idx, res_id)] = new_idx
713
+ else:
714
+ new_idx = chain_res_ids[(chain_idx, res_id)]
715
+
716
+ # Map atoms to token indices
717
+ ref_space_uid.extend([new_idx] * token["atom_num"])
718
+ atom_to_token.extend([token_id] * token["atom_num"])
719
+
720
+ # Add atom data
721
+ start = token["atom_idx"]
722
+ end = token["atom_idx"] + token["atom_num"]
723
+ token_atoms = data.structure.atoms[start:end]
724
+
725
+ # Map token to representative atom
726
+ token_to_rep_atom.append(atom_idx + token["disto_idx"] - start)
727
+ if (chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]) and token[
728
+ "resolved_mask"
729
+ ]:
730
+ r_set_to_rep_atom.append(atom_idx + token["center_idx"] - start)
731
+
732
+ # Get token coordinates
733
+ token_coords = np.array([token_atoms["coords"]])
734
+ coord_data.append(token_coords)
735
+
736
+ # Get frame data
737
+ res_type = const.tokens[token["res_type"]]
738
+
739
+ if token["atom_num"] < 3 or res_type in ["PAD", "UNK", "-"]:
740
+ idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
741
+ mask_frame = False
742
+ elif (token["mol_type"] == const.chain_type_ids["PROTEIN"]) and (
743
+ res_type in const.ref_atoms
744
+ ):
745
+ idx_frame_a, idx_frame_b, idx_frame_c = (
746
+ const.ref_atoms[res_type].index("N"),
747
+ const.ref_atoms[res_type].index("CA"),
748
+ const.ref_atoms[res_type].index("C"),
749
+ )
750
+ mask_frame = (
751
+ token_atoms["is_present"][idx_frame_a]
752
+ and token_atoms["is_present"][idx_frame_b]
753
+ and token_atoms["is_present"][idx_frame_c]
754
+ )
755
+ elif (
756
+ token["mol_type"] == const.chain_type_ids["DNA"]
757
+ or token["mol_type"] == const.chain_type_ids["RNA"]
758
+ ) and (res_type in const.ref_atoms):
759
+ idx_frame_a, idx_frame_b, idx_frame_c = (
760
+ const.ref_atoms[res_type].index("C1'"),
761
+ const.ref_atoms[res_type].index("C3'"),
762
+ const.ref_atoms[res_type].index("C4'"),
763
+ )
764
+ mask_frame = (
765
+ token_atoms["is_present"][idx_frame_a]
766
+ and token_atoms["is_present"][idx_frame_b]
767
+ and token_atoms["is_present"][idx_frame_c]
768
+ )
769
+ else:
770
+ idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
771
+ mask_frame = False
772
+ frame_data.append(
773
+ [idx_frame_a + atom_idx, idx_frame_b + atom_idx, idx_frame_c + atom_idx]
774
+ )
775
+ resolved_frame_data.append(mask_frame)
776
+
777
+ # Get distogram coordinates
778
+ disto_coords_tok = data.structure.atoms[token["disto_idx"]]["coords"]
779
+ disto_coords.append(disto_coords_tok)
780
+
781
+ # Update atom data. This is technically never used again (we rely on coord_data),
782
+ # but we update for consistency and to make sure the Atom object has valid, transformed coordinates.
783
+ token_atoms = token_atoms.copy()
784
+ token_atoms["coords"] = token_coords[0] # atom has a copy of first coords
785
+ atom_data.append(token_atoms)
786
+ atom_idx += len(token_atoms)
787
+
788
+ disto_coords = np.array(disto_coords)
789
+
790
+ # Compute distogram
791
+ t_center = torch.Tensor(disto_coords)
792
+ t_dists = torch.cdist(t_center, t_center)
793
+ boundaries = torch.linspace(min_dist, max_dist, num_bins - 1)
794
+ distogram = (t_dists.unsqueeze(-1) > boundaries).sum(dim=-1).long()
795
+ disto_target = one_hot(distogram, num_classes=num_bins)
796
+
797
+ atom_data = np.concatenate(atom_data)
798
+ coord_data = np.concatenate(coord_data, axis=1)
799
+ ref_space_uid = np.array(ref_space_uid)
800
+
801
+ # Compute features
802
+ ref_atom_name_chars = from_numpy(atom_data["name"]).long()
803
+ ref_element = from_numpy(atom_data["element"]).long()
804
+ ref_charge = from_numpy(atom_data["charge"])
805
+ ref_pos = from_numpy(
806
+ atom_data["conformer"].copy()
807
+ ) # not sure why I need to copy here..
808
+ ref_space_uid = from_numpy(ref_space_uid)
809
+ coords = from_numpy(coord_data.copy())
810
+ resolved_mask = from_numpy(atom_data["is_present"])
811
+ pad_mask = torch.ones(len(atom_data), dtype=torch.float)
812
+ atom_to_token = torch.tensor(atom_to_token, dtype=torch.long)
813
+ token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long)
814
+ r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long)
815
+ frame_data, resolved_frame_data = compute_frames_nonpolymer(
816
+ data,
817
+ coord_data,
818
+ atom_data["is_present"],
819
+ atom_to_token,
820
+ frame_data,
821
+ resolved_frame_data,
822
+ ) # Compute frames for NONPOLYMER tokens
823
+ frames = from_numpy(frame_data.copy())
824
+ frame_resolved_mask = from_numpy(resolved_frame_data.copy())
825
+ # Convert to one-hot
826
+ ref_atom_name_chars = one_hot(
827
+ ref_atom_name_chars % num_bins, num_classes=num_bins
828
+ ) # added for lower case letters
829
+ ref_element = one_hot(ref_element, num_classes=const.num_elements)
830
+ atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
831
+ token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data))
832
+ r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data))
833
+
834
+ # Center the ground truth coordinates
835
+ center = (coords * resolved_mask[None, :, None]).sum(dim=1)
836
+ center = center / resolved_mask.sum().clamp(min=1)
837
+ coords = coords - center[:, None]
838
+
839
+ # Apply random roto-translation to the input atoms
840
+ ref_pos = center_random_augmentation(
841
+ ref_pos[None], resolved_mask[None], centering=False
842
+ )[0]
843
+
844
+ # Compute padding and apply
845
+ if max_atoms is not None:
846
+ assert max_atoms % atoms_per_window_queries == 0
847
+ pad_len = max_atoms - len(atom_data)
848
+ else:
849
+ pad_len = (
850
+ (len(atom_data) - 1) // atoms_per_window_queries + 1
851
+ ) * atoms_per_window_queries - len(atom_data)
852
+
853
+ if pad_len > 0:
854
+ pad_mask = pad_dim(pad_mask, 0, pad_len)
855
+ ref_pos = pad_dim(ref_pos, 0, pad_len)
856
+ resolved_mask = pad_dim(resolved_mask, 0, pad_len)
857
+ ref_element = pad_dim(ref_element, 0, pad_len)
858
+ ref_charge = pad_dim(ref_charge, 0, pad_len)
859
+ ref_atom_name_chars = pad_dim(ref_atom_name_chars, 0, pad_len)
860
+ ref_space_uid = pad_dim(ref_space_uid, 0, pad_len)
861
+ coords = pad_dim(coords, 1, pad_len)
862
+ atom_to_token = pad_dim(atom_to_token, 0, pad_len)
863
+ token_to_rep_atom = pad_dim(token_to_rep_atom, 1, pad_len)
864
+ r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 1, pad_len)
865
+
866
+ if max_tokens is not None:
867
+ pad_len = max_tokens - token_to_rep_atom.shape[0]
868
+ if pad_len > 0:
869
+ atom_to_token = pad_dim(atom_to_token, 1, pad_len)
870
+ token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len)
871
+ r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len)
872
+ disto_target = pad_dim(pad_dim(disto_target, 0, pad_len), 1, pad_len)
873
+ frames = pad_dim(frames, 0, pad_len)
874
+ frame_resolved_mask = pad_dim(frame_resolved_mask, 0, pad_len)
875
+
876
+ return {
877
+ "ref_pos": ref_pos,
878
+ "atom_resolved_mask": resolved_mask,
879
+ "ref_element": ref_element,
880
+ "ref_charge": ref_charge,
881
+ "ref_atom_name_chars": ref_atom_name_chars,
882
+ "ref_space_uid": ref_space_uid,
883
+ "coords": coords,
884
+ "atom_pad_mask": pad_mask,
885
+ "atom_to_token": atom_to_token,
886
+ "token_to_rep_atom": token_to_rep_atom,
887
+ "r_set_to_rep_atom": r_set_to_rep_atom,
888
+ "disto_target": disto_target,
889
+ "frames_idx": frames,
890
+ "frame_resolved_mask": frame_resolved_mask,
891
+ }
892
+
893
+
894
+ def process_msa_features(
895
+ data: Tokenized,
896
+ max_seqs_batch: int,
897
+ max_seqs: int,
898
+ max_tokens: Optional[int] = None,
899
+ pad_to_max_seqs: bool = False,
900
+ ) -> dict[str, Tensor]:
901
+ """Get the MSA features.
902
+
903
+ Parameters
904
+ ----------
905
+ data : Tokenized
906
+ The tokenized data.
907
+ max_seqs : int
908
+ The maximum number of MSA sequences.
909
+ max_tokens : int
910
+ The maximum number of tokens.
911
+ pad_to_max_seqs : bool
912
+ Whether to pad to the maximum number of sequences.
913
+
914
+ Returns
915
+ -------
916
+ dict[str, Tensor]
917
+ The MSA features.
918
+
919
+ """
920
+ # Created paired MSA
921
+ msa, deletion, paired = construct_paired_msa(data, max_seqs_batch)
922
+ msa, deletion, paired = (
923
+ msa.transpose(1, 0),
924
+ deletion.transpose(1, 0),
925
+ paired.transpose(1, 0),
926
+ ) # (N_MSA, N_RES, N_AA)
927
+
928
+ # Prepare features
929
+ msa = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
930
+ msa_mask = torch.ones_like(msa[:, :, 0])
931
+ profile = msa.float().mean(dim=0)
932
+ has_deletion = deletion > 0
933
+ deletion = np.pi / 2 * np.arctan(deletion / 3)
934
+ deletion_mean = deletion.mean(axis=0)
935
+
936
+ # Pad in the MSA dimension (dim=0)
937
+ if pad_to_max_seqs:
938
+ pad_len = max_seqs - msa.shape[0]
939
+ if pad_len > 0:
940
+ msa = pad_dim(msa, 0, pad_len, const.token_ids["-"])
941
+ paired = pad_dim(paired, 0, pad_len)
942
+ msa_mask = pad_dim(msa_mask, 0, pad_len)
943
+ has_deletion = pad_dim(has_deletion, 0, pad_len)
944
+ deletion = pad_dim(deletion, 0, pad_len)
945
+
946
+ # Pad in the token dimension (dim=1)
947
+ if max_tokens is not None:
948
+ pad_len = max_tokens - msa.shape[1]
949
+ if pad_len > 0:
950
+ msa = pad_dim(msa, 1, pad_len, const.token_ids["-"])
951
+ paired = pad_dim(paired, 1, pad_len)
952
+ msa_mask = pad_dim(msa_mask, 1, pad_len)
953
+ has_deletion = pad_dim(has_deletion, 1, pad_len)
954
+ deletion = pad_dim(deletion, 1, pad_len)
955
+ profile = pad_dim(profile, 0, pad_len)
956
+ deletion_mean = pad_dim(deletion_mean, 0, pad_len)
957
+
958
+ return {
959
+ "msa": msa,
960
+ "msa_paired": paired,
961
+ "deletion_value": deletion,
962
+ "has_deletion": has_deletion,
963
+ "deletion_mean": deletion_mean,
964
+ "profile": profile,
965
+ "msa_mask": msa_mask,
966
+ }
967
+
968
+
969
+ def process_symmetry_features(
970
+ cropped: Tokenized, symmetries: dict
971
+ ) -> dict[str, Tensor]:
972
+ """Get the symmetry features.
973
+
974
+ Parameters
975
+ ----------
976
+ data : Tokenized
977
+ The tokenized data.
978
+
979
+ Returns
980
+ -------
981
+ dict[str, Tensor]
982
+ The symmetry features.
983
+
984
+ """
985
+ features = get_chain_symmetries(cropped)
986
+ features.update(get_amino_acids_symmetries(cropped))
987
+ features.update(get_ligand_symmetries(cropped, symmetries))
988
+
989
+ return features
990
+
991
+
992
+ def process_residue_constraint_features(
993
+ data: Tokenized,
994
+ ) -> dict[str, Tensor]:
995
+ residue_constraints = data.residue_constraints
996
+ if residue_constraints is not None:
997
+ rdkit_bounds_constraints = residue_constraints.rdkit_bounds_constraints
998
+ chiral_atom_constraints = residue_constraints.chiral_atom_constraints
999
+ stereo_bond_constraints = residue_constraints.stereo_bond_constraints
1000
+ planar_bond_constraints = residue_constraints.planar_bond_constraints
1001
+ planar_ring_5_constraints = residue_constraints.planar_ring_5_constraints
1002
+ planar_ring_6_constraints = residue_constraints.planar_ring_6_constraints
1003
+
1004
+ rdkit_bounds_index = torch.tensor(
1005
+ rdkit_bounds_constraints["atom_idxs"].copy(), dtype=torch.long
1006
+ ).T
1007
+ rdkit_bounds_bond_mask = torch.tensor(
1008
+ rdkit_bounds_constraints["is_bond"].copy(), dtype=torch.bool
1009
+ )
1010
+ rdkit_bounds_angle_mask = torch.tensor(
1011
+ rdkit_bounds_constraints["is_angle"].copy(), dtype=torch.bool
1012
+ )
1013
+ rdkit_upper_bounds = torch.tensor(
1014
+ rdkit_bounds_constraints["upper_bound"].copy(), dtype=torch.float
1015
+ )
1016
+ rdkit_lower_bounds = torch.tensor(
1017
+ rdkit_bounds_constraints["lower_bound"].copy(), dtype=torch.float
1018
+ )
1019
+
1020
+ chiral_atom_index = torch.tensor(
1021
+ chiral_atom_constraints["atom_idxs"].copy(), dtype=torch.long
1022
+ ).T
1023
+ chiral_reference_mask = torch.tensor(
1024
+ chiral_atom_constraints["is_reference"].copy(), dtype=torch.bool
1025
+ )
1026
+ chiral_atom_orientations = torch.tensor(
1027
+ chiral_atom_constraints["is_r"].copy(), dtype=torch.bool
1028
+ )
1029
+
1030
+ stereo_bond_index = torch.tensor(
1031
+ stereo_bond_constraints["atom_idxs"].copy(), dtype=torch.long
1032
+ ).T
1033
+ stereo_reference_mask = torch.tensor(
1034
+ stereo_bond_constraints["is_reference"].copy(), dtype=torch.bool
1035
+ )
1036
+ stereo_bond_orientations = torch.tensor(
1037
+ stereo_bond_constraints["is_e"].copy(), dtype=torch.bool
1038
+ )
1039
+
1040
+ planar_bond_index = torch.tensor(
1041
+ planar_bond_constraints["atom_idxs"].copy(), dtype=torch.long
1042
+ ).T
1043
+ planar_ring_5_index = torch.tensor(
1044
+ planar_ring_5_constraints["atom_idxs"].copy(), dtype=torch.long
1045
+ ).T
1046
+ planar_ring_6_index = torch.tensor(
1047
+ planar_ring_6_constraints["atom_idxs"].copy(), dtype=torch.long
1048
+ ).T
1049
+ else:
1050
+ rdkit_bounds_index = torch.empty((2, 0), dtype=torch.long)
1051
+ rdkit_bounds_bond_mask = torch.empty((0,), dtype=torch.bool)
1052
+ rdkit_bounds_angle_mask = torch.empty((0,), dtype=torch.bool)
1053
+ rdkit_upper_bounds = torch.empty((0,), dtype=torch.float)
1054
+ rdkit_lower_bounds = torch.empty((0,), dtype=torch.float)
1055
+ chiral_atom_index = torch.empty(
1056
+ (
1057
+ 4,
1058
+ 0,
1059
+ ),
1060
+ dtype=torch.long,
1061
+ )
1062
+ chiral_reference_mask = torch.empty((0,), dtype=torch.bool)
1063
+ chiral_atom_orientations = torch.empty((0,), dtype=torch.bool)
1064
+ stereo_bond_index = torch.empty((4, 0), dtype=torch.long)
1065
+ stereo_reference_mask = torch.empty((0,), dtype=torch.bool)
1066
+ stereo_bond_orientations = torch.empty((0,), dtype=torch.bool)
1067
+ planar_bond_index = torch.empty((6, 0), dtype=torch.long)
1068
+ planar_ring_5_index = torch.empty((5, 0), dtype=torch.long)
1069
+ planar_ring_6_index = torch.empty((6, 0), dtype=torch.long)
1070
+
1071
+ return {
1072
+ "rdkit_bounds_index": rdkit_bounds_index,
1073
+ "rdkit_bounds_bond_mask": rdkit_bounds_bond_mask,
1074
+ "rdkit_bounds_angle_mask": rdkit_bounds_angle_mask,
1075
+ "rdkit_upper_bounds": rdkit_upper_bounds,
1076
+ "rdkit_lower_bounds": rdkit_lower_bounds,
1077
+ "chiral_atom_index": chiral_atom_index,
1078
+ "chiral_reference_mask": chiral_reference_mask,
1079
+ "chiral_atom_orientations": chiral_atom_orientations,
1080
+ "stereo_bond_index": stereo_bond_index,
1081
+ "stereo_reference_mask": stereo_reference_mask,
1082
+ "stereo_bond_orientations": stereo_bond_orientations,
1083
+ "planar_bond_index": planar_bond_index,
1084
+ "planar_ring_5_index": planar_ring_5_index,
1085
+ "planar_ring_6_index": planar_ring_6_index,
1086
+ }
1087
+
1088
+
1089
+ def process_chain_feature_constraints(
1090
+ data: Tokenized,
1091
+ ) -> dict[str, Tensor]:
1092
+ structure = data.structure
1093
+ if structure.connections.shape[0] > 0:
1094
+ connected_chain_index, connected_atom_index = [], []
1095
+ for connection in structure.connections:
1096
+ connected_chain_index.append([connection["chain_1"], connection["chain_2"]])
1097
+ connected_atom_index.append([connection["atom_1"], connection["atom_2"]])
1098
+ connected_chain_index = torch.tensor(connected_chain_index, dtype=torch.long).T
1099
+ connected_atom_index = torch.tensor(connected_atom_index, dtype=torch.long).T
1100
+ else:
1101
+ connected_chain_index = torch.empty((2, 0), dtype=torch.long)
1102
+ connected_atom_index = torch.empty((2, 0), dtype=torch.long)
1103
+
1104
+ symmetric_chain_index = []
1105
+ for i, chain_i in enumerate(structure.chains):
1106
+ for j, chain_j in enumerate(structure.chains):
1107
+ if j <= i:
1108
+ continue
1109
+ if chain_i["entity_id"] == chain_j["entity_id"]:
1110
+ symmetric_chain_index.append([i, j])
1111
+ if len(symmetric_chain_index) > 0:
1112
+ symmetric_chain_index = torch.tensor(symmetric_chain_index, dtype=torch.long).T
1113
+ else:
1114
+ symmetric_chain_index = torch.empty((2, 0), dtype=torch.long)
1115
+ return {
1116
+ "connected_chain_index": connected_chain_index,
1117
+ "connected_atom_index": connected_atom_index,
1118
+ "symmetric_chain_index": symmetric_chain_index,
1119
+ }
1120
+
1121
+
1122
+ class BoltzFeaturizer:
1123
+ """Boltz featurizer."""
1124
+
1125
+ def process(
1126
+ self,
1127
+ data: Tokenized,
1128
+ training: bool,
1129
+ max_seqs: int = 4096,
1130
+ atoms_per_window_queries: int = 32,
1131
+ min_dist: float = 2.0,
1132
+ max_dist: float = 22.0,
1133
+ num_bins: int = 64,
1134
+ max_tokens: Optional[int] = None,
1135
+ max_atoms: Optional[int] = None,
1136
+ pad_to_max_seqs: bool = False,
1137
+ compute_symmetries: bool = False,
1138
+ symmetries: Optional[dict] = None,
1139
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
1140
+ binder_pocket_cutoff: Optional[float] = 6.0,
1141
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
1142
+ only_ligand_binder_pocket: Optional[bool] = False,
1143
+ inference_binder: Optional[int] = None,
1144
+ inference_pocket: Optional[list[tuple[int, int]]] = None,
1145
+ compute_constraint_features: bool = False,
1146
+ ) -> dict[str, Tensor]:
1147
+ """Compute features.
1148
+
1149
+ Parameters
1150
+ ----------
1151
+ data : Tokenized
1152
+ The tokenized data.
1153
+ training : bool
1154
+ Whether the model is in training mode.
1155
+ max_tokens : int, optional
1156
+ The maximum number of tokens.
1157
+ max_atoms : int, optional
1158
+ The maximum number of atoms
1159
+ max_seqs : int, optional
1160
+ The maximum number of sequences.
1161
+
1162
+ Returns
1163
+ -------
1164
+ dict[str, Tensor]
1165
+ The features for model training.
1166
+
1167
+ """
1168
+ # Compute random number of sequences
1169
+ if training and max_seqs is not None:
1170
+ max_seqs_batch = np.random.randint(1, max_seqs + 1) # noqa: NPY002
1171
+ else:
1172
+ max_seqs_batch = max_seqs
1173
+
1174
+ # Compute token features
1175
+ token_features = process_token_features(
1176
+ data,
1177
+ max_tokens,
1178
+ binder_pocket_conditioned_prop,
1179
+ binder_pocket_cutoff,
1180
+ binder_pocket_sampling_geometric_p,
1181
+ only_ligand_binder_pocket,
1182
+ inference_binder=inference_binder,
1183
+ inference_pocket=inference_pocket,
1184
+ )
1185
+
1186
+ # Compute atom features
1187
+ atom_features = process_atom_features(
1188
+ data,
1189
+ atoms_per_window_queries,
1190
+ min_dist,
1191
+ max_dist,
1192
+ num_bins,
1193
+ max_atoms,
1194
+ max_tokens,
1195
+ )
1196
+
1197
+ # Compute MSA features
1198
+ msa_features = process_msa_features(
1199
+ data,
1200
+ max_seqs_batch,
1201
+ max_seqs,
1202
+ max_tokens,
1203
+ pad_to_max_seqs,
1204
+ )
1205
+
1206
+ # Compute symmetry features
1207
+ symmetry_features = {}
1208
+ if compute_symmetries:
1209
+ symmetry_features = process_symmetry_features(data, symmetries)
1210
+
1211
+ # Compute constraint features
1212
+ residue_constraint_features = {}
1213
+ chain_constraint_features = {}
1214
+ if compute_constraint_features:
1215
+ residue_constraint_features = process_residue_constraint_features(data)
1216
+ chain_constraint_features = process_chain_feature_constraints(data)
1217
+
1218
+ return {
1219
+ **token_features,
1220
+ **atom_features,
1221
+ **msa_features,
1222
+ **symmetry_features,
1223
+ **residue_constraint_features,
1224
+ **chain_constraint_features,
1225
+ }
protify/FastPLMs/boltz/src/boltz/data/feature/featurizerv2.py ADDED
@@ -0,0 +1,2354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+ from collections import deque
4
+ import numba
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+ import rdkit.Chem.Descriptors
8
+ import torch
9
+ from numba import types
10
+ from rdkit.Chem import Mol
11
+ from scipy.spatial.distance import cdist
12
+ from torch import Tensor, from_numpy
13
+ from torch.nn.functional import one_hot
14
+
15
+ from boltz.data import const
16
+ from boltz.data.mol import (
17
+ get_amino_acids_symmetries,
18
+ get_chain_symmetries,
19
+ get_ligand_symmetries,
20
+ get_symmetries,
21
+ )
22
+ from boltz.data.pad import pad_dim
23
+ from boltz.data.types import (
24
+ MSA,
25
+ MSADeletion,
26
+ MSAResidue,
27
+ MSASequence,
28
+ TemplateInfo,
29
+ Tokenized,
30
+ )
31
+ from boltz.model.modules.utils import center_random_augmentation
32
+
33
+ ####################################################################################################
34
+ # HELPERS
35
+ ####################################################################################################
36
+
37
+
38
+ def convert_atom_name(name: str) -> tuple[int, int, int, int]:
39
+ """Convert an atom name to a standard format.
40
+
41
+ Parameters
42
+ ----------
43
+ name : str
44
+ The atom name.
45
+
46
+ Returns
47
+ -------
48
+ tuple[int, int, int, int]
49
+ The converted atom name.
50
+
51
+ """
52
+ name = str(name).strip()
53
+ name = [ord(c) - 32 for c in name]
54
+ name = name + [0] * (4 - len(name))
55
+ return tuple(name)
56
+
57
+
58
+ def sample_d(
59
+ min_d: float,
60
+ max_d: float,
61
+ n_samples: int,
62
+ random: np.random.Generator,
63
+ ) -> np.ndarray:
64
+ """Generate samples from a 1/d distribution between min_d and max_d.
65
+
66
+ Parameters
67
+ ----------
68
+ min_d : float
69
+ Minimum value of d
70
+ max_d : float
71
+ Maximum value of d
72
+ n_samples : int
73
+ Number of samples to generate
74
+ random : numpy.random.Generator
75
+ Random number generator
76
+
77
+ Returns
78
+ -------
79
+ numpy.ndarray
80
+ Array of samples drawn from the distribution
81
+
82
+ Notes
83
+ -----
84
+ The probability density function is:
85
+ f(d) = 1/(d * ln(max_d/min_d)) for d in [min_d, max_d]
86
+
87
+ The inverse CDF transform is:
88
+ d = min_d * (max_d/min_d)**u where u ~ Uniform(0,1)
89
+
90
+ """
91
+ # Generate n_samples uniform random numbers in [0, 1]
92
+ u = random.random(n_samples)
93
+ # Transform u using the inverse CDF
94
+ return min_d * (max_d / min_d) ** u
95
+
96
+
97
+ def compute_frames_nonpolymer(
98
+ data: Tokenized,
99
+ coords,
100
+ resolved_mask,
101
+ atom_to_token,
102
+ frame_data: list,
103
+ resolved_frame_data: list,
104
+ ) -> tuple[list, list]:
105
+ """Get the frames for non-polymer tokens.
106
+
107
+ Parameters
108
+ ----------
109
+ data : Tokenized
110
+ The input data to the model.
111
+ frame_data : list
112
+ The frame data.
113
+ resolved_frame_data : list
114
+ The resolved frame data.
115
+
116
+ Returns
117
+ -------
118
+ tuple[list, list]
119
+ The frame data and resolved frame data.
120
+
121
+ """
122
+ frame_data = np.array(frame_data)
123
+ resolved_frame_data = np.array(resolved_frame_data)
124
+ asym_id_token = data.tokens["asym_id"]
125
+ asym_id_atom = data.tokens["asym_id"][atom_to_token]
126
+ token_idx = 0
127
+ atom_idx = 0
128
+ for id in np.unique(data.tokens["asym_id"]):
129
+ mask_chain_token = asym_id_token == id
130
+ mask_chain_atom = asym_id_atom == id
131
+ num_tokens = mask_chain_token.sum()
132
+ num_atoms = mask_chain_atom.sum()
133
+ if (
134
+ data.tokens[token_idx]["mol_type"] != const.chain_type_ids["NONPOLYMER"]
135
+ or num_atoms < 3 # noqa: PLR2004
136
+ ):
137
+ token_idx += num_tokens
138
+ atom_idx += num_atoms
139
+ continue
140
+ dist_mat = (
141
+ (
142
+ coords.reshape(-1, 3)[mask_chain_atom][:, None, :]
143
+ - coords.reshape(-1, 3)[mask_chain_atom][None, :, :]
144
+ )
145
+ ** 2
146
+ ).sum(-1) ** 0.5
147
+ resolved_pair = 1 - (
148
+ resolved_mask[mask_chain_atom][None, :]
149
+ * resolved_mask[mask_chain_atom][:, None]
150
+ ).astype(np.float32)
151
+ resolved_pair[resolved_pair == 1] = math.inf
152
+ indices = np.argsort(dist_mat + resolved_pair, axis=1)
153
+ frames = (
154
+ np.concatenate(
155
+ [
156
+ indices[:, 1:2],
157
+ indices[:, 0:1],
158
+ indices[:, 2:3],
159
+ ],
160
+ axis=1,
161
+ )
162
+ + atom_idx
163
+ )
164
+ frame_data[token_idx : token_idx + num_atoms, :] = frames
165
+ resolved_frame_data[token_idx : token_idx + num_atoms] = resolved_mask[
166
+ frames
167
+ ].all(axis=1)
168
+ token_idx += num_tokens
169
+ atom_idx += num_atoms
170
+ frames_expanded = coords.reshape(-1, 3)[frame_data]
171
+
172
+ mask_collinear = compute_collinear_mask(
173
+ frames_expanded[:, 1] - frames_expanded[:, 0],
174
+ frames_expanded[:, 1] - frames_expanded[:, 2],
175
+ )
176
+ return frame_data, resolved_frame_data & mask_collinear
177
+
178
+
179
+ def compute_collinear_mask(v1, v2):
180
+ norm1 = np.linalg.norm(v1, axis=1, keepdims=True)
181
+ norm2 = np.linalg.norm(v2, axis=1, keepdims=True)
182
+ v1 = v1 / (norm1 + 1e-6)
183
+ v2 = v2 / (norm2 + 1e-6)
184
+ mask_angle = np.abs(np.sum(v1 * v2, axis=1)) < 0.9063
185
+ mask_overlap1 = norm1.reshape(-1) > 1e-2
186
+ mask_overlap2 = norm2.reshape(-1) > 1e-2
187
+ return mask_angle & mask_overlap1 & mask_overlap2
188
+
189
+
190
+ def dummy_msa(residues: np.ndarray) -> MSA:
191
+ """Create a dummy MSA for a chain.
192
+
193
+ Parameters
194
+ ----------
195
+ residues : np.ndarray
196
+ The residues for the chain.
197
+
198
+ Returns
199
+ -------
200
+ MSA
201
+ The dummy MSA.
202
+
203
+ """
204
+ residues = [res["res_type"] for res in residues]
205
+ deletions = []
206
+ sequences = [(0, -1, 0, len(residues), 0, 0)]
207
+ return MSA(
208
+ residues=np.array(residues, dtype=MSAResidue),
209
+ deletions=np.array(deletions, dtype=MSADeletion),
210
+ sequences=np.array(sequences, dtype=MSASequence),
211
+ )
212
+
213
+
214
+ def construct_paired_msa( # noqa: C901, PLR0915, PLR0912
215
+ data: Tokenized,
216
+ random: np.random.Generator,
217
+ max_seqs: int,
218
+ max_pairs: int = 8192,
219
+ max_total: int = 16384,
220
+ random_subset: bool = False,
221
+ ) -> tuple[Tensor, Tensor, Tensor]:
222
+ """Pair the MSA data.
223
+
224
+ Parameters
225
+ ----------
226
+ data : Tokenized
227
+ The input data to the model.
228
+
229
+ Returns
230
+ -------
231
+ Tensor
232
+ The MSA data.
233
+ Tensor
234
+ The deletion data.
235
+ Tensor
236
+ Mask indicating paired sequences.
237
+
238
+ """
239
+ # Get unique chains (ensuring monotonicity in the order)
240
+ assert np.all(np.diff(data.tokens["asym_id"], n=1) >= 0)
241
+ chain_ids = np.unique(data.tokens["asym_id"])
242
+
243
+ # Get relevant MSA, and create a dummy for chains without
244
+ msa: dict[int, MSA] = {}
245
+ for chain_id in chain_ids:
246
+ # Get input sequence
247
+ chain = data.structure.chains[chain_id]
248
+ res_start = chain["res_idx"]
249
+ res_end = res_start + chain["res_num"]
250
+ residues = data.structure.residues[res_start:res_end]
251
+
252
+ # Check if we have an MSA, and that the
253
+ # first sequence matches the input sequence
254
+ if chain_id in data.msa:
255
+ # Set the MSA
256
+ msa[chain_id] = data.msa[chain_id]
257
+
258
+ # Run length and residue type checks
259
+ first = data.msa[chain_id].sequences[0]
260
+ first_start = first["res_start"]
261
+ first_end = first["res_end"]
262
+ msa_residues = data.msa[chain_id].residues
263
+ first_residues = msa_residues[first_start:first_end]
264
+
265
+ warning = "Warning: MSA does not match input sequence, creating dummy."
266
+ if len(residues) == len(first_residues):
267
+ # If there is a mismatch, check if it is between MET & UNK
268
+ # If so, replace the first sequence with the input sequence.
269
+ # Otherwise, replace with a dummy MSA for this chain.
270
+ mismatches = residues["res_type"] != first_residues["res_type"]
271
+ if mismatches.sum().item():
272
+ idx = np.where(mismatches)[0]
273
+ is_met = residues["res_type"][idx] == const.token_ids["MET"]
274
+ is_unk = residues["res_type"][idx] == const.token_ids["UNK"]
275
+ is_msa_unk = (
276
+ first_residues["res_type"][idx] == const.token_ids["UNK"]
277
+ )
278
+ if (np.all(is_met) and np.all(is_msa_unk)) or np.all(is_unk):
279
+ msa_residues[first_start:first_end]["res_type"] = residues[
280
+ "res_type"
281
+ ]
282
+ else:
283
+ print(
284
+ warning,
285
+ "1",
286
+ residues["res_type"],
287
+ first_residues["res_type"],
288
+ data.record.id,
289
+ )
290
+ msa[chain_id] = dummy_msa(residues)
291
+ else:
292
+ print(
293
+ warning,
294
+ "2",
295
+ residues["res_type"],
296
+ first_residues["res_type"],
297
+ data.record.id,
298
+ )
299
+ msa[chain_id] = dummy_msa(residues)
300
+ else:
301
+ msa[chain_id] = dummy_msa(residues)
302
+
303
+ # Map taxonomies to (chain_id, seq_idx)
304
+ taxonomy_map: dict[str, list] = {}
305
+ for chain_id, chain_msa in msa.items():
306
+ sequences = chain_msa.sequences
307
+ sequences = sequences[sequences["taxonomy"] != -1]
308
+ for sequence in sequences:
309
+ seq_idx = sequence["seq_idx"]
310
+ taxon = sequence["taxonomy"]
311
+ taxonomy_map.setdefault(taxon, []).append((chain_id, seq_idx))
312
+
313
+ # Remove taxonomies with only one sequence and sort by the
314
+ # number of chain_id present in each of the taxonomies
315
+ taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1}
316
+ taxonomy_map = sorted(
317
+ taxonomy_map.items(),
318
+ key=lambda x: len({c for c, _ in x[1]}),
319
+ reverse=True,
320
+ )
321
+
322
+ # Keep track of the sequences available per chain, keeping the original
323
+ # order of the sequences in the MSA to favor the best matching sequences
324
+ visited = {(c, s) for c, items in taxonomy_map for s in items}
325
+ available = {}
326
+ for c in chain_ids:
327
+ available[c] = deque(
328
+ i for i in range(1, len(msa[c].sequences)) if (c, i) not in visited
329
+ )
330
+
331
+ # Create sequence pairs
332
+ is_paired = []
333
+ pairing = []
334
+
335
+ # Start with the first sequence for each chain
336
+ is_paired.append({c: 1 for c in chain_ids})
337
+ pairing.append({c: 0 for c in chain_ids})
338
+
339
+ # Then add up to 8191 paired rows
340
+ for _, pairs in taxonomy_map:
341
+ # Group occurences by chain_id in case we have multiple
342
+ # sequences from the same chain and same taxonomy
343
+ chain_occurences = {}
344
+ for chain_id, seq_idx in pairs:
345
+ chain_occurences.setdefault(chain_id, []).append(seq_idx)
346
+
347
+ # We create as many pairings as the maximum number of occurences
348
+ max_occurences = max(len(v) for v in chain_occurences.values())
349
+ for i in range(max_occurences):
350
+ row_pairing = {}
351
+ row_is_paired = {}
352
+
353
+ # Add the chains present in the taxonomy
354
+ for chain_id, seq_idxs in chain_occurences.items():
355
+ # Roll over the sequence index to maximize diversity
356
+ idx = i % len(seq_idxs)
357
+ seq_idx = seq_idxs[idx]
358
+
359
+ # Add the sequence to the pairing
360
+ row_pairing[chain_id] = seq_idx
361
+ row_is_paired[chain_id] = 1
362
+
363
+ # Add any missing chains
364
+ for chain_id in chain_ids:
365
+ if chain_id not in row_pairing:
366
+ row_is_paired[chain_id] = 0
367
+ if available[chain_id]:
368
+ # Add the next available sequence
369
+ row_pairing[chain_id] = available[chain_id].popleft()
370
+ else:
371
+ # No more sequences available, we place a gap
372
+ row_pairing[chain_id] = -1
373
+
374
+ pairing.append(row_pairing)
375
+ is_paired.append(row_is_paired)
376
+
377
+ # Break if we have enough pairs
378
+ if len(pairing) >= max_pairs:
379
+ break
380
+
381
+ # Break if we have enough pairs
382
+ if len(pairing) >= max_pairs:
383
+ break
384
+
385
+ # Now add up to 16384 unpaired rows total
386
+ max_left = max(len(v) for v in available.values())
387
+ for _ in range(min(max_total - len(pairing), max_left)):
388
+ row_pairing = {}
389
+ row_is_paired = {}
390
+ for chain_id in chain_ids:
391
+ row_is_paired[chain_id] = 0
392
+ if available[chain_id]:
393
+ # Add the next available sequence
394
+ row_pairing[chain_id] = available[chain_id].popleft()
395
+ else:
396
+ # No more sequences available, we place a gap
397
+ row_pairing[chain_id] = -1
398
+
399
+ pairing.append(row_pairing)
400
+ is_paired.append(row_is_paired)
401
+
402
+ # Break if we have enough sequences
403
+ if len(pairing) >= max_total:
404
+ break
405
+
406
+ # Randomly sample a subset of the pairs
407
+ # ensuring the first row is always present
408
+ if random_subset:
409
+ num_seqs = len(pairing)
410
+ if num_seqs > max_seqs:
411
+ indices = random.choice(
412
+ np.arange(1, num_seqs), size=max_seqs - 1, replace=False
413
+ ) # noqa: NPY002
414
+ pairing = [pairing[0]] + [pairing[i] for i in indices]
415
+ is_paired = [is_paired[0]] + [is_paired[i] for i in indices]
416
+ else:
417
+ # Deterministic downsample to max_seqs
418
+ pairing = pairing[:max_seqs]
419
+ is_paired = is_paired[:max_seqs]
420
+
421
+ # Map (chain_id, seq_idx, res_idx) to deletion
422
+ deletions = numba.typed.Dict.empty(
423
+ key_type=numba.types.Tuple(
424
+ [numba.types.int64, numba.types.int64, numba.types.int64]),
425
+ value_type=numba.types.int64
426
+ )
427
+ for chain_id, chain_msa in msa.items():
428
+ chain_deletions = chain_msa.deletions
429
+ for sequence in chain_msa.sequences:
430
+ seq_idx = sequence["seq_idx"]
431
+ del_start = sequence["del_start"]
432
+ del_end = sequence["del_end"]
433
+ chain_deletions = chain_deletions[del_start:del_end]
434
+ for deletion_data in chain_deletions:
435
+ res_idx = deletion_data["res_idx"]
436
+ deletion_values = deletion_data["deletion"]
437
+ deletions[(chain_id, seq_idx, res_idx)] = deletion_values
438
+
439
+ # Add all the token MSA data
440
+ msa_data, del_data, paired_data = prepare_msa_arrays(
441
+ data.tokens, pairing, is_paired, deletions, msa
442
+ )
443
+
444
+ msa_data = torch.tensor(msa_data, dtype=torch.long)
445
+ del_data = torch.tensor(del_data, dtype=torch.float)
446
+ paired_data = torch.tensor(paired_data, dtype=torch.float)
447
+
448
+ return msa_data, del_data, paired_data
449
+
450
+
451
+ def prepare_msa_arrays(
452
+ tokens,
453
+ pairing: list[dict[int, int]],
454
+ is_paired: list[dict[int, int]],
455
+ deletions: dict[tuple[int, int, int], int],
456
+ msa: dict[int, MSA],
457
+ ) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
458
+ """Reshape data to play nicely with numba jit."""
459
+ token_asym_ids_arr = np.array([t["asym_id"] for t in tokens], dtype=np.int64)
460
+ token_res_idxs_arr = np.array([t["res_idx"] for t in tokens], dtype=np.int64)
461
+
462
+ chain_ids = sorted(msa.keys())
463
+
464
+ # chain_ids are not necessarily contiguous (e.g. they might be 0, 24, 25).
465
+ # This allows us to look up a chain_id by it's index in the chain_ids list.
466
+ chain_id_to_idx = {chain_id: i for i, chain_id in enumerate(chain_ids)}
467
+ token_asym_ids_idx_arr = np.array(
468
+ [chain_id_to_idx[asym_id] for asym_id in token_asym_ids_arr], dtype=np.int64
469
+ )
470
+
471
+ pairing_arr = np.zeros((len(pairing), len(chain_ids)), dtype=np.int64)
472
+ is_paired_arr = np.zeros((len(is_paired), len(chain_ids)), dtype=np.int64)
473
+
474
+ for i, row_pairing in enumerate(pairing):
475
+ for chain_id in chain_ids:
476
+ pairing_arr[i, chain_id_to_idx[chain_id]] = row_pairing[chain_id]
477
+
478
+ for i, row_is_paired in enumerate(is_paired):
479
+ for chain_id in chain_ids:
480
+ is_paired_arr[i, chain_id_to_idx[chain_id]] = row_is_paired[chain_id]
481
+
482
+ max_seq_len = max(len(msa[chain_id].sequences) for chain_id in chain_ids)
483
+
484
+ # we want res_start from sequences
485
+ msa_sequences = np.full((len(chain_ids), max_seq_len), -1, dtype=np.int64)
486
+ for chain_id in chain_ids:
487
+ for i, seq in enumerate(msa[chain_id].sequences):
488
+ msa_sequences[chain_id_to_idx[chain_id], i] = seq["res_start"]
489
+
490
+ max_residues_len = max(len(msa[chain_id].residues) for chain_id in chain_ids)
491
+ msa_residues = np.full((len(chain_ids), max_residues_len), -1, dtype=np.int64)
492
+ for chain_id in chain_ids:
493
+ residues = msa[chain_id].residues.astype(np.int64)
494
+ idxs = np.arange(len(residues))
495
+ chain_idx = chain_id_to_idx[chain_id]
496
+ msa_residues[chain_idx, idxs] = residues
497
+
498
+ return _prepare_msa_arrays_inner(
499
+ token_asym_ids_arr,
500
+ token_res_idxs_arr,
501
+ token_asym_ids_idx_arr,
502
+ pairing_arr,
503
+ is_paired_arr,
504
+ deletions,
505
+ msa_sequences,
506
+ msa_residues,
507
+ const.token_ids["-"],
508
+ )
509
+
510
+
511
+ deletions_dict_type = types.DictType(types.UniTuple(types.int64, 3), types.int64)
512
+
513
+
514
+ @numba.njit(
515
+ [
516
+ types.Tuple(
517
+ (
518
+ types.int64[:, ::1], # msa_data
519
+ types.int64[:, ::1], # del_data
520
+ types.int64[:, ::1], # paired_data
521
+ )
522
+ )(
523
+ types.int64[::1], # token_asym_ids
524
+ types.int64[::1], # token_res_idxs
525
+ types.int64[::1], # token_asym_ids_idx
526
+ types.int64[:, ::1], # pairing
527
+ types.int64[:, ::1], # is_paired
528
+ deletions_dict_type, # deletions
529
+ types.int64[:, ::1], # msa_sequences
530
+ types.int64[:, ::1], # msa_residues
531
+ types.int64, # gap_token
532
+ )
533
+ ],
534
+ cache=True,
535
+ )
536
+ def _prepare_msa_arrays_inner(
537
+ token_asym_ids: npt.NDArray[np.int64],
538
+ token_res_idxs: npt.NDArray[np.int64],
539
+ token_asym_ids_idx: npt.NDArray[np.int64],
540
+ pairing: npt.NDArray[np.int64],
541
+ is_paired: npt.NDArray[np.int64],
542
+ deletions: dict[tuple[int, int, int], int],
543
+ msa_sequences: npt.NDArray[np.int64],
544
+ msa_residues: npt.NDArray[np.int64],
545
+ gap_token: int,
546
+ ) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
547
+ n_tokens = len(token_asym_ids)
548
+ n_pairs = len(pairing)
549
+ msa_data = np.full((n_tokens, n_pairs), gap_token, dtype=np.int64)
550
+ paired_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
551
+ del_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
552
+
553
+ # Add all the token MSA data
554
+ for token_idx in range(n_tokens):
555
+ chain_id_idx = token_asym_ids_idx[token_idx]
556
+ chain_id = token_asym_ids[token_idx]
557
+ res_idx = token_res_idxs[token_idx]
558
+
559
+ for pair_idx in range(n_pairs):
560
+ seq_idx = pairing[pair_idx, chain_id_idx]
561
+ paired_data[token_idx, pair_idx] = is_paired[pair_idx, chain_id_idx]
562
+
563
+ # Add residue type
564
+ if seq_idx != -1:
565
+ res_start = msa_sequences[chain_id_idx, seq_idx]
566
+ res_type = msa_residues[chain_id_idx, res_start + res_idx]
567
+ k = (chain_id, seq_idx, res_idx)
568
+ if k in deletions:
569
+ del_data[token_idx, pair_idx] = deletions[k]
570
+ msa_data[token_idx, pair_idx] = res_type
571
+
572
+ return msa_data, del_data, paired_data
573
+
574
+
575
+ ####################################################################################################
576
+ # FEATURES
577
+ ####################################################################################################
578
+
579
+
580
+ def select_subset_from_mask(mask, p, random: np.random.Generator) -> np.ndarray:
581
+ num_true = np.sum(mask)
582
+ v = random.geometric(p) + 1
583
+ k = min(v, num_true)
584
+
585
+ true_indices = np.where(mask)[0]
586
+
587
+ # Randomly select k indices from the true_indices
588
+ selected_indices = random.choice(true_indices, size=k, replace=False)
589
+
590
+ new_mask = np.zeros_like(mask)
591
+ new_mask[selected_indices] = 1
592
+
593
+ return new_mask
594
+
595
+
596
+ def get_range_bin(value: float, range_dict: dict[tuple[float, float], int], default=0):
597
+ """Get the bin of a value given a range dictionary."""
598
+ value = float(value)
599
+ for k, idx in range_dict.items():
600
+ if k == "other":
601
+ continue
602
+ low, high = k
603
+ if low <= value < high:
604
+ return idx
605
+ return default
606
+
607
+
608
+ def process_token_features( # noqa: C901, PLR0915, PLR0912
609
+ data: Tokenized,
610
+ random: np.random.Generator,
611
+ max_tokens: Optional[int] = None,
612
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
613
+ contact_conditioned_prop: Optional[float] = 0.0,
614
+ binder_pocket_cutoff_min: Optional[float] = 4.0,
615
+ binder_pocket_cutoff_max: Optional[float] = 20.0,
616
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
617
+ only_ligand_binder_pocket: Optional[bool] = False,
618
+ only_pp_contact: Optional[bool] = False,
619
+ inference_pocket_constraints: Optional[
620
+ list[tuple[int, list[tuple[int, int]], float]]
621
+ ] = False,
622
+ inference_contact_constraints: Optional[
623
+ list[tuple[tuple[int, int], tuple[int, int], float]]
624
+ ] = False,
625
+ override_method: Optional[str] = None,
626
+ ) -> dict[str, Tensor]:
627
+ """Get the token features.
628
+
629
+ Parameters
630
+ ----------
631
+ data : Tokenized
632
+ The input data to the model.
633
+ max_tokens : int
634
+ The maximum number of tokens.
635
+
636
+ Returns
637
+ -------
638
+ dict[str, Tensor]
639
+ The token features.
640
+
641
+ """
642
+ # Token data
643
+ token_data = data.tokens
644
+ token_bonds = data.bonds
645
+
646
+ # Token core features
647
+ token_index = torch.arange(len(token_data), dtype=torch.long)
648
+ residue_index = from_numpy(token_data["res_idx"]).long()
649
+ asym_id = from_numpy(token_data["asym_id"]).long()
650
+ entity_id = from_numpy(token_data["entity_id"]).long()
651
+ sym_id = from_numpy(token_data["sym_id"]).long()
652
+ mol_type = from_numpy(token_data["mol_type"]).long()
653
+ res_type = from_numpy(token_data["res_type"]).long()
654
+ res_type = one_hot(res_type, num_classes=const.num_tokens)
655
+ disto_center = from_numpy(token_data["disto_coords"])
656
+ modified = from_numpy(token_data["modified"]).long() # float()
657
+ cyclic_period = from_numpy(token_data["cyclic_period"].copy())
658
+ affinity_mask = from_numpy(token_data["affinity_mask"]).float()
659
+
660
+ ## Conditioning features ##
661
+ method = (
662
+ np.zeros(len(token_data))
663
+ + const.method_types_ids[
664
+ (
665
+ "x-ray diffraction"
666
+ if override_method is None
667
+ else override_method.lower()
668
+ )
669
+ ]
670
+ )
671
+ if data.record is not None:
672
+ if (
673
+ override_method is None
674
+ and data.record.structure.method is not None
675
+ and data.record.structure.method.lower() in const.method_types_ids
676
+ ):
677
+ method = (method * 0) + const.method_types_ids[
678
+ data.record.structure.method.lower()
679
+ ]
680
+
681
+ method_feature = from_numpy(method).long()
682
+
683
+ # Token mask features
684
+ pad_mask = torch.ones(len(token_data), dtype=torch.float)
685
+ resolved_mask = from_numpy(token_data["resolved_mask"]).float()
686
+ disto_mask = from_numpy(token_data["disto_mask"]).float()
687
+
688
+ # Token bond features
689
+ if max_tokens is not None:
690
+ pad_len = max_tokens - len(token_data)
691
+ num_tokens = max_tokens if pad_len > 0 else len(token_data)
692
+ else:
693
+ num_tokens = len(token_data)
694
+
695
+ tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)}
696
+ bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float)
697
+ bonds_type = torch.zeros(num_tokens, num_tokens, dtype=torch.long)
698
+ for token_bond in token_bonds:
699
+ token_1 = tok_to_idx[token_bond["token_1"]]
700
+ token_2 = tok_to_idx[token_bond["token_2"]]
701
+ bonds[token_1, token_2] = 1
702
+ bonds[token_2, token_1] = 1
703
+ bond_type = token_bond["type"]
704
+ bonds_type[token_1, token_2] = bond_type
705
+ bonds_type[token_2, token_1] = bond_type
706
+
707
+ bonds = bonds.unsqueeze(-1)
708
+
709
+ # Pocket conditioned feature
710
+ contact_conditioning = (
711
+ np.zeros((len(token_data), len(token_data)))
712
+ + const.contact_conditioning_info["UNSELECTED"]
713
+ )
714
+ contact_threshold = np.zeros((len(token_data), len(token_data)))
715
+
716
+ if inference_pocket_constraints is not None:
717
+ for binder, contacts, max_distance, force in inference_pocket_constraints:
718
+ binder_mask = token_data["asym_id"] == binder
719
+
720
+ for idx, token in enumerate(token_data):
721
+ if (
722
+ token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
723
+ and (token["asym_id"], token["res_idx"]) in contacts
724
+ ) or (
725
+ token["mol_type"] == const.chain_type_ids["NONPOLYMER"]
726
+ and (token["asym_id"], token["atom_idx"]) in contacts
727
+ ):
728
+ contact_conditioning[binder_mask, idx] = (
729
+ const.contact_conditioning_info["BINDER>POCKET"]
730
+ )
731
+ contact_conditioning[idx, binder_mask] = (
732
+ const.contact_conditioning_info["POCKET>BINDER"]
733
+ )
734
+ contact_threshold[binder_mask, idx] = max_distance
735
+ contact_threshold[idx, binder_mask] = max_distance
736
+
737
+ if inference_contact_constraints is not None:
738
+ for token1, token2, max_distance, force in inference_contact_constraints:
739
+ for idx1, _token1 in enumerate(token_data):
740
+ if (
741
+ _token1["mol_type"] != const.chain_type_ids["NONPOLYMER"]
742
+ and (_token1["asym_id"], _token1["res_idx"]) == token1
743
+ ) or (
744
+ _token1["mol_type"] == const.chain_type_ids["NONPOLYMER"]
745
+ and (_token1["asym_id"], _token1["atom_idx"]) == token1
746
+ ):
747
+ for idx2, _token2 in enumerate(token_data):
748
+ if (
749
+ _token2["mol_type"] != const.chain_type_ids["NONPOLYMER"]
750
+ and (_token2["asym_id"], _token2["res_idx"]) == token2
751
+ ) or (
752
+ _token2["mol_type"] == const.chain_type_ids["NONPOLYMER"]
753
+ and (_token2["asym_id"], _token2["atom_idx"]) == token2
754
+ ):
755
+ contact_conditioning[idx1, idx2] = (
756
+ const.contact_conditioning_info["CONTACT"]
757
+ )
758
+ contact_conditioning[idx2, idx1] = (
759
+ const.contact_conditioning_info["CONTACT"]
760
+ )
761
+ contact_threshold[idx1, idx2] = max_distance
762
+ contact_threshold[idx2, idx1] = max_distance
763
+ break
764
+ break
765
+
766
+ if binder_pocket_conditioned_prop > 0.0:
767
+ # choose as binder a random ligand in the crop, if there are no ligands select a protein chain
768
+ binder_asym_ids = np.unique(
769
+ token_data["asym_id"][
770
+ token_data["mol_type"] == const.chain_type_ids["NONPOLYMER"]
771
+ ]
772
+ )
773
+
774
+ if len(binder_asym_ids) == 0:
775
+ if not only_ligand_binder_pocket:
776
+ binder_asym_ids = np.unique(token_data["asym_id"])
777
+
778
+ while random.random() < binder_pocket_conditioned_prop:
779
+ if len(binder_asym_ids) == 0:
780
+ break
781
+
782
+ pocket_asym_id = random.choice(binder_asym_ids)
783
+ binder_asym_ids = binder_asym_ids[binder_asym_ids != pocket_asym_id]
784
+
785
+ binder_pocket_cutoff = sample_d(
786
+ min_d=binder_pocket_cutoff_min,
787
+ max_d=binder_pocket_cutoff_max,
788
+ n_samples=1,
789
+ random=random,
790
+ )
791
+
792
+ binder_mask = token_data["asym_id"] == pocket_asym_id
793
+
794
+ binder_coords = []
795
+ for token in token_data:
796
+ if token["asym_id"] == pocket_asym_id:
797
+ _coords = data.structure.atoms["coords"][
798
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
799
+ ]
800
+ _is_present = data.structure.atoms["is_present"][
801
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
802
+ ]
803
+ binder_coords.append(_coords[_is_present])
804
+ binder_coords = np.concatenate(binder_coords, axis=0)
805
+
806
+ # find the tokens in the pocket
807
+ token_dist = np.zeros(len(token_data)) + 1000
808
+ for i, token in enumerate(token_data):
809
+ if (
810
+ token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
811
+ and token["asym_id"] != pocket_asym_id
812
+ and token["resolved_mask"] == 1
813
+ ):
814
+ token_coords = data.structure.atoms["coords"][
815
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
816
+ ]
817
+ token_is_present = data.structure.atoms["is_present"][
818
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
819
+ ]
820
+ token_coords = token_coords[token_is_present]
821
+
822
+ # find chain and apply chain transformation
823
+ for chain in data.structure.chains:
824
+ if chain["asym_id"] == token["asym_id"]:
825
+ break
826
+
827
+ token_dist[i] = np.min(
828
+ np.linalg.norm(
829
+ token_coords[:, None, :] - binder_coords[None, :, :],
830
+ axis=-1,
831
+ )
832
+ )
833
+
834
+ pocket_mask = token_dist < binder_pocket_cutoff
835
+
836
+ if np.sum(pocket_mask) > 0:
837
+ if binder_pocket_sampling_geometric_p > 0.0:
838
+ # select a subset of the pocket, according
839
+ # to a geometric distribution with one as minimum
840
+ pocket_mask = select_subset_from_mask(
841
+ pocket_mask,
842
+ binder_pocket_sampling_geometric_p,
843
+ random,
844
+ )
845
+
846
+ contact_conditioning[np.ix_(binder_mask, pocket_mask)] = (
847
+ const.contact_conditioning_info["BINDER>POCKET"]
848
+ )
849
+ contact_conditioning[np.ix_(pocket_mask, binder_mask)] = (
850
+ const.contact_conditioning_info["POCKET>BINDER"]
851
+ )
852
+ contact_threshold[np.ix_(binder_mask, pocket_mask)] = (
853
+ binder_pocket_cutoff
854
+ )
855
+ contact_threshold[np.ix_(pocket_mask, binder_mask)] = (
856
+ binder_pocket_cutoff
857
+ )
858
+
859
+ # Contact conditioning feature
860
+ if contact_conditioned_prop > 0.0:
861
+ while random.random() < contact_conditioned_prop:
862
+ contact_cutoff = sample_d(
863
+ min_d=binder_pocket_cutoff_min,
864
+ max_d=binder_pocket_cutoff_max,
865
+ n_samples=1,
866
+ random=random,
867
+ )
868
+ if only_pp_contact:
869
+ chain_asym_ids = np.unique(
870
+ token_data["asym_id"][
871
+ token_data["mol_type"] == const.chain_type_ids["PROTEIN"]
872
+ ]
873
+ )
874
+ else:
875
+ chain_asym_ids = np.unique(token_data["asym_id"])
876
+
877
+ if len(chain_asym_ids) > 1:
878
+ chain_asym_id = random.choice(chain_asym_ids)
879
+
880
+ chain_coords = []
881
+ for token in token_data:
882
+ if token["asym_id"] == chain_asym_id:
883
+ _coords = data.structure.atoms["coords"][
884
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
885
+ ]
886
+ _is_present = data.structure.atoms["is_present"][
887
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
888
+ ]
889
+ chain_coords.append(_coords[_is_present])
890
+ chain_coords = np.concatenate(chain_coords, axis=0)
891
+
892
+ # find contacts in other chains
893
+ possible_other_chains = []
894
+ for other_chain_id in chain_asym_ids[chain_asym_ids != chain_asym_id]:
895
+ for token in token_data:
896
+ if token["asym_id"] == other_chain_id:
897
+ _coords = data.structure.atoms["coords"][
898
+ token["atom_idx"] : token["atom_idx"]
899
+ + token["atom_num"]
900
+ ]
901
+ _is_present = data.structure.atoms["is_present"][
902
+ token["atom_idx"] : token["atom_idx"]
903
+ + token["atom_num"]
904
+ ]
905
+ if _is_present.sum() == 0:
906
+ continue
907
+ token_coords = _coords[_is_present]
908
+
909
+ # check minimum distance
910
+ if (
911
+ np.min(cdist(chain_coords, token_coords))
912
+ < contact_cutoff
913
+ ):
914
+ possible_other_chains.append(other_chain_id)
915
+ break
916
+
917
+ if len(possible_other_chains) > 0:
918
+ other_chain_id = random.choice(possible_other_chains)
919
+
920
+ pairs = []
921
+ for token_1 in token_data:
922
+ if token_1["asym_id"] == chain_asym_id:
923
+ _coords = data.structure.atoms["coords"][
924
+ token_1["atom_idx"] : token_1["atom_idx"]
925
+ + token_1["atom_num"]
926
+ ]
927
+ _is_present = data.structure.atoms["is_present"][
928
+ token_1["atom_idx"] : token_1["atom_idx"]
929
+ + token_1["atom_num"]
930
+ ]
931
+ if _is_present.sum() == 0:
932
+ continue
933
+ token_1_coords = _coords[_is_present]
934
+
935
+ for token_2 in token_data:
936
+ if token_2["asym_id"] == other_chain_id:
937
+ _coords = data.structure.atoms["coords"][
938
+ token_2["atom_idx"] : token_2["atom_idx"]
939
+ + token_2["atom_num"]
940
+ ]
941
+ _is_present = data.structure.atoms["is_present"][
942
+ token_2["atom_idx"] : token_2["atom_idx"]
943
+ + token_2["atom_num"]
944
+ ]
945
+ if _is_present.sum() == 0:
946
+ continue
947
+ token_2_coords = _coords[_is_present]
948
+
949
+ if (
950
+ np.min(cdist(token_1_coords, token_2_coords))
951
+ < contact_cutoff
952
+ ):
953
+ pairs.append(
954
+ (token_1["token_idx"], token_2["token_idx"])
955
+ )
956
+
957
+ assert len(pairs) > 0
958
+
959
+ pair = random.choice(pairs)
960
+ token_1_mask = token_data["token_idx"] == pair[0]
961
+ token_2_mask = token_data["token_idx"] == pair[1]
962
+
963
+ contact_conditioning[np.ix_(token_1_mask, token_2_mask)] = (
964
+ const.contact_conditioning_info["CONTACT"]
965
+ )
966
+ contact_conditioning[np.ix_(token_2_mask, token_1_mask)] = (
967
+ const.contact_conditioning_info["CONTACT"]
968
+ )
969
+
970
+ elif not only_pp_contact:
971
+ # only one chain, find contacts within the chain with minimum residue distance
972
+ pairs = []
973
+ for token_1 in token_data:
974
+ _coords = data.structure.atoms["coords"][
975
+ token_1["atom_idx"] : token_1["atom_idx"] + token_1["atom_num"]
976
+ ]
977
+ _is_present = data.structure.atoms["is_present"][
978
+ token_1["atom_idx"] : token_1["atom_idx"] + token_1["atom_num"]
979
+ ]
980
+ if _is_present.sum() == 0:
981
+ continue
982
+ token_1_coords = _coords[_is_present]
983
+
984
+ for token_2 in token_data:
985
+ if np.abs(token_1["res_idx"] - token_2["res_idx"]) <= 8:
986
+ continue
987
+
988
+ _coords = data.structure.atoms["coords"][
989
+ token_2["atom_idx"] : token_2["atom_idx"]
990
+ + token_2["atom_num"]
991
+ ]
992
+ _is_present = data.structure.atoms["is_present"][
993
+ token_2["atom_idx"] : token_2["atom_idx"]
994
+ + token_2["atom_num"]
995
+ ]
996
+ if _is_present.sum() == 0:
997
+ continue
998
+ token_2_coords = _coords[_is_present]
999
+
1000
+ if (
1001
+ np.min(cdist(token_1_coords, token_2_coords))
1002
+ < contact_cutoff
1003
+ ):
1004
+ pairs.append((token_1["token_idx"], token_2["token_idx"]))
1005
+
1006
+ if len(pairs) > 0:
1007
+ pair = random.choice(pairs)
1008
+ token_1_mask = token_data["token_idx"] == pair[0]
1009
+ token_2_mask = token_data["token_idx"] == pair[1]
1010
+
1011
+ contact_conditioning[np.ix_(token_1_mask, token_2_mask)] = (
1012
+ const.contact_conditioning_info["CONTACT"]
1013
+ )
1014
+ contact_conditioning[np.ix_(token_2_mask, token_1_mask)] = (
1015
+ const.contact_conditioning_info["CONTACT"]
1016
+ )
1017
+
1018
+ if np.all(contact_conditioning == const.contact_conditioning_info["UNSELECTED"]):
1019
+ contact_conditioning = (
1020
+ contact_conditioning
1021
+ - const.contact_conditioning_info["UNSELECTED"]
1022
+ + const.contact_conditioning_info["UNSPECIFIED"]
1023
+ )
1024
+ contact_conditioning = from_numpy(contact_conditioning).long()
1025
+ contact_conditioning = one_hot(
1026
+ contact_conditioning, num_classes=len(const.contact_conditioning_info)
1027
+ )
1028
+ contact_threshold = from_numpy(contact_threshold).float()
1029
+
1030
+ # compute cyclic polymer mask
1031
+ cyclic_ids = {}
1032
+ for idx_chain, asym_id_iter in enumerate(data.structure.chains["asym_id"]):
1033
+ for connection in data.structure.bonds:
1034
+ if (
1035
+ idx_chain == connection["chain_1"] == connection["chain_2"]
1036
+ and data.structure.chains[connection["chain_1"]]["res_num"] > 2
1037
+ and connection["res_1"]
1038
+ != connection["res_2"] # Avoid same residue bonds!
1039
+ ):
1040
+ if (
1041
+ data.structure.chains[connection["chain_1"]]["res_num"]
1042
+ == (connection["res_2"] + 1)
1043
+ and connection["res_1"] == 0
1044
+ ) or (
1045
+ data.structure.chains[connection["chain_1"]]["res_num"]
1046
+ == (connection["res_1"] + 1)
1047
+ and connection["res_2"] == 0
1048
+ ):
1049
+ cyclic_ids[asym_id_iter] = data.structure.chains[
1050
+ connection["chain_1"]
1051
+ ]["res_num"]
1052
+ cyclic = from_numpy(
1053
+ np.array(
1054
+ [
1055
+ (cyclic_ids[asym_id_iter] if asym_id_iter in cyclic_ids else 0)
1056
+ for asym_id_iter in token_data["asym_id"]
1057
+ ]
1058
+ )
1059
+ ).float()
1060
+
1061
+ # cyclic period is either computed from the bonds or given as input flag
1062
+ cyclic_period = torch.maximum(cyclic, cyclic_period)
1063
+
1064
+ # Pad to max tokens if given
1065
+ if max_tokens is not None:
1066
+ pad_len = max_tokens - len(token_data)
1067
+ if pad_len > 0:
1068
+ token_index = pad_dim(token_index, 0, pad_len)
1069
+ residue_index = pad_dim(residue_index, 0, pad_len)
1070
+ asym_id = pad_dim(asym_id, 0, pad_len)
1071
+ entity_id = pad_dim(entity_id, 0, pad_len)
1072
+ sym_id = pad_dim(sym_id, 0, pad_len)
1073
+ mol_type = pad_dim(mol_type, 0, pad_len)
1074
+ res_type = pad_dim(res_type, 0, pad_len)
1075
+ disto_center = pad_dim(disto_center, 0, pad_len)
1076
+ pad_mask = pad_dim(pad_mask, 0, pad_len)
1077
+ resolved_mask = pad_dim(resolved_mask, 0, pad_len)
1078
+ disto_mask = pad_dim(disto_mask, 0, pad_len)
1079
+ contact_conditioning = pad_dim(contact_conditioning, 0, pad_len)
1080
+ contact_conditioning = pad_dim(contact_conditioning, 1, pad_len)
1081
+ contact_threshold = pad_dim(contact_threshold, 0, pad_len)
1082
+ contact_threshold = pad_dim(contact_threshold, 1, pad_len)
1083
+ method_feature = pad_dim(method_feature, 0, pad_len)
1084
+ modified = pad_dim(modified, 0, pad_len)
1085
+ cyclic_period = pad_dim(cyclic_period, 0, pad_len)
1086
+ affinity_mask = pad_dim(affinity_mask, 0, pad_len)
1087
+
1088
+ token_features = {
1089
+ "token_index": token_index,
1090
+ "residue_index": residue_index,
1091
+ "asym_id": asym_id,
1092
+ "entity_id": entity_id,
1093
+ "sym_id": sym_id,
1094
+ "mol_type": mol_type,
1095
+ "res_type": res_type,
1096
+ "disto_center": disto_center,
1097
+ "token_bonds": bonds,
1098
+ "type_bonds": bonds_type,
1099
+ "token_pad_mask": pad_mask,
1100
+ "token_resolved_mask": resolved_mask,
1101
+ "token_disto_mask": disto_mask,
1102
+ "contact_conditioning": contact_conditioning,
1103
+ "contact_threshold": contact_threshold,
1104
+ "method_feature": method_feature,
1105
+ "modified": modified,
1106
+ "cyclic_period": cyclic_period,
1107
+ "affinity_token_mask": affinity_mask,
1108
+ }
1109
+
1110
+ return token_features
1111
+
1112
+
1113
+ def process_atom_features(
1114
+ data: Tokenized,
1115
+ random: np.random.Generator,
1116
+ ensemble_features: dict,
1117
+ molecules: dict[str, Mol],
1118
+ atoms_per_window_queries: int = 32,
1119
+ min_dist: float = 2.0,
1120
+ max_dist: float = 22.0,
1121
+ num_bins: int = 64,
1122
+ max_atoms: Optional[int] = None,
1123
+ max_tokens: Optional[int] = None,
1124
+ disto_use_ensemble: Optional[bool] = False,
1125
+ override_bfactor: bool = False,
1126
+ compute_frames: bool = False,
1127
+ override_coords: Optional[Tensor] = None,
1128
+ bfactor_md_correction: bool = False,
1129
+ ) -> dict[str, Tensor]:
1130
+ """Get the atom features.
1131
+
1132
+ Parameters
1133
+ ----------
1134
+ data : Tokenized
1135
+ The input to the model.
1136
+ max_atoms : int, optional
1137
+ The maximum number of atoms.
1138
+
1139
+ Returns
1140
+ -------
1141
+ dict[str, Tensor]
1142
+ The atom features.
1143
+
1144
+ """
1145
+ # Filter to tokens' atoms
1146
+ atom_data = []
1147
+ atom_name = []
1148
+ atom_element = []
1149
+ atom_charge = []
1150
+ atom_conformer = []
1151
+ atom_chirality = []
1152
+ ref_space_uid = []
1153
+ coord_data = []
1154
+ if compute_frames:
1155
+ frame_data = []
1156
+ resolved_frame_data = []
1157
+ atom_to_token = []
1158
+ token_to_rep_atom = [] # index on cropped atom table
1159
+ r_set_to_rep_atom = []
1160
+ disto_coords_ensemble = []
1161
+ backbone_feat_index = []
1162
+ token_to_center_atom = []
1163
+
1164
+ e_offsets = data.structure.ensemble["atom_coord_idx"]
1165
+ atom_idx = 0
1166
+
1167
+ # Start atom idx in full atom table for structures chosen. Up to num_ensembles points.
1168
+ ensemble_atom_starts = [
1169
+ data.structure.ensemble[idx]["atom_coord_idx"]
1170
+ for idx in ensemble_features["ensemble_ref_idxs"]
1171
+ ]
1172
+
1173
+ # Set unk chirality id
1174
+ unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
1175
+
1176
+ chain_res_ids = {}
1177
+ res_index_to_conf_id = {}
1178
+ for token_id, token in enumerate(data.tokens):
1179
+ # Get the chain residue ids
1180
+ chain_idx, res_id = token["asym_id"], token["res_idx"]
1181
+ chain = data.structure.chains[chain_idx]
1182
+
1183
+ if (chain_idx, res_id) not in chain_res_ids:
1184
+ new_idx = len(chain_res_ids)
1185
+ chain_res_ids[(chain_idx, res_id)] = new_idx
1186
+ else:
1187
+ new_idx = chain_res_ids[(chain_idx, res_id)]
1188
+
1189
+ # Get the molecule and conformer
1190
+ mol = molecules[token["res_name"]]
1191
+ atom_name_to_ref = {a.GetProp("name"): a for a in mol.GetAtoms()}
1192
+
1193
+ # Sample a random conformer
1194
+ if (chain_idx, res_id) not in res_index_to_conf_id:
1195
+ conf_ids = [int(conf.GetId()) for conf in mol.GetConformers()]
1196
+ conf_id = int(random.choice(conf_ids))
1197
+ res_index_to_conf_id[(chain_idx, res_id)] = conf_id
1198
+
1199
+ conf_id = res_index_to_conf_id[(chain_idx, res_id)]
1200
+ conformer = mol.GetConformer(conf_id)
1201
+
1202
+ # Map atoms to token indices
1203
+ ref_space_uid.extend([new_idx] * token["atom_num"])
1204
+ atom_to_token.extend([token_id] * token["atom_num"])
1205
+
1206
+ # Add atom data
1207
+ start = token["atom_idx"]
1208
+ end = token["atom_idx"] + token["atom_num"]
1209
+ token_atoms = data.structure.atoms[start:end]
1210
+
1211
+ # Add atom ref data
1212
+ # element, charge, conformer, chirality
1213
+ token_atom_name = np.array([convert_atom_name(a["name"]) for a in token_atoms])
1214
+ token_atoms_ref = np.array([atom_name_to_ref[a["name"]] for a in token_atoms])
1215
+ token_atoms_element = np.array([a.GetAtomicNum() for a in token_atoms_ref])
1216
+ token_atoms_charge = np.array([a.GetFormalCharge() for a in token_atoms_ref])
1217
+ token_atoms_conformer = np.array(
1218
+ [
1219
+ (
1220
+ conformer.GetAtomPosition(a.GetIdx()).x,
1221
+ conformer.GetAtomPosition(a.GetIdx()).y,
1222
+ conformer.GetAtomPosition(a.GetIdx()).z,
1223
+ )
1224
+ for a in token_atoms_ref
1225
+ ]
1226
+ )
1227
+ token_atoms_chirality = np.array(
1228
+ [
1229
+ const.chirality_type_ids.get(a.GetChiralTag().name, unk_chirality)
1230
+ for a in token_atoms_ref
1231
+ ]
1232
+ )
1233
+
1234
+ # Map token to representative atom
1235
+ token_to_rep_atom.append(atom_idx + token["disto_idx"] - start)
1236
+ token_to_center_atom.append(atom_idx + token["center_idx"] - start)
1237
+ if (chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]) and token[
1238
+ "resolved_mask"
1239
+ ]:
1240
+ r_set_to_rep_atom.append(atom_idx + token["center_idx"] - start)
1241
+
1242
+ if chain["mol_type"] == const.chain_type_ids["PROTEIN"]:
1243
+ backbone_index = [
1244
+ (
1245
+ const.protein_backbone_atom_index[atom_name] + 1
1246
+ if atom_name in const.protein_backbone_atom_index
1247
+ else 0
1248
+ )
1249
+ for atom_name in token_atoms["name"]
1250
+ ]
1251
+ elif (
1252
+ chain["mol_type"] == const.chain_type_ids["DNA"]
1253
+ or chain["mol_type"] == const.chain_type_ids["RNA"]
1254
+ ):
1255
+ backbone_index = [
1256
+ (
1257
+ const.nucleic_backbone_atom_index[atom_name]
1258
+ + 1
1259
+ + len(const.protein_backbone_atom_index)
1260
+ if atom_name in const.nucleic_backbone_atom_index
1261
+ else 0
1262
+ )
1263
+ for atom_name in token_atoms["name"]
1264
+ ]
1265
+ else:
1266
+ backbone_index = [0] * token["atom_num"]
1267
+ backbone_feat_index.extend(backbone_index)
1268
+
1269
+ # Get token coordinates across sampled ensembles and apply transforms
1270
+ token_coords = np.array(
1271
+ [
1272
+ data.structure.coords[
1273
+ ensemble_atom_start + start : ensemble_atom_start + end
1274
+ ]["coords"]
1275
+ for ensemble_atom_start in ensemble_atom_starts
1276
+ ]
1277
+ )
1278
+ coord_data.append(token_coords)
1279
+
1280
+ if compute_frames:
1281
+ # Get frame data
1282
+ res_type = const.tokens[token["res_type"]]
1283
+ res_name = str(token["res_name"])
1284
+
1285
+ if token["atom_num"] < 3 or res_type in ["PAD", "UNK", "-"]:
1286
+ idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
1287
+ mask_frame = False
1288
+ elif (token["mol_type"] == const.chain_type_ids["PROTEIN"]) and (
1289
+ res_name in const.ref_atoms
1290
+ ):
1291
+ idx_frame_a, idx_frame_b, idx_frame_c = (
1292
+ const.ref_atoms[res_name].index("N"),
1293
+ const.ref_atoms[res_name].index("CA"),
1294
+ const.ref_atoms[res_name].index("C"),
1295
+ )
1296
+ mask_frame = (
1297
+ token_atoms["is_present"][idx_frame_a]
1298
+ and token_atoms["is_present"][idx_frame_b]
1299
+ and token_atoms["is_present"][idx_frame_c]
1300
+ )
1301
+ elif (
1302
+ token["mol_type"] == const.chain_type_ids["DNA"]
1303
+ or token["mol_type"] == const.chain_type_ids["RNA"]
1304
+ ) and (res_name in const.ref_atoms):
1305
+ idx_frame_a, idx_frame_b, idx_frame_c = (
1306
+ const.ref_atoms[res_name].index("C1'"),
1307
+ const.ref_atoms[res_name].index("C3'"),
1308
+ const.ref_atoms[res_name].index("C4'"),
1309
+ )
1310
+ mask_frame = (
1311
+ token_atoms["is_present"][idx_frame_a]
1312
+ and token_atoms["is_present"][idx_frame_b]
1313
+ and token_atoms["is_present"][idx_frame_c]
1314
+ )
1315
+ elif token["mol_type"] == const.chain_type_ids["PROTEIN"]:
1316
+ # Try to look for the atom nams in the modified residue
1317
+ is_ca = token_atoms["name"] == "CA"
1318
+ idx_frame_a = is_ca.argmax()
1319
+ ca_present = (
1320
+ token_atoms[idx_frame_a]["is_present"] if is_ca.any() else False
1321
+ )
1322
+
1323
+ is_n = token_atoms["name"] == "N"
1324
+ idx_frame_b = is_n.argmax()
1325
+ n_present = (
1326
+ token_atoms[idx_frame_b]["is_present"] if is_n.any() else False
1327
+ )
1328
+
1329
+ is_c = token_atoms["name"] == "C"
1330
+ idx_frame_c = is_c.argmax()
1331
+ c_present = (
1332
+ token_atoms[idx_frame_c]["is_present"] if is_c.any() else False
1333
+ )
1334
+ mask_frame = ca_present and n_present and c_present
1335
+
1336
+ elif (token["mol_type"] == const.chain_type_ids["DNA"]) or (
1337
+ token["mol_type"] == const.chain_type_ids["RNA"]
1338
+ ):
1339
+ # Try to look for the atom nams in the modified residue
1340
+ is_c1 = token_atoms["name"] == "C1'"
1341
+ idx_frame_a = is_c1.argmax()
1342
+ c1_present = (
1343
+ token_atoms[idx_frame_a]["is_present"] if is_c1.any() else False
1344
+ )
1345
+
1346
+ is_c3 = token_atoms["name"] == "C3'"
1347
+ idx_frame_b = is_c3.argmax()
1348
+ c3_present = (
1349
+ token_atoms[idx_frame_b]["is_present"] if is_c3.any() else False
1350
+ )
1351
+
1352
+ is_c4 = token_atoms["name"] == "C4'"
1353
+ idx_frame_c = is_c4.argmax()
1354
+ c4_present = (
1355
+ token_atoms[idx_frame_c]["is_present"] if is_c4.any() else False
1356
+ )
1357
+ mask_frame = c1_present and c3_present and c4_present
1358
+ else:
1359
+ idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
1360
+ mask_frame = False
1361
+ frame_data.append(
1362
+ [
1363
+ idx_frame_a + atom_idx,
1364
+ idx_frame_b + atom_idx,
1365
+ idx_frame_c + atom_idx,
1366
+ ]
1367
+ )
1368
+ resolved_frame_data.append(mask_frame)
1369
+
1370
+ # Get distogram coordinates
1371
+ disto_coords_ensemble_tok = data.structure.coords[
1372
+ e_offsets + token["disto_idx"]
1373
+ ]["coords"]
1374
+ disto_coords_ensemble.append(disto_coords_ensemble_tok)
1375
+
1376
+ # Update atom data. This is technically never used again (we rely on coord_data),
1377
+ # but we update for consistency and to make sure the Atom object has valid, transformed coordinates.
1378
+ token_atoms = token_atoms.copy()
1379
+ token_atoms["coords"] = token_coords[
1380
+ 0
1381
+ ] # atom has a copy of first coords in ensemble
1382
+ atom_data.append(token_atoms)
1383
+ atom_name.append(token_atom_name)
1384
+ atom_element.append(token_atoms_element)
1385
+ atom_charge.append(token_atoms_charge)
1386
+ atom_conformer.append(token_atoms_conformer)
1387
+ atom_chirality.append(token_atoms_chirality)
1388
+ atom_idx += len(token_atoms)
1389
+
1390
+ disto_coords_ensemble = np.array(disto_coords_ensemble) # (N_TOK, N_ENS, 3)
1391
+
1392
+ # Compute ensemble distogram
1393
+ L = len(data.tokens)
1394
+
1395
+ if disto_use_ensemble:
1396
+ # Use all available structures to create distogram
1397
+ idx_list = range(disto_coords_ensemble.shape[1])
1398
+ else:
1399
+ # Only use a sampled structures to create distogram
1400
+ idx_list = ensemble_features["ensemble_ref_idxs"]
1401
+
1402
+ # Create distogram
1403
+ disto_target = torch.zeros(L, L, len(idx_list), num_bins) # TODO1
1404
+
1405
+ # disto_target = torch.zeros(L, L, num_bins)
1406
+ for i, e_idx in enumerate(idx_list):
1407
+ t_center = torch.Tensor(disto_coords_ensemble[:, e_idx, :])
1408
+ t_dists = torch.cdist(t_center, t_center)
1409
+ boundaries = torch.linspace(min_dist, max_dist, num_bins - 1)
1410
+ distogram = (t_dists.unsqueeze(-1) > boundaries).sum(dim=-1).long()
1411
+ # disto_target += one_hot(distogram, num_classes=num_bins)
1412
+ disto_target[:, :, i, :] = one_hot(distogram, num_classes=num_bins) # TODO1
1413
+
1414
+ # Normalize distogram
1415
+ # disto_target = disto_target / disto_target.sum(-1)[..., None] # remove TODO1
1416
+ atom_data = np.concatenate(atom_data)
1417
+ atom_name = np.concatenate(atom_name)
1418
+ atom_element = np.concatenate(atom_element)
1419
+ atom_charge = np.concatenate(atom_charge)
1420
+ atom_conformer = np.concatenate(atom_conformer)
1421
+ atom_chirality = np.concatenate(atom_chirality)
1422
+ coord_data = np.concatenate(coord_data, axis=1)
1423
+ ref_space_uid = np.array(ref_space_uid)
1424
+
1425
+ # Compute features
1426
+ disto_coords_ensemble = from_numpy(disto_coords_ensemble)
1427
+ disto_coords_ensemble = disto_coords_ensemble[
1428
+ :, ensemble_features["ensemble_ref_idxs"]
1429
+ ].permute(1, 0, 2)
1430
+ backbone_feat_index = from_numpy(np.asarray(backbone_feat_index)).long()
1431
+ ref_atom_name_chars = from_numpy(atom_name).long()
1432
+ ref_element = from_numpy(atom_element).long()
1433
+ ref_charge = from_numpy(atom_charge).float()
1434
+ ref_pos = from_numpy(atom_conformer).float()
1435
+ ref_space_uid = from_numpy(ref_space_uid)
1436
+ ref_chirality = from_numpy(atom_chirality).long()
1437
+ coords = from_numpy(coord_data.copy())
1438
+ resolved_mask = from_numpy(atom_data["is_present"])
1439
+ pad_mask = torch.ones(len(atom_data), dtype=torch.float)
1440
+ atom_to_token = torch.tensor(atom_to_token, dtype=torch.long)
1441
+ token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long)
1442
+ r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long)
1443
+ token_to_center_atom = torch.tensor(token_to_center_atom, dtype=torch.long)
1444
+ bfactor = from_numpy(atom_data["bfactor"].copy())
1445
+ plddt = from_numpy(atom_data["plddt"].copy())
1446
+ if override_bfactor:
1447
+ bfactor = bfactor * 0.0
1448
+
1449
+ if bfactor_md_correction and data.record.structure.method.lower() == "md":
1450
+ # MD bfactor was computed as RMSF
1451
+ # Convert to b-factor
1452
+ bfactor = 8 * (np.pi**2) * (bfactor**2)
1453
+
1454
+ # We compute frames within ensemble
1455
+ if compute_frames:
1456
+ frames = []
1457
+ frame_resolved_mask = []
1458
+ for i in range(coord_data.shape[0]):
1459
+ frame_data_, resolved_frame_data_ = compute_frames_nonpolymer(
1460
+ data,
1461
+ coord_data[i],
1462
+ atom_data["is_present"],
1463
+ atom_to_token,
1464
+ frame_data,
1465
+ resolved_frame_data,
1466
+ ) # Compute frames for NONPOLYMER tokens
1467
+ frames.append(frame_data_.copy())
1468
+ frame_resolved_mask.append(resolved_frame_data_.copy())
1469
+ frames = from_numpy(np.stack(frames)) # (N_ENS, N_TOK, 3)
1470
+ frame_resolved_mask = from_numpy(np.stack(frame_resolved_mask))
1471
+
1472
+ # Convert to one-hot
1473
+ backbone_feat_index = one_hot(
1474
+ backbone_feat_index,
1475
+ num_classes=1
1476
+ + len(const.protein_backbone_atom_index)
1477
+ + len(const.nucleic_backbone_atom_index),
1478
+ )
1479
+ ref_atom_name_chars = one_hot(ref_atom_name_chars, num_classes=64)
1480
+ ref_element = one_hot(ref_element, num_classes=const.num_elements)
1481
+ atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
1482
+ token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data))
1483
+ r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data))
1484
+ token_to_center_atom = one_hot(token_to_center_atom, num_classes=len(atom_data))
1485
+
1486
+ # Center the ground truth coordinates
1487
+ center = (coords * resolved_mask[None, :, None]).sum(dim=1)
1488
+ center = center / resolved_mask.sum().clamp(min=1)
1489
+ coords = coords - center[:, None]
1490
+
1491
+ if isinstance(override_coords, Tensor):
1492
+ coords = override_coords.unsqueeze(0)
1493
+
1494
+ # Apply random roto-translation to the input conformers
1495
+ for i in range(torch.max(ref_space_uid)):
1496
+ included = ref_space_uid == i
1497
+ if torch.sum(included) > 0 and torch.any(resolved_mask[included]):
1498
+ ref_pos[included] = center_random_augmentation(
1499
+ ref_pos[included][None], resolved_mask[included][None], centering=True
1500
+ )[0]
1501
+
1502
+ # Compute padding and apply
1503
+ if max_atoms is not None:
1504
+ assert max_atoms % atoms_per_window_queries == 0
1505
+ pad_len = max_atoms - len(atom_data)
1506
+ else:
1507
+ pad_len = (
1508
+ (len(atom_data) - 1) // atoms_per_window_queries + 1
1509
+ ) * atoms_per_window_queries - len(atom_data)
1510
+
1511
+ if pad_len > 0:
1512
+ pad_mask = pad_dim(pad_mask, 0, pad_len)
1513
+ ref_pos = pad_dim(ref_pos, 0, pad_len)
1514
+ resolved_mask = pad_dim(resolved_mask, 0, pad_len)
1515
+ ref_atom_name_chars = pad_dim(ref_atom_name_chars, 0, pad_len)
1516
+ ref_element = pad_dim(ref_element, 0, pad_len)
1517
+ ref_charge = pad_dim(ref_charge, 0, pad_len)
1518
+ ref_chirality = pad_dim(ref_chirality, 0, pad_len)
1519
+ backbone_feat_index = pad_dim(backbone_feat_index, 0, pad_len)
1520
+ ref_space_uid = pad_dim(ref_space_uid, 0, pad_len)
1521
+ coords = pad_dim(coords, 1, pad_len)
1522
+ atom_to_token = pad_dim(atom_to_token, 0, pad_len)
1523
+ token_to_rep_atom = pad_dim(token_to_rep_atom, 1, pad_len)
1524
+ token_to_center_atom = pad_dim(token_to_center_atom, 1, pad_len)
1525
+ r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 1, pad_len)
1526
+ bfactor = pad_dim(bfactor, 0, pad_len)
1527
+ plddt = pad_dim(plddt, 0, pad_len)
1528
+
1529
+ if max_tokens is not None:
1530
+ pad_len = max_tokens - token_to_rep_atom.shape[0]
1531
+ if pad_len > 0:
1532
+ atom_to_token = pad_dim(atom_to_token, 1, pad_len)
1533
+ token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len)
1534
+ r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len)
1535
+ token_to_center_atom = pad_dim(token_to_center_atom, 0, pad_len)
1536
+ disto_target = pad_dim(pad_dim(disto_target, 0, pad_len), 1, pad_len)
1537
+ disto_coords_ensemble = pad_dim(disto_coords_ensemble, 1, pad_len)
1538
+
1539
+ if compute_frames:
1540
+ frames = pad_dim(frames, 1, pad_len)
1541
+ frame_resolved_mask = pad_dim(frame_resolved_mask, 1, pad_len)
1542
+
1543
+ atom_features = {
1544
+ "ref_pos": ref_pos,
1545
+ "atom_resolved_mask": resolved_mask,
1546
+ "ref_atom_name_chars": ref_atom_name_chars,
1547
+ "ref_element": ref_element,
1548
+ "ref_charge": ref_charge,
1549
+ "ref_chirality": ref_chirality,
1550
+ "atom_backbone_feat": backbone_feat_index,
1551
+ "ref_space_uid": ref_space_uid,
1552
+ "coords": coords,
1553
+ "atom_pad_mask": pad_mask,
1554
+ "atom_to_token": atom_to_token,
1555
+ "token_to_rep_atom": token_to_rep_atom,
1556
+ "r_set_to_rep_atom": r_set_to_rep_atom,
1557
+ "token_to_center_atom": token_to_center_atom,
1558
+ "disto_target": disto_target,
1559
+ "disto_coords_ensemble": disto_coords_ensemble,
1560
+ "bfactor": bfactor,
1561
+ "plddt": plddt,
1562
+ }
1563
+
1564
+ if compute_frames:
1565
+ atom_features["frames_idx"] = frames
1566
+ atom_features["frame_resolved_mask"] = frame_resolved_mask
1567
+
1568
+ return atom_features
1569
+
1570
+
1571
+ def process_msa_features(
1572
+ data: Tokenized,
1573
+ random: np.random.Generator,
1574
+ max_seqs_batch: int,
1575
+ max_seqs: int,
1576
+ max_tokens: Optional[int] = None,
1577
+ pad_to_max_seqs: bool = False,
1578
+ msa_sampling: bool = False,
1579
+ affinity: bool = False,
1580
+ ) -> dict[str, Tensor]:
1581
+ """Get the MSA features.
1582
+
1583
+ Parameters
1584
+ ----------
1585
+ data : Tokenized
1586
+ The input to the model.
1587
+ random : np.random.Generator
1588
+ The random number generator.
1589
+ max_seqs : int
1590
+ The maximum number of MSA sequences.
1591
+ max_tokens : int
1592
+ The maximum number of tokens.
1593
+ pad_to_max_seqs : bool
1594
+ Whether to pad to the maximum number of sequences.
1595
+ msa_sampling : bool
1596
+ Whether to sample the MSA.
1597
+
1598
+ Returns
1599
+ -------
1600
+ dict[str, Tensor]
1601
+ The MSA features.
1602
+
1603
+ """
1604
+ # Created paired MSA
1605
+ msa, deletion, paired = construct_paired_msa(
1606
+ data=data,
1607
+ random=random,
1608
+ max_seqs=max_seqs_batch,
1609
+ random_subset=msa_sampling,
1610
+ )
1611
+ msa, deletion, paired = (
1612
+ msa.transpose(1, 0),
1613
+ deletion.transpose(1, 0),
1614
+ paired.transpose(1, 0),
1615
+ ) # (N_MSA, N_RES, N_AA)
1616
+
1617
+ # Prepare features
1618
+ assert torch.all(msa >= 0) and torch.all(msa < const.num_tokens)
1619
+ msa_one_hot = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
1620
+ msa_mask = torch.ones_like(msa)
1621
+ profile = msa_one_hot.float().mean(dim=0)
1622
+ has_deletion = deletion > 0
1623
+ deletion = np.pi / 2 * np.arctan(deletion / 3)
1624
+ deletion_mean = deletion.mean(axis=0)
1625
+
1626
+ # Pad in the MSA dimension (dim=0)
1627
+ if pad_to_max_seqs:
1628
+ pad_len = max_seqs - msa.shape[0]
1629
+ if pad_len > 0:
1630
+ msa = pad_dim(msa, 0, pad_len, const.token_ids["-"])
1631
+ paired = pad_dim(paired, 0, pad_len)
1632
+ msa_mask = pad_dim(msa_mask, 0, pad_len)
1633
+ has_deletion = pad_dim(has_deletion, 0, pad_len)
1634
+ deletion = pad_dim(deletion, 0, pad_len)
1635
+
1636
+ # Pad in the token dimension (dim=1)
1637
+ if max_tokens is not None:
1638
+ pad_len = max_tokens - msa.shape[1]
1639
+ if pad_len > 0:
1640
+ msa = pad_dim(msa, 1, pad_len, const.token_ids["-"])
1641
+ paired = pad_dim(paired, 1, pad_len)
1642
+ msa_mask = pad_dim(msa_mask, 1, pad_len)
1643
+ has_deletion = pad_dim(has_deletion, 1, pad_len)
1644
+ deletion = pad_dim(deletion, 1, pad_len)
1645
+ profile = pad_dim(profile, 0, pad_len)
1646
+ deletion_mean = pad_dim(deletion_mean, 0, pad_len)
1647
+ if affinity:
1648
+ return {
1649
+ "deletion_mean_affinity": deletion_mean,
1650
+ "profile_affinity": profile,
1651
+ }
1652
+ else:
1653
+ return {
1654
+ "msa": msa,
1655
+ "msa_paired": paired,
1656
+ "deletion_value": deletion,
1657
+ "has_deletion": has_deletion,
1658
+ "deletion_mean": deletion_mean,
1659
+ "profile": profile,
1660
+ "msa_mask": msa_mask,
1661
+ }
1662
+
1663
+
1664
+ def load_dummy_templates_features(tdim: int, num_tokens: int) -> dict:
1665
+ """Load dummy templates for v2."""
1666
+ # Allocate features
1667
+ res_type = np.zeros((tdim, num_tokens), dtype=np.int64)
1668
+ frame_rot = np.zeros((tdim, num_tokens, 3, 3), dtype=np.float32)
1669
+ frame_t = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
1670
+ cb_coords = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
1671
+ ca_coords = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
1672
+ frame_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
1673
+ cb_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
1674
+ template_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
1675
+ query_to_template = np.zeros((tdim, num_tokens), dtype=np.int64)
1676
+ visibility_ids = np.zeros((tdim, num_tokens), dtype=np.float32)
1677
+
1678
+ # Convert to one-hot
1679
+ res_type = torch.from_numpy(res_type)
1680
+ res_type = one_hot(res_type, num_classes=const.num_tokens)
1681
+
1682
+ return {
1683
+ "template_restype": res_type,
1684
+ "template_frame_rot": torch.from_numpy(frame_rot),
1685
+ "template_frame_t": torch.from_numpy(frame_t),
1686
+ "template_cb": torch.from_numpy(cb_coords),
1687
+ "template_ca": torch.from_numpy(ca_coords),
1688
+ "template_mask_cb": torch.from_numpy(cb_mask),
1689
+ "template_mask_frame": torch.from_numpy(frame_mask),
1690
+ "template_mask": torch.from_numpy(template_mask),
1691
+ "query_to_template": torch.from_numpy(query_to_template),
1692
+ "visibility_ids": torch.from_numpy(visibility_ids),
1693
+ }
1694
+
1695
+
1696
+ def compute_template_features(
1697
+ query_tokens: Tokenized,
1698
+ tmpl_tokens: list[dict],
1699
+ num_tokens: int,
1700
+ ) -> dict:
1701
+ """Compute the template features."""
1702
+ # Allocate features
1703
+ res_type = np.zeros((num_tokens,), dtype=np.int64)
1704
+ frame_rot = np.zeros((num_tokens, 3, 3), dtype=np.float32)
1705
+ frame_t = np.zeros((num_tokens, 3), dtype=np.float32)
1706
+ cb_coords = np.zeros((num_tokens, 3), dtype=np.float32)
1707
+ ca_coords = np.zeros((num_tokens, 3), dtype=np.float32)
1708
+ frame_mask = np.zeros((num_tokens,), dtype=np.float32)
1709
+ cb_mask = np.zeros((num_tokens,), dtype=np.float32)
1710
+ template_mask = np.zeros((num_tokens,), dtype=np.float32)
1711
+ query_to_template = np.zeros((num_tokens,), dtype=np.int64)
1712
+ visibility_ids = np.zeros((num_tokens,), dtype=np.float32)
1713
+
1714
+ # Now create features per token
1715
+ asym_id_to_pdb_id = {}
1716
+
1717
+ for token_dict in tmpl_tokens:
1718
+ idx = token_dict["q_idx"]
1719
+ pdb_id = token_dict["pdb_id"]
1720
+ token = token_dict["token"]
1721
+ query_token = query_tokens.tokens[idx]
1722
+ asym_id_to_pdb_id[query_token["asym_id"]] = pdb_id
1723
+ res_type[idx] = token["res_type"]
1724
+ frame_rot[idx] = token["frame_rot"].reshape(3, 3)
1725
+ frame_t[idx] = token["frame_t"]
1726
+ cb_coords[idx] = token["disto_coords"]
1727
+ ca_coords[idx] = token["center_coords"]
1728
+ cb_mask[idx] = token["disto_mask"]
1729
+ frame_mask[idx] = token["frame_mask"]
1730
+ template_mask[idx] = 1.0
1731
+
1732
+ # Set visibility_id for templated chains
1733
+ for asym_id, pdb_id in asym_id_to_pdb_id.items():
1734
+ indices = (query_tokens.tokens["asym_id"] == asym_id).nonzero()
1735
+ visibility_ids[indices] = pdb_id
1736
+
1737
+ # Set visibility for non templated chain + olygomerics
1738
+ for asym_id in np.unique(query_tokens.structure.chains["asym_id"]):
1739
+ if asym_id not in asym_id_to_pdb_id:
1740
+ # We hack the chain id to be negative to not overlap with the above
1741
+ indices = (query_tokens.tokens["asym_id"] == asym_id).nonzero()
1742
+ visibility_ids[indices] = -1 - asym_id
1743
+
1744
+ # Convert to one-hot
1745
+ res_type = torch.from_numpy(res_type)
1746
+ res_type = one_hot(res_type, num_classes=const.num_tokens)
1747
+
1748
+ return {
1749
+ "template_restype": res_type,
1750
+ "template_frame_rot": torch.from_numpy(frame_rot),
1751
+ "template_frame_t": torch.from_numpy(frame_t),
1752
+ "template_cb": torch.from_numpy(cb_coords),
1753
+ "template_ca": torch.from_numpy(ca_coords),
1754
+ "template_mask_cb": torch.from_numpy(cb_mask),
1755
+ "template_mask_frame": torch.from_numpy(frame_mask),
1756
+ "template_mask": torch.from_numpy(template_mask),
1757
+ "query_to_template": torch.from_numpy(query_to_template),
1758
+ "visibility_ids": torch.from_numpy(visibility_ids),
1759
+ }
1760
+
1761
+
1762
+ def process_template_features(
1763
+ data: Tokenized,
1764
+ max_tokens: int,
1765
+ ) -> dict[str, torch.Tensor]:
1766
+ """Load the given input data.
1767
+
1768
+ Parameters
1769
+ ----------
1770
+ data : Tokenized
1771
+ The input to the model.
1772
+ max_tokens : int
1773
+ The maximum number of tokens.
1774
+
1775
+ Returns
1776
+ -------
1777
+ dict[str, torch.Tensor]
1778
+ The loaded template features.
1779
+
1780
+ """
1781
+ # Group templates by name
1782
+ name_to_templates: dict[str, list[TemplateInfo]] = {}
1783
+ for template_info in data.record.templates:
1784
+ name_to_templates.setdefault(template_info.name, []).append(template_info)
1785
+
1786
+ # Map chain name to asym_id
1787
+ chain_name_to_asym_id = {}
1788
+ for chain in data.structure.chains:
1789
+ chain_name_to_asym_id[chain["name"]] = chain["asym_id"]
1790
+
1791
+ # Compute the offset
1792
+ template_features = []
1793
+ for template_id, (template_name, templates) in enumerate(name_to_templates.items()):
1794
+ row_tokens = []
1795
+ template_structure = data.templates[template_name]
1796
+ template_tokens = data.template_tokens[template_name]
1797
+ tmpl_chain_name_to_asym_id = {}
1798
+ for chain in template_structure.chains:
1799
+ tmpl_chain_name_to_asym_id[chain["name"]] = chain["asym_id"]
1800
+
1801
+ for template in templates:
1802
+ offset = template.template_st - template.query_st
1803
+
1804
+ # Get query and template tokens to map residues
1805
+ query_tokens = data.tokens
1806
+ chain_id = chain_name_to_asym_id[template.query_chain]
1807
+ q_tokens = query_tokens[query_tokens["asym_id"] == chain_id]
1808
+ q_indices = dict(zip(q_tokens["res_idx"], q_tokens["token_idx"]))
1809
+
1810
+ # Get the template tokens at the query residues
1811
+ chain_id = tmpl_chain_name_to_asym_id[template.template_chain]
1812
+ toks = template_tokens[template_tokens["asym_id"] == chain_id]
1813
+ toks = [t for t in toks if t["res_idx"] - offset in q_indices]
1814
+ for t in toks:
1815
+ q_idx = q_indices[t["res_idx"] - offset]
1816
+ row_tokens.append(
1817
+ {
1818
+ "token": t,
1819
+ "pdb_id": template_id,
1820
+ "q_idx": q_idx,
1821
+ }
1822
+ )
1823
+
1824
+ # Compute template features for each row
1825
+ row_features = compute_template_features(data, row_tokens, max_tokens)
1826
+ row_features["template_force"] = torch.tensor(template.force)
1827
+ row_features["template_force_threshold"] = torch.tensor(
1828
+ template.threshold if template.threshold is not None else float("inf"),
1829
+ dtype=torch.float32,
1830
+ )
1831
+ template_features.append(row_features)
1832
+
1833
+ # Stack each feature
1834
+ out = {}
1835
+ for k in template_features[0]:
1836
+ out[k] = torch.stack([f[k] for f in template_features])
1837
+ return out
1838
+
1839
+
1840
+ def process_symmetry_features(
1841
+ cropped: Tokenized, symmetries: dict
1842
+ ) -> dict[str, Tensor]:
1843
+ """Get the symmetry features.
1844
+
1845
+ Parameters
1846
+ ----------
1847
+ data : Tokenized
1848
+ The input to the model.
1849
+
1850
+ Returns
1851
+ -------
1852
+ dict[str, Tensor]
1853
+ The symmetry features.
1854
+
1855
+ """
1856
+ features = get_chain_symmetries(cropped)
1857
+ features.update(get_amino_acids_symmetries(cropped))
1858
+ features.update(get_ligand_symmetries(cropped, symmetries))
1859
+
1860
+ return features
1861
+
1862
+
1863
+ def process_ensemble_features(
1864
+ data: Tokenized,
1865
+ random: np.random.Generator,
1866
+ num_ensembles: int,
1867
+ ensemble_sample_replacement: bool,
1868
+ fix_single_ensemble: bool,
1869
+ ) -> dict[str, Tensor]:
1870
+ """Get the ensemble features.
1871
+
1872
+ Parameters
1873
+ ----------
1874
+ data : Tokenized
1875
+ The input to the model.
1876
+ random : np.random.Generator
1877
+ The random number generator.
1878
+ num_ensembles : int
1879
+ The maximum number of ensembles to sample.
1880
+ ensemble_sample_replacement : bool
1881
+ Whether to sample with replacement.
1882
+
1883
+ Returns
1884
+ -------
1885
+ dict[str, Tensor]
1886
+ The ensemble features.
1887
+
1888
+ """
1889
+ assert num_ensembles > 0, "Number of conformers sampled must be greater than 0."
1890
+
1891
+ # Number of available conformers in the structure
1892
+ # s_ensemble_num = min(len(cropped.structure.ensemble), 24) # Limit to 24 conformers DEBUG: TODO: remove !
1893
+ s_ensemble_num = len(data.structure.ensemble)
1894
+
1895
+ if fix_single_ensemble:
1896
+ # Always take the first conformer for train and validation
1897
+ assert num_ensembles == 1, (
1898
+ "Number of conformers sampled must be 1 with fix_single_ensemble=True."
1899
+ )
1900
+ ensemble_ref_idxs = np.array([0])
1901
+ else:
1902
+ if ensemble_sample_replacement:
1903
+ # Used in training
1904
+ ensemble_ref_idxs = random.integers(0, s_ensemble_num, (num_ensembles,))
1905
+ else:
1906
+ # Used in validation
1907
+ if s_ensemble_num < num_ensembles:
1908
+ # Take all available conformers
1909
+ ensemble_ref_idxs = np.arange(0, s_ensemble_num)
1910
+ else:
1911
+ # Sample without replacement
1912
+ ensemble_ref_idxs = random.choice(
1913
+ s_ensemble_num, num_ensembles, replace=False
1914
+ )
1915
+
1916
+ ensemble_features = {
1917
+ "ensemble_ref_idxs": torch.Tensor(ensemble_ref_idxs).long(),
1918
+ }
1919
+
1920
+ return ensemble_features
1921
+
1922
+
1923
+ def process_residue_constraint_features(data: Tokenized) -> dict[str, Tensor]:
1924
+ residue_constraints = data.residue_constraints
1925
+ if residue_constraints is not None:
1926
+ rdkit_bounds_constraints = residue_constraints.rdkit_bounds_constraints
1927
+ chiral_atom_constraints = residue_constraints.chiral_atom_constraints
1928
+ stereo_bond_constraints = residue_constraints.stereo_bond_constraints
1929
+ planar_bond_constraints = residue_constraints.planar_bond_constraints
1930
+ planar_ring_5_constraints = residue_constraints.planar_ring_5_constraints
1931
+ planar_ring_6_constraints = residue_constraints.planar_ring_6_constraints
1932
+
1933
+ rdkit_bounds_index = torch.tensor(
1934
+ rdkit_bounds_constraints["atom_idxs"].copy(), dtype=torch.long
1935
+ ).T
1936
+ rdkit_bounds_bond_mask = torch.tensor(
1937
+ rdkit_bounds_constraints["is_bond"].copy(), dtype=torch.bool
1938
+ )
1939
+ rdkit_bounds_angle_mask = torch.tensor(
1940
+ rdkit_bounds_constraints["is_angle"].copy(), dtype=torch.bool
1941
+ )
1942
+ rdkit_upper_bounds = torch.tensor(
1943
+ rdkit_bounds_constraints["upper_bound"].copy(), dtype=torch.float
1944
+ )
1945
+ rdkit_lower_bounds = torch.tensor(
1946
+ rdkit_bounds_constraints["lower_bound"].copy(), dtype=torch.float
1947
+ )
1948
+
1949
+ chiral_atom_index = torch.tensor(
1950
+ chiral_atom_constraints["atom_idxs"].copy(), dtype=torch.long
1951
+ ).T
1952
+ chiral_reference_mask = torch.tensor(
1953
+ chiral_atom_constraints["is_reference"].copy(), dtype=torch.bool
1954
+ )
1955
+ chiral_atom_orientations = torch.tensor(
1956
+ chiral_atom_constraints["is_r"].copy(), dtype=torch.bool
1957
+ )
1958
+
1959
+ stereo_bond_index = torch.tensor(
1960
+ stereo_bond_constraints["atom_idxs"].copy(), dtype=torch.long
1961
+ ).T
1962
+ stereo_reference_mask = torch.tensor(
1963
+ stereo_bond_constraints["is_reference"].copy(), dtype=torch.bool
1964
+ )
1965
+ stereo_bond_orientations = torch.tensor(
1966
+ stereo_bond_constraints["is_e"].copy(), dtype=torch.bool
1967
+ )
1968
+
1969
+ planar_bond_index = torch.tensor(
1970
+ planar_bond_constraints["atom_idxs"].copy(), dtype=torch.long
1971
+ ).T
1972
+ planar_ring_5_index = torch.tensor(
1973
+ planar_ring_5_constraints["atom_idxs"].copy(), dtype=torch.long
1974
+ ).T
1975
+ planar_ring_6_index = torch.tensor(
1976
+ planar_ring_6_constraints["atom_idxs"].copy(), dtype=torch.long
1977
+ ).T
1978
+ else:
1979
+ rdkit_bounds_index = torch.empty((2, 0), dtype=torch.long)
1980
+ rdkit_bounds_bond_mask = torch.empty((0,), dtype=torch.bool)
1981
+ rdkit_bounds_angle_mask = torch.empty((0,), dtype=torch.bool)
1982
+ rdkit_upper_bounds = torch.empty((0,), dtype=torch.float)
1983
+ rdkit_lower_bounds = torch.empty((0,), dtype=torch.float)
1984
+ chiral_atom_index = torch.empty(
1985
+ (
1986
+ 4,
1987
+ 0,
1988
+ ),
1989
+ dtype=torch.long,
1990
+ )
1991
+ chiral_reference_mask = torch.empty((0,), dtype=torch.bool)
1992
+ chiral_atom_orientations = torch.empty((0,), dtype=torch.bool)
1993
+ stereo_bond_index = torch.empty((4, 0), dtype=torch.long)
1994
+ stereo_reference_mask = torch.empty((0,), dtype=torch.bool)
1995
+ stereo_bond_orientations = torch.empty((0,), dtype=torch.bool)
1996
+ planar_bond_index = torch.empty((6, 0), dtype=torch.long)
1997
+ planar_ring_5_index = torch.empty((5, 0), dtype=torch.long)
1998
+ planar_ring_6_index = torch.empty((6, 0), dtype=torch.long)
1999
+
2000
+ return {
2001
+ "rdkit_bounds_index": rdkit_bounds_index,
2002
+ "rdkit_bounds_bond_mask": rdkit_bounds_bond_mask,
2003
+ "rdkit_bounds_angle_mask": rdkit_bounds_angle_mask,
2004
+ "rdkit_upper_bounds": rdkit_upper_bounds,
2005
+ "rdkit_lower_bounds": rdkit_lower_bounds,
2006
+ "chiral_atom_index": chiral_atom_index,
2007
+ "chiral_reference_mask": chiral_reference_mask,
2008
+ "chiral_atom_orientations": chiral_atom_orientations,
2009
+ "stereo_bond_index": stereo_bond_index,
2010
+ "stereo_reference_mask": stereo_reference_mask,
2011
+ "stereo_bond_orientations": stereo_bond_orientations,
2012
+ "planar_bond_index": planar_bond_index,
2013
+ "planar_ring_5_index": planar_ring_5_index,
2014
+ "planar_ring_6_index": planar_ring_6_index,
2015
+ }
2016
+
2017
+
2018
+ def process_chain_feature_constraints(data: Tokenized) -> dict[str, Tensor]:
2019
+ structure = data.structure
2020
+ if structure.bonds.shape[0] > 0:
2021
+ connected_chain_index, connected_atom_index = [], []
2022
+ for connection in structure.bonds:
2023
+ if connection["chain_1"] == connection["chain_2"]:
2024
+ continue
2025
+ connected_chain_index.append([connection["chain_1"], connection["chain_2"]])
2026
+ connected_atom_index.append([connection["atom_1"], connection["atom_2"]])
2027
+ if len(connected_chain_index) > 0:
2028
+ connected_chain_index = torch.tensor(
2029
+ connected_chain_index, dtype=torch.long
2030
+ ).T
2031
+ connected_atom_index = torch.tensor(
2032
+ connected_atom_index, dtype=torch.long
2033
+ ).T
2034
+ else:
2035
+ connected_chain_index = torch.empty((2, 0), dtype=torch.long)
2036
+ connected_atom_index = torch.empty((2, 0), dtype=torch.long)
2037
+ else:
2038
+ connected_chain_index = torch.empty((2, 0), dtype=torch.long)
2039
+ connected_atom_index = torch.empty((2, 0), dtype=torch.long)
2040
+
2041
+ symmetric_chain_index = []
2042
+ for i, chain_i in enumerate(structure.chains):
2043
+ for j, chain_j in enumerate(structure.chains):
2044
+ if j <= i:
2045
+ continue
2046
+ if chain_i["entity_id"] == chain_j["entity_id"]:
2047
+ symmetric_chain_index.append([i, j])
2048
+ if len(symmetric_chain_index) > 0:
2049
+ symmetric_chain_index = torch.tensor(symmetric_chain_index, dtype=torch.long).T
2050
+ else:
2051
+ symmetric_chain_index = torch.empty((2, 0), dtype=torch.long)
2052
+ return {
2053
+ "connected_chain_index": connected_chain_index,
2054
+ "connected_atom_index": connected_atom_index,
2055
+ "symmetric_chain_index": symmetric_chain_index,
2056
+ }
2057
+
2058
+
2059
+ def process_contact_feature_constraints(
2060
+ data: Tokenized,
2061
+ inference_pocket_constraints: list[tuple[int, list[tuple[int, int]], float]],
2062
+ inference_contact_constraints: list[tuple[tuple[int, int], tuple[int, int], float]],
2063
+ ):
2064
+ token_data = data.tokens
2065
+ union_idx = 0
2066
+ pair_index, union_index, negation_mask, thresholds = [], [], [], []
2067
+ for binder, contacts, max_distance, force in inference_pocket_constraints:
2068
+ if not force:
2069
+ continue
2070
+
2071
+ binder_chain = data.structure.chains[binder]
2072
+ for token in token_data:
2073
+ if (
2074
+ token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
2075
+ and (token["asym_id"], token["res_idx"]) in contacts
2076
+ ) or (
2077
+ token["mol_type"] == const.chain_type_ids["NONPOLYMER"]
2078
+ and (token["asym_id"], token["atom_idx"]) in contacts
2079
+ ):
2080
+ atom_idx_pairs = torch.cartesian_prod(
2081
+ torch.arange(
2082
+ binder_chain["atom_idx"],
2083
+ binder_chain["atom_idx"] + binder_chain["atom_num"],
2084
+ ),
2085
+ torch.arange(
2086
+ token["atom_idx"], token["atom_idx"] + token["atom_num"]
2087
+ ),
2088
+ ).T
2089
+ pair_index.append(atom_idx_pairs)
2090
+ union_index.append(torch.full((atom_idx_pairs.shape[1],), union_idx))
2091
+ negation_mask.append(
2092
+ torch.ones((atom_idx_pairs.shape[1],), dtype=torch.bool)
2093
+ )
2094
+ thresholds.append(torch.full((atom_idx_pairs.shape[1],), max_distance))
2095
+ union_idx += 1
2096
+
2097
+ for token1, token2, max_distance, force in inference_contact_constraints:
2098
+ if not force:
2099
+ continue
2100
+
2101
+ for idx1, _token1 in enumerate(token_data):
2102
+ if (
2103
+ _token1["mol_type"] != const.chain_type_ids["NONPOLYMER"]
2104
+ and (_token1["asym_id"], _token1["res_idx"]) == token1
2105
+ ) or (
2106
+ _token1["mol_type"] == const.chain_type_ids["NONPOLYMER"]
2107
+ and (_token1["asym_id"], _token1["atom_idx"]) == token1
2108
+ ):
2109
+ for idx2, _token2 in enumerate(token_data):
2110
+ if (
2111
+ _token2["mol_type"] != const.chain_type_ids["NONPOLYMER"]
2112
+ and (_token2["asym_id"], _token2["res_idx"]) == token2
2113
+ ) or (
2114
+ _token2["mol_type"] == const.chain_type_ids["NONPOLYMER"]
2115
+ and (_token2["asym_id"], _token2["atom_idx"]) == token2
2116
+ ):
2117
+ atom_idx_pairs = torch.cartesian_prod(
2118
+ torch.arange(
2119
+ _token1["atom_idx"],
2120
+ _token1["atom_idx"] + _token1["atom_num"],
2121
+ ),
2122
+ torch.arange(
2123
+ _token2["atom_idx"],
2124
+ _token2["atom_idx"] + _token2["atom_num"],
2125
+ ),
2126
+ ).T
2127
+ pair_index.append(atom_idx_pairs)
2128
+ union_index.append(
2129
+ torch.full((atom_idx_pairs.shape[1],), union_idx)
2130
+ )
2131
+ negation_mask.append(
2132
+ torch.ones((atom_idx_pairs.shape[1],), dtype=torch.bool)
2133
+ )
2134
+ thresholds.append(
2135
+ torch.full((atom_idx_pairs.shape[1],), max_distance)
2136
+ )
2137
+ union_idx += 1
2138
+ break
2139
+ break
2140
+
2141
+ if len(pair_index) > 0:
2142
+ pair_index = torch.cat(pair_index, dim=1)
2143
+ union_index = torch.cat(union_index)
2144
+ negation_mask = torch.cat(negation_mask)
2145
+ thresholds = torch.cat(thresholds)
2146
+ else:
2147
+ pair_index = torch.empty((2, 0), dtype=torch.long)
2148
+ union_index = torch.empty((0,), dtype=torch.long)
2149
+ negation_mask = torch.empty((0,), dtype=torch.bool)
2150
+ thresholds = torch.empty((0,), dtype=torch.float32)
2151
+
2152
+ return {
2153
+ "contact_pair_index": pair_index,
2154
+ "contact_union_index": union_index,
2155
+ "contact_negation_mask": negation_mask,
2156
+ "contact_thresholds": thresholds,
2157
+ }
2158
+
2159
+
2160
+ class Boltz2Featurizer:
2161
+ """Boltz2 featurizer."""
2162
+
2163
+ def process(
2164
+ self,
2165
+ data: Tokenized,
2166
+ random: np.random.Generator,
2167
+ molecules: dict[str, Mol],
2168
+ training: bool,
2169
+ max_seqs: int,
2170
+ atoms_per_window_queries: int = 32,
2171
+ min_dist: float = 2.0,
2172
+ max_dist: float = 22.0,
2173
+ num_bins: int = 64,
2174
+ num_ensembles: int = 1,
2175
+ ensemble_sample_replacement: bool = False,
2176
+ disto_use_ensemble: Optional[bool] = False,
2177
+ fix_single_ensemble: Optional[bool] = True,
2178
+ max_tokens: Optional[int] = None,
2179
+ max_atoms: Optional[int] = None,
2180
+ pad_to_max_seqs: bool = False,
2181
+ compute_symmetries: bool = False,
2182
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
2183
+ contact_conditioned_prop: Optional[float] = 0.0,
2184
+ binder_pocket_cutoff_min: Optional[float] = 4.0,
2185
+ binder_pocket_cutoff_max: Optional[float] = 20.0,
2186
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
2187
+ only_ligand_binder_pocket: Optional[bool] = False,
2188
+ only_pp_contact: Optional[bool] = False,
2189
+ single_sequence_prop: Optional[float] = 0.0,
2190
+ msa_sampling: bool = False,
2191
+ override_bfactor: float = False,
2192
+ override_method: Optional[str] = None,
2193
+ compute_frames: bool = False,
2194
+ override_coords: Optional[Tensor] = None,
2195
+ bfactor_md_correction: bool = False,
2196
+ compute_constraint_features: bool = False,
2197
+ inference_pocket_constraints: Optional[
2198
+ list[tuple[int, list[tuple[int, int]], float]]
2199
+ ] = None,
2200
+ inference_contact_constraints: Optional[
2201
+ list[tuple[tuple[int, int], tuple[int, int], float]]
2202
+ ] = None,
2203
+ compute_affinity: bool = False,
2204
+ ) -> dict[str, Tensor]:
2205
+ """Compute features.
2206
+
2207
+ Parameters
2208
+ ----------
2209
+ data : Tokenized
2210
+ The input to the model.
2211
+ training : bool
2212
+ Whether the model is in training mode.
2213
+ max_tokens : int, optional
2214
+ The maximum number of tokens.
2215
+ max_atoms : int, optional
2216
+ The maximum number of atoms
2217
+ max_seqs : int, optional
2218
+ The maximum number of sequences.
2219
+
2220
+ Returns
2221
+ -------
2222
+ dict[str, Tensor]
2223
+ The features for model training.
2224
+
2225
+ """
2226
+ # Compute random number of sequences
2227
+ if training and max_seqs is not None:
2228
+ if random.random() > single_sequence_prop:
2229
+ max_seqs_batch = random.integers(1, max_seqs + 1)
2230
+ else:
2231
+ max_seqs_batch = 1
2232
+ else:
2233
+ max_seqs_batch = max_seqs
2234
+
2235
+ # Compute ensemble features
2236
+ ensemble_features = process_ensemble_features(
2237
+ data=data,
2238
+ random=random,
2239
+ num_ensembles=num_ensembles,
2240
+ ensemble_sample_replacement=ensemble_sample_replacement,
2241
+ fix_single_ensemble=fix_single_ensemble,
2242
+ )
2243
+
2244
+ # Compute token features
2245
+ token_features = process_token_features(
2246
+ data=data,
2247
+ random=random,
2248
+ max_tokens=max_tokens,
2249
+ binder_pocket_conditioned_prop=binder_pocket_conditioned_prop,
2250
+ contact_conditioned_prop=contact_conditioned_prop,
2251
+ binder_pocket_cutoff_min=binder_pocket_cutoff_min,
2252
+ binder_pocket_cutoff_max=binder_pocket_cutoff_max,
2253
+ binder_pocket_sampling_geometric_p=binder_pocket_sampling_geometric_p,
2254
+ only_ligand_binder_pocket=only_ligand_binder_pocket,
2255
+ only_pp_contact=only_pp_contact,
2256
+ override_method=override_method,
2257
+ inference_pocket_constraints=inference_pocket_constraints,
2258
+ inference_contact_constraints=inference_contact_constraints,
2259
+ )
2260
+
2261
+ # Compute atom features
2262
+ atom_features = process_atom_features(
2263
+ data=data,
2264
+ random=random,
2265
+ molecules=molecules,
2266
+ ensemble_features=ensemble_features,
2267
+ atoms_per_window_queries=atoms_per_window_queries,
2268
+ min_dist=min_dist,
2269
+ max_dist=max_dist,
2270
+ num_bins=num_bins,
2271
+ max_atoms=max_atoms,
2272
+ max_tokens=max_tokens,
2273
+ disto_use_ensemble=disto_use_ensemble,
2274
+ override_bfactor=override_bfactor,
2275
+ compute_frames=compute_frames,
2276
+ override_coords=override_coords,
2277
+ bfactor_md_correction=bfactor_md_correction,
2278
+ )
2279
+
2280
+ # Compute MSA features
2281
+ msa_features = process_msa_features(
2282
+ data=data,
2283
+ random=random,
2284
+ max_seqs_batch=max_seqs_batch,
2285
+ max_seqs=max_seqs,
2286
+ max_tokens=max_tokens,
2287
+ pad_to_max_seqs=pad_to_max_seqs,
2288
+ msa_sampling=training and msa_sampling,
2289
+ )
2290
+
2291
+ # Compute MSA features
2292
+ msa_features_affinity = {}
2293
+ if compute_affinity:
2294
+ msa_features_affinity = process_msa_features(
2295
+ data=data,
2296
+ random=random,
2297
+ max_seqs_batch=1,
2298
+ max_seqs=1,
2299
+ max_tokens=max_tokens,
2300
+ pad_to_max_seqs=pad_to_max_seqs,
2301
+ msa_sampling=training and msa_sampling,
2302
+ affinity=True,
2303
+ )
2304
+
2305
+ # Compute affinity ligand Molecular Weight
2306
+ ligand_to_mw = {}
2307
+ if compute_affinity:
2308
+ ligand_to_mw["affinity_mw"] = data.record.affinity.mw
2309
+
2310
+ # Compute template features
2311
+ num_tokens = data.tokens.shape[0] if max_tokens is None else max_tokens
2312
+ if data.templates and not compute_affinity:
2313
+ template_features = process_template_features(
2314
+ data=data,
2315
+ max_tokens=num_tokens,
2316
+ )
2317
+ else:
2318
+ template_features = load_dummy_templates_features(
2319
+ tdim=1,
2320
+ num_tokens=num_tokens,
2321
+ )
2322
+
2323
+ # Compute symmetry features
2324
+ symmetry_features = {}
2325
+ if compute_symmetries:
2326
+ symmetries = get_symmetries(molecules)
2327
+ symmetry_features = process_symmetry_features(data, symmetries)
2328
+
2329
+ # Compute constraint features
2330
+ residue_constraint_features = {}
2331
+ chain_constraint_features = {}
2332
+ contact_constraint_features = {}
2333
+ if compute_constraint_features:
2334
+ residue_constraint_features = process_residue_constraint_features(data)
2335
+ chain_constraint_features = process_chain_feature_constraints(data)
2336
+ contact_constraint_features = process_contact_feature_constraints(
2337
+ data=data,
2338
+ inference_pocket_constraints=inference_pocket_constraints if inference_pocket_constraints else [],
2339
+ inference_contact_constraints=inference_contact_constraints if inference_contact_constraints else [],
2340
+ )
2341
+
2342
+ return {
2343
+ **token_features,
2344
+ **atom_features,
2345
+ **msa_features,
2346
+ **msa_features_affinity,
2347
+ **template_features,
2348
+ **symmetry_features,
2349
+ **ensemble_features,
2350
+ **residue_constraint_features,
2351
+ **chain_constraint_features,
2352
+ **contact_constraint_features,
2353
+ **ligand_to_mw,
2354
+ }
protify/FastPLMs/boltz/src/boltz/data/feature/symmetry.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import pickle
3
+ import random
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from boltz.data import const
10
+ from boltz.data.pad import pad_dim
11
+ from boltz.model.loss.confidence import lddt_dist
12
+ from boltz.model.loss.validation import weighted_minimum_rmsd_single
13
+
14
+
15
+ def convert_atom_name(name: str) -> tuple[int, int, int, int]:
16
+ """Convert an atom name to a standard format.
17
+
18
+ Parameters
19
+ ----------
20
+ name : str
21
+ The atom name.
22
+
23
+ Returns
24
+ -------
25
+ Tuple[int, int, int, int]
26
+ The converted atom name.
27
+
28
+ """
29
+ name = name.strip()
30
+ name = [ord(c) - 32 for c in name]
31
+ name = name + [0] * (4 - len(name))
32
+ return tuple(name)
33
+
34
+
35
+ def get_symmetries(path: str) -> dict:
36
+ """Create a dictionary for the ligand symmetries.
37
+
38
+ Parameters
39
+ ----------
40
+ path : str
41
+ The path to the ligand symmetries.
42
+
43
+ Returns
44
+ -------
45
+ dict
46
+ The ligand symmetries.
47
+
48
+ """
49
+ with Path(path).open("rb") as f:
50
+ data: dict = pickle.load(f) # noqa: S301
51
+
52
+ symmetries = {}
53
+ for key, mol in data.items():
54
+ try:
55
+ serialized_sym = bytes.fromhex(mol.GetProp("symmetries"))
56
+ sym = pickle.loads(serialized_sym) # noqa: S301
57
+ atom_names = []
58
+ for atom in mol.GetAtoms():
59
+ # Get atom name
60
+ atom_name = convert_atom_name(atom.GetProp("name"))
61
+ atom_names.append(atom_name)
62
+
63
+ symmetries[key] = (sym, atom_names)
64
+ except Exception: # noqa: BLE001, PERF203, S110
65
+ pass
66
+
67
+ return symmetries
68
+
69
+
70
+ def compute_symmetry_idx_dictionary(data):
71
+ # Compute the symmetry index dictionary
72
+ total_count = 0
73
+ all_coords = []
74
+ for i, chain in enumerate(data.chains):
75
+ chain.start_idx = total_count
76
+ for j, token in enumerate(chain.tokens):
77
+ token.start_idx = total_count - chain.start_idx
78
+ all_coords.extend(
79
+ [[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
80
+ )
81
+ total_count += len(token.atoms)
82
+ return all_coords
83
+
84
+
85
+ def get_current_idx_list(data):
86
+ idx = []
87
+ for chain in data.chains:
88
+ if chain.in_crop:
89
+ for token in chain.tokens:
90
+ if token.in_crop:
91
+ idx.extend(
92
+ [
93
+ chain.start_idx + token.start_idx + i
94
+ for i in range(len(token.atoms))
95
+ ]
96
+ )
97
+ return idx
98
+
99
+
100
+ def all_different_after_swap(l):
101
+ final = [s[-1] for s in l]
102
+ return len(final) == len(set(final))
103
+
104
+
105
+ def minimum_symmetry_coords(
106
+ coords: torch.Tensor,
107
+ feats: dict,
108
+ index_batch: int,
109
+ **args_rmsd,
110
+ ):
111
+ all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
112
+ all_resolved_mask = (
113
+ feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
114
+ )
115
+ crop_to_all_atom_map = (
116
+ feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
117
+ )
118
+ chain_symmetries = feats["chain_symmetries"][index_batch]
119
+ amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
120
+ ligand_symmetries = feats["ligand_symmetries"][index_batch]
121
+
122
+ # Check best symmetry on chain swap
123
+ best_true_coords = None
124
+ best_rmsd = float("inf")
125
+ best_align_weights = None
126
+ for c in chain_symmetries:
127
+ true_all_coords = all_coords.clone()
128
+ true_all_resolved_mask = all_resolved_mask.clone()
129
+ for start1, end1, start2, end2, chainidx1, chainidx2 in c:
130
+ true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
131
+ true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
132
+ true_coords = true_all_coords[:, crop_to_all_atom_map]
133
+ true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
134
+ true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
135
+ true_resolved_mask = pad_dim(
136
+ true_resolved_mask,
137
+ 0,
138
+ coords.shape[1] - true_resolved_mask.shape[0],
139
+ )
140
+ try:
141
+ rmsd, aligned_coords, align_weights = weighted_minimum_rmsd_single(
142
+ coords,
143
+ true_coords,
144
+ atom_mask=true_resolved_mask,
145
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
146
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
147
+ **args_rmsd,
148
+ )
149
+ except:
150
+ print("Warning: error in rmsd computation inside symmetry code")
151
+ continue
152
+ rmsd = rmsd.item()
153
+
154
+ if rmsd < best_rmsd:
155
+ best_rmsd = rmsd
156
+ best_true_coords = aligned_coords
157
+ best_align_weights = align_weights
158
+ best_true_resolved_mask = true_resolved_mask
159
+
160
+ # atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
161
+ true_coords = best_true_coords.clone()
162
+ true_resolved_mask = best_true_resolved_mask.clone()
163
+ for symmetric_amino in amino_acids_symmetries:
164
+ for c in symmetric_amino:
165
+ # starting from greedy best, try to swap the atoms
166
+ new_true_coords = true_coords.clone()
167
+ new_true_resolved_mask = true_resolved_mask.clone()
168
+ for i, j in c:
169
+ new_true_coords[:, i] = true_coords[:, j]
170
+ new_true_resolved_mask[i] = true_resolved_mask[j]
171
+
172
+ # compute squared distance, for efficiency we do not recompute the alignment
173
+ best_mse_loss = torch.sum(
174
+ ((coords - best_true_coords) ** 2).sum(dim=-1)
175
+ * best_align_weights
176
+ * best_true_resolved_mask,
177
+ dim=-1,
178
+ ) / torch.sum(best_align_weights * best_true_resolved_mask, dim=-1)
179
+ new_mse_loss = torch.sum(
180
+ ((coords - new_true_coords) ** 2).sum(dim=-1)
181
+ * best_align_weights
182
+ * new_true_resolved_mask,
183
+ dim=-1,
184
+ ) / torch.sum(best_align_weights * new_true_resolved_mask, dim=-1)
185
+
186
+ if best_mse_loss > new_mse_loss:
187
+ best_true_coords = new_true_coords
188
+ best_true_resolved_mask = new_true_resolved_mask
189
+
190
+ # greedily update best coordinates after each amino acid
191
+ true_coords = best_true_coords.clone()
192
+ true_resolved_mask = best_true_resolved_mask.clone()
193
+
194
+ # Recomputing alignment
195
+ rmsd, true_coords, best_align_weights = weighted_minimum_rmsd_single(
196
+ coords,
197
+ true_coords,
198
+ atom_mask=true_resolved_mask,
199
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
200
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
201
+ **args_rmsd,
202
+ )
203
+ best_rmsd = rmsd.item()
204
+
205
+ # atom symmetries (ligand and non-standard), resolved greedily recomputing alignment
206
+ for symmetric_ligand in ligand_symmetries:
207
+ for c in symmetric_ligand:
208
+ new_true_coords = true_coords.clone()
209
+ new_true_resolved_mask = true_resolved_mask.clone()
210
+ for i, j in c:
211
+ new_true_coords[:, j] = true_coords[:, i]
212
+ new_true_resolved_mask[j] = true_resolved_mask[i]
213
+ try:
214
+ # TODO if this is too slow maybe we can get away with not recomputing alignment
215
+ rmsd, aligned_coords, align_weights = weighted_minimum_rmsd_single(
216
+ coords,
217
+ new_true_coords,
218
+ atom_mask=new_true_resolved_mask,
219
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
220
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
221
+ **args_rmsd,
222
+ )
223
+ except Exception as e:
224
+ raise e
225
+ print(e)
226
+ continue
227
+ rmsd = rmsd.item()
228
+ if rmsd < best_rmsd:
229
+ best_true_coords = aligned_coords
230
+ best_rmsd = rmsd
231
+ best_true_resolved_mask = new_true_resolved_mask
232
+
233
+ true_coords = best_true_coords.clone()
234
+ true_resolved_mask = best_true_resolved_mask.clone()
235
+
236
+ return best_true_coords, best_rmsd, best_true_resolved_mask.unsqueeze(0)
237
+
238
+
239
+ def minimum_lddt_symmetry_coords(
240
+ coords: torch.Tensor,
241
+ feats: dict,
242
+ index_batch: int,
243
+ **args_rmsd,
244
+ ):
245
+ all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
246
+ all_resolved_mask = (
247
+ feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
248
+ )
249
+ crop_to_all_atom_map = (
250
+ feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
251
+ )
252
+ chain_symmetries = feats["chain_symmetries"][index_batch]
253
+ amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
254
+ ligand_symmetries = feats["ligand_symmetries"][index_batch]
255
+
256
+ dmat_predicted = torch.cdist(
257
+ coords[:, : len(crop_to_all_atom_map)], coords[:, : len(crop_to_all_atom_map)]
258
+ )
259
+
260
+ # Check best symmetry on chain swap
261
+ best_true_coords = None
262
+ best_lddt = 0
263
+ for c in chain_symmetries:
264
+ true_all_coords = all_coords.clone()
265
+ true_all_resolved_mask = all_resolved_mask.clone()
266
+ for start1, end1, start2, end2, chainidx1, chainidx2 in c:
267
+ true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
268
+ true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
269
+ true_coords = true_all_coords[:, crop_to_all_atom_map]
270
+ true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
271
+ dmat_true = torch.cdist(true_coords, true_coords)
272
+ pair_mask = (
273
+ true_resolved_mask[:, None]
274
+ * true_resolved_mask[None, :]
275
+ * (1 - torch.eye(len(true_resolved_mask))).to(true_resolved_mask)
276
+ )
277
+
278
+ lddt = lddt_dist(
279
+ dmat_predicted, dmat_true, pair_mask, cutoff=15.0, per_atom=False
280
+ )[0]
281
+ lddt = lddt.item()
282
+
283
+ if lddt > best_lddt:
284
+ best_lddt = lddt
285
+ best_true_coords = true_coords
286
+ best_true_resolved_mask = true_resolved_mask
287
+
288
+ # atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
289
+ true_coords = best_true_coords.clone()
290
+ true_resolved_mask = best_true_resolved_mask.clone()
291
+ for symmetric_amino_or_lig in amino_acids_symmetries + ligand_symmetries:
292
+ for c in symmetric_amino_or_lig:
293
+ # starting from greedy best, try to swap the atoms
294
+ new_true_coords = true_coords.clone()
295
+ new_true_resolved_mask = true_resolved_mask.clone()
296
+ indices = []
297
+ for i, j in c:
298
+ new_true_coords[:, i] = true_coords[:, j]
299
+ new_true_resolved_mask[i] = true_resolved_mask[j]
300
+ indices.append(i)
301
+
302
+ indices = (
303
+ torch.from_numpy(np.asarray(indices)).to(new_true_coords.device).long()
304
+ )
305
+
306
+ pred_coords_subset = coords[:, : len(crop_to_all_atom_map)][:, indices]
307
+ true_coords_subset = true_coords[:, indices]
308
+ new_true_coords_subset = new_true_coords[:, indices]
309
+
310
+ sub_dmat_pred = torch.cdist(
311
+ coords[:, : len(crop_to_all_atom_map)], pred_coords_subset
312
+ )
313
+ sub_dmat_true = torch.cdist(true_coords, true_coords_subset)
314
+ sub_dmat_new_true = torch.cdist(new_true_coords, new_true_coords_subset)
315
+
316
+ sub_true_pair_lddt = (
317
+ true_resolved_mask[:, None] * true_resolved_mask[None, indices]
318
+ )
319
+ sub_true_pair_lddt[indices] = (
320
+ sub_true_pair_lddt[indices]
321
+ * (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
322
+ )
323
+
324
+ sub_new_true_pair_lddt = (
325
+ new_true_resolved_mask[:, None] * new_true_resolved_mask[None, indices]
326
+ )
327
+ sub_new_true_pair_lddt[indices] = (
328
+ sub_new_true_pair_lddt[indices]
329
+ * (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
330
+ )
331
+
332
+ lddt = lddt_dist(
333
+ sub_dmat_pred,
334
+ sub_dmat_true,
335
+ sub_true_pair_lddt,
336
+ cutoff=15.0,
337
+ per_atom=False,
338
+ )[0]
339
+ new_lddt = lddt_dist(
340
+ sub_dmat_pred,
341
+ sub_dmat_new_true,
342
+ sub_new_true_pair_lddt,
343
+ cutoff=15.0,
344
+ per_atom=False,
345
+ )[0]
346
+
347
+ if new_lddt > lddt:
348
+ best_true_coords = new_true_coords
349
+ best_true_resolved_mask = new_true_resolved_mask
350
+
351
+ # greedily update best coordinates after each amino acid
352
+ true_coords = best_true_coords.clone()
353
+ true_resolved_mask = best_true_resolved_mask.clone()
354
+
355
+ # Recomputing alignment
356
+ true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
357
+ true_resolved_mask = pad_dim(
358
+ true_resolved_mask,
359
+ 0,
360
+ coords.shape[1] - true_resolved_mask.shape[0],
361
+ )
362
+
363
+ try:
364
+ rmsd, true_coords, _ = weighted_minimum_rmsd_single(
365
+ coords,
366
+ true_coords,
367
+ atom_mask=true_resolved_mask,
368
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
369
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
370
+ **args_rmsd,
371
+ )
372
+ best_rmsd = rmsd.item()
373
+ except Exception as e:
374
+ print("Failed proper RMSD computation, returning inf. Error: ", e)
375
+ best_rmsd = 1000
376
+
377
+ return true_coords, best_rmsd, true_resolved_mask.unsqueeze(0)
378
+
379
+
380
+ def compute_all_coords_mask(structure):
381
+ # Compute all coords, crop mask and add start_idx to structure
382
+ total_count = 0
383
+ all_coords = []
384
+ all_coords_crop_mask = []
385
+ all_resolved_mask = []
386
+ for i, chain in enumerate(structure.chains):
387
+ chain.start_idx = total_count
388
+ for j, token in enumerate(chain.tokens):
389
+ token.start_idx = total_count - chain.start_idx
390
+ all_coords.extend(
391
+ [[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
392
+ )
393
+ all_coords_crop_mask.extend(
394
+ [token.in_crop for _ in range(len(token.atoms))]
395
+ )
396
+ all_resolved_mask.extend(
397
+ [token.is_present for _ in range(len(token.atoms))]
398
+ )
399
+ total_count += len(token.atoms)
400
+ if len(all_coords_crop_mask) != len(all_resolved_mask):
401
+ pass
402
+ return all_coords, all_coords_crop_mask, all_resolved_mask
403
+
404
+
405
+ def get_chain_symmetries(cropped, max_n_symmetries=100):
406
+ # get all coordinates and resolved mask
407
+ structure = cropped.structure
408
+ all_coords = []
409
+ all_resolved_mask = []
410
+ original_atom_idx = []
411
+ chain_atom_idx = []
412
+ chain_atom_num = []
413
+ chain_in_crop = []
414
+ chain_asym_id = []
415
+ new_atom_idx = 0
416
+
417
+ for chain in structure.chains:
418
+ atom_idx, atom_num = (
419
+ chain["atom_idx"],
420
+ chain["atom_num"],
421
+ )
422
+
423
+ # compute coordinates and resolved mask
424
+ resolved_mask = structure.atoms["is_present"][atom_idx : atom_idx + atom_num]
425
+
426
+ # ensemble_atom_starts = [structure.ensemble[idx]["atom_coord_idx"] for idx in cropped.ensemble_ref_idxs]
427
+ # coords = np.array(
428
+ # [structure.coords[ensemble_atom_start + atom_idx: ensemble_atom_start + atom_idx + atom_num]["coords"] for
429
+ # ensemble_atom_start in ensemble_atom_starts])
430
+
431
+ coords = structure.atoms["coords"][atom_idx : atom_idx + atom_num]
432
+
433
+ in_crop = False
434
+ for token in cropped.tokens:
435
+ if token["asym_id"] == chain["asym_id"]:
436
+ in_crop = True
437
+ break
438
+
439
+ all_coords.append(coords)
440
+ all_resolved_mask.append(resolved_mask)
441
+ original_atom_idx.append(atom_idx)
442
+ chain_atom_idx.append(new_atom_idx)
443
+ chain_atom_num.append(atom_num)
444
+ chain_in_crop.append(in_crop)
445
+ chain_asym_id.append(chain["asym_id"])
446
+
447
+ new_atom_idx += atom_num
448
+
449
+ # Compute backmapping from token to all coords
450
+ crop_to_all_atom_map = []
451
+ for token in cropped.tokens:
452
+ chain_idx = chain_asym_id.index(token["asym_id"])
453
+ start = (
454
+ chain_atom_idx[chain_idx] - original_atom_idx[chain_idx] + token["atom_idx"]
455
+ )
456
+ crop_to_all_atom_map.append(np.arange(start, start + token["atom_num"]))
457
+
458
+ # Compute the symmetries between chains
459
+ swaps = []
460
+ for i, chain in enumerate(structure.chains):
461
+ start = chain_atom_idx[i]
462
+ end = start + chain_atom_num[i]
463
+ if chain_in_crop[i]:
464
+ possible_swaps = []
465
+ for j, chain2 in enumerate(structure.chains):
466
+ start2 = chain_atom_idx[j]
467
+ end2 = start2 + chain_atom_num[j]
468
+ if (
469
+ chain["entity_id"] == chain2["entity_id"]
470
+ and end - start == end2 - start2
471
+ ):
472
+ possible_swaps.append((start, end, start2, end2, i, j))
473
+ swaps.append(possible_swaps)
474
+ combinations = itertools.product(*swaps)
475
+ # to avoid combinatorial explosion, bound the number of combinations even considered
476
+ combinations = list(itertools.islice(combinations, max_n_symmetries * 10))
477
+ # filter for all chains getting a different assignment
478
+ combinations = [c for c in combinations if all_different_after_swap(c)]
479
+
480
+ if len(combinations) > max_n_symmetries:
481
+ combinations = random.sample(combinations, max_n_symmetries)
482
+
483
+ if len(combinations) == 0:
484
+ combinations.append([])
485
+
486
+ features = {}
487
+ features["all_coords"] = torch.Tensor(
488
+ np.concatenate(all_coords, axis=0)
489
+ ) # axis=1 with ensemble
490
+
491
+ features["all_resolved_mask"] = torch.Tensor(
492
+ np.concatenate(all_resolved_mask, axis=0)
493
+ )
494
+ features["crop_to_all_atom_map"] = torch.Tensor(
495
+ np.concatenate(crop_to_all_atom_map, axis=0)
496
+ )
497
+ features["chain_symmetries"] = combinations
498
+
499
+ return features
500
+
501
+
502
+ def get_amino_acids_symmetries(cropped):
503
+ # Compute standard amino-acids symmetries
504
+ swaps = []
505
+ start_index_crop = 0
506
+ for token in cropped.tokens:
507
+ symmetries = const.ref_symmetries.get(const.tokens[token["res_type"]], [])
508
+ if len(symmetries) > 0:
509
+ residue_swaps = []
510
+ for sym in symmetries:
511
+ sym_new_idx = [
512
+ (i + start_index_crop, j + start_index_crop) for i, j in sym
513
+ ]
514
+ residue_swaps.append(sym_new_idx)
515
+ swaps.append(residue_swaps)
516
+ start_index_crop += token["atom_num"]
517
+
518
+ features = {"amino_acids_symmetries": swaps}
519
+ return features
520
+
521
+
522
+ def get_ligand_symmetries(cropped, symmetries):
523
+ # Compute ligand and non-standard amino-acids symmetries
524
+ structure = cropped.structure
525
+
526
+ added_molecules = {}
527
+ index_mols = []
528
+ atom_count = 0
529
+ for token in cropped.tokens:
530
+ # check if molecule is already added by identifying it through asym_id and res_idx
531
+ atom_count += token["atom_num"]
532
+ mol_id = (token["asym_id"], token["res_idx"])
533
+ if mol_id in added_molecules.keys():
534
+ added_molecules[mol_id] += token["atom_num"]
535
+ continue
536
+ added_molecules[mol_id] = token["atom_num"]
537
+
538
+ # get the molecule type and indices
539
+ residue_idx = token["res_idx"] + structure.chains[token["asym_id"]]["res_idx"]
540
+ mol_name = structure.residues[residue_idx]["name"]
541
+ atom_idx = structure.residues[residue_idx]["atom_idx"]
542
+ mol_atom_names = structure.atoms[
543
+ atom_idx : atom_idx + structure.residues[residue_idx]["atom_num"]
544
+ ]["name"]
545
+ mol_atom_names = [tuple(m) for m in mol_atom_names]
546
+ if mol_name not in const.ref_symmetries.keys():
547
+ index_mols.append(
548
+ (mol_name, atom_count - token["atom_num"], mol_id, mol_atom_names)
549
+ )
550
+
551
+ # for each molecule, get the symmetries
552
+ molecule_symmetries = []
553
+ for mol_name, start_mol, mol_id, mol_atom_names in index_mols:
554
+ if not mol_name in symmetries:
555
+ continue
556
+ else:
557
+ swaps = []
558
+ syms_ccd, mol_atom_names_ccd = symmetries[mol_name]
559
+ # Get indices of mol_atom_names_ccd that are in mol_atom_names
560
+ ccd_to_valid_ids = {
561
+ mol_atom_names_ccd.index(name): i
562
+ for i, name in enumerate(mol_atom_names)
563
+ }
564
+ ccd_valid_ids = set(ccd_to_valid_ids.keys())
565
+
566
+ syms = []
567
+ # Get syms
568
+ for sym_ccd in syms_ccd:
569
+ sym_dict = {}
570
+ bool_add = True
571
+ for i, j in enumerate(sym_ccd):
572
+ if i in ccd_valid_ids:
573
+ if j in ccd_valid_ids:
574
+ i_true = ccd_to_valid_ids[i]
575
+ j_true = ccd_to_valid_ids[j]
576
+ sym_dict[i_true] = j_true
577
+ else:
578
+ bool_add = False
579
+ break
580
+ if bool_add:
581
+ syms.append([sym_dict[i] for i in range(len(ccd_valid_ids))])
582
+
583
+ for sym in syms:
584
+ if len(sym) != added_molecules[mol_id]:
585
+ raise Exception(
586
+ f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
587
+ )
588
+ # assert (
589
+ # len(sym) == added_molecules[mol_id]
590
+ # ), f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
591
+ sym_new_idx = []
592
+ for i, j in enumerate(sym):
593
+ if i != int(j):
594
+ sym_new_idx.append((i + start_mol, int(j) + start_mol))
595
+ if len(sym_new_idx) > 0:
596
+ swaps.append(sym_new_idx)
597
+ if len(swaps) > 0:
598
+ molecule_symmetries.append(swaps)
599
+
600
+ features = {"ligand_symmetries": molecule_symmetries}
601
+
602
+ return features
protify/FastPLMs/boltz/src/boltz/data/filter/__init__.py ADDED
File without changes
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/__init__.py ADDED
File without changes
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/date.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import Literal
3
+
4
+ from boltz.data.types import Record
5
+ from boltz.data.filter.dynamic.filter import DynamicFilter
6
+
7
+
8
+ class DateFilter(DynamicFilter):
9
+ """A filter that filters complexes based on their date.
10
+
11
+ The date can be the deposition, release, or revision date.
12
+ If the date is not available, the previous date is used.
13
+
14
+ If no date is available, the complex is rejected.
15
+
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ date: str,
21
+ ref: Literal["deposited", "revised", "released"],
22
+ ) -> None:
23
+ """Initialize the filter.
24
+
25
+ Parameters
26
+ ----------
27
+ date : str, optional
28
+ The maximum date of PDB entries to filter
29
+ ref : Literal["deposited", "revised", "released"]
30
+ The reference date to use.
31
+
32
+ """
33
+ self.filter_date = datetime.fromisoformat(date)
34
+ self.ref = ref
35
+
36
+ if ref not in ["deposited", "revised", "released"]:
37
+ msg = (
38
+ "Invalid reference date. Must be ",
39
+ "deposited, revised, or released",
40
+ )
41
+ raise ValueError(msg)
42
+
43
+ def filter(self, record: Record) -> bool:
44
+ """Filter a record based on its date.
45
+
46
+ Parameters
47
+ ----------
48
+ record : Record
49
+ The record to filter.
50
+
51
+ Returns
52
+ -------
53
+ bool
54
+ Whether the record should be filtered.
55
+
56
+ """
57
+ structure = record.structure
58
+
59
+ if self.ref == "deposited":
60
+ date = structure.deposited
61
+ elif self.ref == "released":
62
+ date = structure.released
63
+ if not date:
64
+ date = structure.deposited
65
+ elif self.ref == "revised":
66
+ date = structure.revised
67
+ if not date and structure.released:
68
+ date = structure.released
69
+ elif not date:
70
+ date = structure.deposited
71
+
72
+ if date is None or date == "":
73
+ return False
74
+
75
+ date = datetime.fromisoformat(date)
76
+ return date <= self.filter_date
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/filter.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ from boltz.data.types import Record
4
+
5
+
6
+ class DynamicFilter(ABC):
7
+ """Base class for data filters."""
8
+
9
+ @abstractmethod
10
+ def filter(self, record: Record) -> bool:
11
+ """Filter a data record.
12
+
13
+ Parameters
14
+ ----------
15
+ record : Record
16
+ The object to consider filtering in / out.
17
+
18
+ Returns
19
+ -------
20
+ bool
21
+ True if the data passes the filter, False otherwise.
22
+
23
+ """
24
+ raise NotImplementedError
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/max_residues.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from boltz.data.types import Record
2
+ from boltz.data.filter.dynamic.filter import DynamicFilter
3
+
4
+
5
+ class MaxResiduesFilter(DynamicFilter):
6
+ """A filter that filters structures based on their size."""
7
+
8
+ def __init__(self, min_residues: int = 1, max_residues: int = 500) -> None:
9
+ """Initialize the filter.
10
+
11
+ Parameters
12
+ ----------
13
+ min_chains : int
14
+ The minimum number of chains allowed.
15
+ max_chains : int
16
+ The maximum number of chains allowed.
17
+
18
+ """
19
+ self.min_residues = min_residues
20
+ self.max_residues = max_residues
21
+
22
+ def filter(self, record: Record) -> bool:
23
+ """Filter structures based on their resolution.
24
+
25
+ Parameters
26
+ ----------
27
+ record : Record
28
+ The record to filter.
29
+
30
+ Returns
31
+ -------
32
+ bool
33
+ Whether the record should be filtered.
34
+
35
+ """
36
+ num_residues = sum(chain.num_residues for chain in record.chains)
37
+ return num_residues <= self.max_residues and num_residues >= self.min_residues
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/resolution.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from boltz.data.types import Record
2
+ from boltz.data.filter.dynamic.filter import DynamicFilter
3
+
4
+
5
+ class ResolutionFilter(DynamicFilter):
6
+ """A filter that filters complexes based on their resolution."""
7
+
8
+ def __init__(self, resolution: float = 9.0) -> None:
9
+ """Initialize the filter.
10
+
11
+ Parameters
12
+ ----------
13
+ resolution : float, optional
14
+ The maximum allowed resolution.
15
+
16
+ """
17
+ self.resolution = resolution
18
+
19
+ def filter(self, record: Record) -> bool:
20
+ """Filter complexes based on their resolution.
21
+
22
+ Parameters
23
+ ----------
24
+ record : Record
25
+ The record to filter.
26
+
27
+ Returns
28
+ -------
29
+ bool
30
+ Whether the record should be filtered.
31
+
32
+ """
33
+ structure = record.structure
34
+ return structure.resolution <= self.resolution
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/size.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from boltz.data.types import Record
2
+ from boltz.data.filter.dynamic.filter import DynamicFilter
3
+
4
+
5
+ class SizeFilter(DynamicFilter):
6
+ """A filter that filters structures based on their size."""
7
+
8
+ def __init__(self, min_chains: int = 1, max_chains: int = 300) -> None:
9
+ """Initialize the filter.
10
+
11
+ Parameters
12
+ ----------
13
+ min_chains : int
14
+ The minimum number of chains allowed.
15
+ max_chains : int
16
+ The maximum number of chains allowed.
17
+
18
+ """
19
+ self.min_chains = min_chains
20
+ self.max_chains = max_chains
21
+
22
+ def filter(self, record: Record) -> bool:
23
+ """Filter structures based on their resolution.
24
+
25
+ Parameters
26
+ ----------
27
+ record : Record
28
+ The record to filter.
29
+
30
+ Returns
31
+ -------
32
+ bool
33
+ Whether the record should be filtered.
34
+
35
+ """
36
+ num_chains = record.structure.num_chains
37
+ num_valid = sum(1 for chain in record.chains if chain.valid)
38
+ return num_chains <= self.max_chains and num_valid >= self.min_chains
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/subset.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from boltz.data.types import Record
4
+ from boltz.data.filter.dynamic.filter import DynamicFilter
5
+
6
+
7
+ class SubsetFilter(DynamicFilter):
8
+ """Filter a data record based on a subset of the data."""
9
+
10
+ def __init__(self, subset: str, reverse: bool = False) -> None:
11
+ """Initialize the filter.
12
+
13
+ Parameters
14
+ ----------
15
+ subset : str
16
+ The subset of data to consider, one per line.
17
+
18
+ """
19
+ with Path(subset).open("r") as f:
20
+ subset = f.read().splitlines()
21
+
22
+ self.subset = {s.lower() for s in subset}
23
+ self.reverse = reverse
24
+
25
+ def filter(self, record: Record) -> bool:
26
+ """Filter a data record.
27
+
28
+ Parameters
29
+ ----------
30
+ record : Record
31
+ The object to consider filtering in / out.
32
+
33
+ Returns
34
+ -------
35
+ bool
36
+ True if the data passes the filter, False otherwise.
37
+
38
+ """
39
+ if self.reverse:
40
+ return record.id.lower() not in self.subset
41
+ else: # noqa: RET505
42
+ return record.id.lower() in self.subset
protify/FastPLMs/boltz/src/boltz/data/filter/static/__init__.py ADDED
File without changes
protify/FastPLMs/boltz/src/boltz/data/filter/static/filter.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+
5
+ from boltz.data.types import Structure
6
+
7
+
8
+ class StaticFilter(ABC):
9
+ """Base class for structure filters."""
10
+
11
+ @abstractmethod
12
+ def filter(self, structure: Structure) -> np.ndarray:
13
+ """Filter chains in a structure.
14
+
15
+ Parameters
16
+ ----------
17
+ structure : Structure
18
+ The structure to filter chains from.
19
+
20
+ Returns
21
+ -------
22
+ np.ndarray
23
+ The chains to keep, as a boolean mask.
24
+
25
+ """
26
+ raise NotImplementedError
protify/FastPLMs/boltz/src/boltz/data/filter/static/ligand.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from boltz.data import const
4
+ from boltz.data.filter.static.filter import StaticFilter
5
+ from boltz.data.types import Structure
6
+
7
+
8
+ class ExcludedLigands(StaticFilter):
9
+ """Filter excluded ligands."""
10
+
11
+ def filter(self, structure: Structure) -> np.ndarray:
12
+ """Filter excluded ligands.
13
+
14
+ Parameters
15
+ ----------
16
+ structure : Structure
17
+ The structure to filter chains from.
18
+
19
+ Returns
20
+ -------
21
+ np.ndarray
22
+ The chains to keep, as a boolean mask.
23
+
24
+ """
25
+ valid = np.ones(len(structure.chains), dtype=bool)
26
+
27
+ for i, chain in enumerate(structure.chains):
28
+ if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]:
29
+ continue
30
+
31
+ res_start = chain["res_idx"]
32
+ res_end = res_start + chain["res_num"]
33
+ residues = structure.residues[res_start:res_end]
34
+ if any(res["name"] in const.ligand_exclusion for res in residues):
35
+ valid[i] = 0
36
+
37
+ return valid
protify/FastPLMs/boltz/src/boltz/data/filter/static/polymer.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ from sklearn.neighbors import KDTree
6
+
7
+ from boltz.data import const
8
+ from boltz.data.filter.static.filter import StaticFilter
9
+ from boltz.data.types import Structure
10
+
11
+
12
+ class MinimumLengthFilter(StaticFilter):
13
+ """Filter polymers based on their length.
14
+
15
+ We use the number of resolved residues when considering
16
+ the minimum, and the sequence length for the maximum.
17
+
18
+ """
19
+
20
+ def __init__(self, min_len: int = 4, max_len: int = 5000) -> None:
21
+ """Initialize the filter.
22
+
23
+ Parameters
24
+ ----------
25
+ min_len : float, optional
26
+ The minimum allowed length.
27
+ max_len : float, optional
28
+ The maximum allowed length.
29
+
30
+ """
31
+ self._min = min_len
32
+ self._max = max_len
33
+
34
+ def filter(self, structure: Structure) -> np.ndarray:
35
+ """Filter a chains based on their length.
36
+
37
+ Parameters
38
+ ----------
39
+ structure : Structure
40
+ The structure to filter chains from.
41
+
42
+ Returns
43
+ -------
44
+ np.ndarray
45
+ The chains to keep, as a boolean mask.
46
+
47
+ """
48
+ valid = np.ones(len(structure.chains), dtype=bool)
49
+
50
+ for i, chain in enumerate(structure.chains):
51
+ if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
52
+ continue
53
+
54
+ res_start = chain["res_idx"]
55
+ res_end = res_start + chain["res_num"]
56
+ residues = structure.residues[res_start:res_end]
57
+ resolved = residues["is_present"].sum()
58
+
59
+ if (resolved < self._min) or (resolved > self._max):
60
+ valid[i] = 0
61
+
62
+ return valid
63
+
64
+
65
+ class UnknownFilter(StaticFilter):
66
+ """Filter proteins with all unknown residues."""
67
+
68
+ def filter(self, structure: Structure) -> np.ndarray:
69
+ """Filter proteins with all unknown residues.
70
+
71
+ Parameters
72
+ ----------
73
+ structure : Structure
74
+ The structure to filter chains from.
75
+
76
+ Returns
77
+ -------
78
+ np.ndarray
79
+ The chains to keep, as a boolean mask.
80
+
81
+ """
82
+ valid = np.ones(len(structure.chains), dtype=bool)
83
+ unk_toks = {
84
+ const.chain_type_ids["PROTEIN"]: const.unk_token_ids["PROTEIN"],
85
+ const.chain_type_ids["DNA"]: const.unk_token_ids["DNA"],
86
+ const.chain_type_ids["RNA"]: const.unk_token_ids["RNA"],
87
+ }
88
+
89
+ for i, chain in enumerate(structure.chains):
90
+ if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
91
+ continue
92
+
93
+ res_start = chain["res_idx"]
94
+ res_end = res_start + chain["res_num"]
95
+ residues = structure.residues[res_start:res_end]
96
+
97
+ unk_id = unk_toks[chain["mol_type"]]
98
+ if np.all(residues["res_type"] == unk_id):
99
+ valid[i] = 0
100
+
101
+ return valid
102
+
103
+
104
+ class ConsecutiveCA(StaticFilter):
105
+ """Filter proteins with consecutive CA atoms above a threshold."""
106
+
107
+ def __init__(self, max_dist: int = 10.0) -> None:
108
+ """Initialize the filter.
109
+
110
+ Parameters
111
+ ----------
112
+ max_dist : float, optional
113
+ The maximum allowed distance.
114
+
115
+ """
116
+ self._max_dist = max_dist
117
+
118
+ def filter(self, structure: Structure) -> np.ndarray:
119
+ """Filter protein if consecutive CA atoms above a threshold.
120
+
121
+ Parameters
122
+ ----------
123
+ structure : Structure
124
+ The structure to filter chains from.
125
+
126
+ Returns
127
+ -------
128
+ np.ndarray
129
+ The chains to keep, as a boolean mask.
130
+
131
+ """
132
+ valid = np.ones(len(structure.chains), dtype=bool)
133
+
134
+ # Remove chain if consecutive CA atoms are above threshold
135
+ for i, chain in enumerate(structure.chains):
136
+ # Skip non-protein chains
137
+ if chain["mol_type"] != const.chain_type_ids["PROTEIN"]:
138
+ continue
139
+
140
+ # Get residues
141
+ res_start = chain["res_idx"]
142
+ res_end = res_start + chain["res_num"]
143
+ residues = structure.residues[res_start:res_end]
144
+
145
+ # Get c-alphas
146
+ ca_ids = residues["atom_center"]
147
+ ca_atoms = structure.atoms[ca_ids]
148
+
149
+ res_valid = residues["is_present"]
150
+ ca_valid = ca_atoms["is_present"] & res_valid
151
+ ca_coords = ca_atoms["coords"]
152
+
153
+ # Compute distances between consecutive atoms
154
+ dist = np.linalg.norm(ca_coords[1:] - ca_coords[:-1], axis=1)
155
+ dist = dist > self._max_dist
156
+ dist = dist[ca_valid[1:] & ca_valid[:-1]]
157
+
158
+ # Remove the chain if any valid pair is above threshold
159
+ if np.any(dist):
160
+ valid[i] = 0
161
+
162
+ return valid
163
+
164
+
165
+ @dataclass(frozen=True)
166
+ class Clash:
167
+ """A clash between two chains."""
168
+
169
+ chain: int
170
+ other: int
171
+ num_atoms: int
172
+ num_clashes: int
173
+
174
+
175
+ class ClashingChainsFilter(StaticFilter):
176
+ """A filter that filters clashing chains.
177
+
178
+ Clashing chains are defined as those with >30% of atoms
179
+ within 1.7 Å of an atom in another chain. If two chains
180
+ are clashing with each other, the chain with the greater
181
+ percentage of clashing atoms will be removed. If the same
182
+ fraction of atoms are clashing, the chain with fewer total
183
+ atoms is removed. If the chains have the same number of
184
+ atoms, then the chain with the larger chain id is removed.
185
+
186
+ """
187
+
188
+ def __init__(self, dist: float = 1.7, freq: float = 0.3) -> None:
189
+ """Initialize the filter.
190
+
191
+ Parameters
192
+ ----------
193
+ dist : float, optional
194
+ The maximum distance for a clash.
195
+ freq : float, optional
196
+ The maximum allowed frequency of clashes.
197
+
198
+ """
199
+ self._dist = dist
200
+ self._freq = freq
201
+
202
+ def filter(self, structure: Structure) -> np.ndarray: # noqa: PLR0912, C901
203
+ """Filter out clashing chains.
204
+
205
+ Parameters
206
+ ----------
207
+ structure : Structure
208
+ The structure to filter chains from.
209
+
210
+ Returns
211
+ -------
212
+ np.ndarray
213
+ The chains to keep, as a boolean mask.
214
+
215
+ """
216
+ num_chains = len(structure.chains)
217
+ if num_chains < 2: # noqa: PLR2004
218
+ return np.ones(num_chains, dtype=bool)
219
+
220
+ # Get unique chain pairs
221
+ pairs = itertools.combinations(range(num_chains), 2)
222
+
223
+ # Compute clashes
224
+ clashes: list[Clash] = []
225
+ for i, j in pairs:
226
+ # Get the chains
227
+ c1 = structure.chains[i]
228
+ c2 = structure.chains[j]
229
+
230
+ # Get the atoms from each chain
231
+ c1_start = c1["atom_idx"]
232
+ c2_start = c2["atom_idx"]
233
+ c1_end = c1_start + c1["atom_num"]
234
+ c2_end = c2_start + c2["atom_num"]
235
+
236
+ atoms1 = structure.atoms[c1_start:c1_end]
237
+ atoms2 = structure.atoms[c2_start:c2_end]
238
+ atoms1 = atoms1[atoms1["is_present"]]
239
+ atoms2 = atoms2[atoms2["is_present"]]
240
+
241
+ # Skip if either chain has no atoms
242
+ if len(atoms1) == 0 or len(atoms2) == 0:
243
+ continue
244
+
245
+ # Compute the number of clashes
246
+ # Compute the distance matrix
247
+ tree = KDTree(atoms1["coords"], metric="euclidean")
248
+ query = tree.query_radius(atoms2["coords"], self._dist)
249
+
250
+ c2_clashes = sum(len(neighbors) > 0 for neighbors in query)
251
+ c1_clashes = len(set(itertools.chain.from_iterable(query)))
252
+
253
+ # Save results
254
+ if (c1_clashes / len(atoms1)) > self._freq:
255
+ clashes.append(Clash(i, j, len(atoms1), c1_clashes))
256
+ if (c2_clashes / len(atoms2)) > self._freq:
257
+ clashes.append(Clash(j, i, len(atoms2), c2_clashes))
258
+
259
+ # Compute indices to clash map
260
+ removed = set()
261
+ ids_to_clash = {(c.chain, c.other): c for c in clashes}
262
+
263
+ # Filter out chains according to ruleset
264
+ for clash in clashes:
265
+ # If either is already removed, skip
266
+ if clash.chain in removed or clash.other in removed:
267
+ continue
268
+
269
+ # Check if the two chains clash with each other
270
+ other_clash = ids_to_clash.get((clash.other, clash.chain))
271
+ if other_clash is not None:
272
+ # Remove the chain with the most clashes
273
+ clash1_freq = clash.num_clashes / clash.num_atoms
274
+ clash2_freq = other_clash.num_clashes / other_clash.num_atoms
275
+ if clash1_freq > clash2_freq:
276
+ removed.add(clash.chain)
277
+ elif clash1_freq < clash2_freq:
278
+ removed.add(clash.other)
279
+
280
+ # If same, remove the chain with fewer atoms
281
+ elif clash.num_atoms < other_clash.num_atoms:
282
+ removed.add(clash.chain)
283
+ elif clash.num_atoms > other_clash.num_atoms:
284
+ removed.add(clash.other)
285
+
286
+ # If same, remove the chain with the larger chain id
287
+ else:
288
+ removed.add(max(clash.chain, clash.other))
289
+
290
+ # Otherwise, just remove the chain directly
291
+ else:
292
+ removed.add(clash.chain)
293
+
294
+ # Remove the chains
295
+ valid = np.ones(len(structure.chains), dtype=bool)
296
+ for i in removed:
297
+ valid[i] = 0
298
+
299
+ return valid
protify/FastPLMs/boltz/src/boltz/data/module/__init__.py ADDED
File without changes
protify/FastPLMs/boltz/src/boltz/data/module/inference.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.utils.data import DataLoader
9
+
10
+ from boltz.data import const
11
+ from boltz.data.feature.featurizer import BoltzFeaturizer
12
+ from boltz.data.pad import pad_to_max
13
+ from boltz.data.tokenize.boltz import BoltzTokenizer
14
+ from boltz.data.types import (
15
+ MSA,
16
+ Connection,
17
+ Input,
18
+ Manifest,
19
+ Record,
20
+ ResidueConstraints,
21
+ Structure,
22
+ )
23
+
24
+
25
+ def load_input(
26
+ record: Record,
27
+ target_dir: Path,
28
+ msa_dir: Path,
29
+ constraints_dir: Optional[Path] = None,
30
+ ) -> Input:
31
+ """Load the given input data.
32
+
33
+ Parameters
34
+ ----------
35
+ record : Record
36
+ The record to load.
37
+ target_dir : Path
38
+ The path to the data directory.
39
+ msa_dir : Path
40
+ The path to msa directory.
41
+
42
+ Returns
43
+ -------
44
+ Input
45
+ The loaded input.
46
+
47
+ """
48
+ # Load the structure
49
+ structure = np.load(target_dir / f"{record.id}.npz")
50
+ structure = Structure(
51
+ atoms=structure["atoms"],
52
+ bonds=structure["bonds"],
53
+ residues=structure["residues"],
54
+ chains=structure["chains"],
55
+ connections=structure["connections"].astype(Connection),
56
+ interfaces=structure["interfaces"],
57
+ mask=structure["mask"],
58
+ )
59
+
60
+ msas = {}
61
+ for chain in record.chains:
62
+ msa_id = chain.msa_id
63
+ # Load the MSA for this chain, if any
64
+ if msa_id != -1:
65
+ msa = np.load(msa_dir / f"{msa_id}.npz")
66
+ msas[chain.chain_id] = MSA(**msa)
67
+
68
+ residue_constraints = None
69
+ if constraints_dir is not None:
70
+ residue_constraints = ResidueConstraints.load(
71
+ constraints_dir / f"{record.id}.npz"
72
+ )
73
+
74
+ return Input(structure, msas, record, residue_constraints)
75
+
76
+
77
+ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
78
+ """Collate the data.
79
+
80
+ Parameters
81
+ ----------
82
+ data : List[Dict[str, Tensor]]
83
+ The data to collate.
84
+
85
+ Returns
86
+ -------
87
+ Dict[str, Tensor]
88
+ The collated data.
89
+
90
+ """
91
+ # Get the keys
92
+ keys = data[0].keys()
93
+
94
+ # Collate the data
95
+ collated = {}
96
+ for key in keys:
97
+ values = [d[key] for d in data]
98
+
99
+ if key not in [
100
+ "all_coords",
101
+ "all_resolved_mask",
102
+ "crop_to_all_atom_map",
103
+ "chain_symmetries",
104
+ "amino_acids_symmetries",
105
+ "ligand_symmetries",
106
+ "record",
107
+ ]:
108
+ # Check if all have the same shape
109
+ shape = values[0].shape
110
+ if not all(v.shape == shape for v in values):
111
+ values, _ = pad_to_max(values, 0)
112
+ else:
113
+ values = torch.stack(values, dim=0)
114
+
115
+ # Stack the values
116
+ collated[key] = values
117
+
118
+ return collated
119
+
120
+
121
+ class PredictionDataset(torch.utils.data.Dataset):
122
+ """Base iterable dataset."""
123
+
124
+ def __init__(
125
+ self,
126
+ manifest: Manifest,
127
+ target_dir: Path,
128
+ msa_dir: Path,
129
+ constraints_dir: Optional[Path] = None,
130
+ ) -> None:
131
+ """Initialize the training dataset.
132
+
133
+ Parameters
134
+ ----------
135
+ manifest : Manifest
136
+ The manifest to load data from.
137
+ target_dir : Path
138
+ The path to the target directory.
139
+ msa_dir : Path
140
+ The path to the msa directory.
141
+
142
+ """
143
+ super().__init__()
144
+ self.manifest = manifest
145
+ self.target_dir = target_dir
146
+ self.msa_dir = msa_dir
147
+ self.constraints_dir = constraints_dir
148
+ self.tokenizer = BoltzTokenizer()
149
+ self.featurizer = BoltzFeaturizer()
150
+
151
+ def __getitem__(self, idx: int) -> dict:
152
+ """Get an item from the dataset.
153
+
154
+ Returns
155
+ -------
156
+ Dict[str, Tensor]
157
+ The sampled data features.
158
+
159
+ """
160
+ # Get a sample from the dataset
161
+ record = self.manifest.records[idx]
162
+
163
+ # Get the structure
164
+ try:
165
+ input_data = load_input(
166
+ record,
167
+ self.target_dir,
168
+ self.msa_dir,
169
+ self.constraints_dir,
170
+ )
171
+ except Exception as e: # noqa: BLE001
172
+ print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201
173
+ return self.__getitem__(0)
174
+
175
+ # Tokenize structure
176
+ try:
177
+ tokenized = self.tokenizer.tokenize(input_data)
178
+ except Exception as e: # noqa: BLE001
179
+ print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
180
+ return self.__getitem__(0)
181
+
182
+ # Inference specific options
183
+ options = record.inference_options
184
+ if options is None or len(options.pocket_constraints) == 0:
185
+ binder, pocket = None, None
186
+ else:
187
+ binder, pocket = (
188
+ options.pocket_constraints[0][0],
189
+ options.pocket_constraints[0][1],
190
+ )
191
+
192
+ # Compute features
193
+ try:
194
+ features = self.featurizer.process(
195
+ tokenized,
196
+ training=False,
197
+ max_atoms=None,
198
+ max_tokens=None,
199
+ max_seqs=const.max_msa_seqs,
200
+ pad_to_max_seqs=False,
201
+ symmetries={},
202
+ compute_symmetries=False,
203
+ inference_binder=binder,
204
+ inference_pocket=pocket,
205
+ compute_constraint_features=True,
206
+ )
207
+ except Exception as e: # noqa: BLE001
208
+ print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
209
+ return self.__getitem__(0)
210
+
211
+ features["record"] = record
212
+ return features
213
+
214
+ def __len__(self) -> int:
215
+ """Get the length of the dataset.
216
+
217
+ Returns
218
+ -------
219
+ int
220
+ The length of the dataset.
221
+
222
+ """
223
+ return len(self.manifest.records)
224
+
225
+
226
+ class BoltzInferenceDataModule(pl.LightningDataModule):
227
+ """DataModule for Boltz inference."""
228
+
229
+ def __init__(
230
+ self,
231
+ manifest: Manifest,
232
+ target_dir: Path,
233
+ msa_dir: Path,
234
+ num_workers: int,
235
+ constraints_dir: Optional[Path] = None,
236
+ ) -> None:
237
+ """Initialize the DataModule.
238
+
239
+ Parameters
240
+ ----------
241
+ config : DataConfig
242
+ The data configuration.
243
+
244
+ """
245
+ super().__init__()
246
+ self.num_workers = num_workers
247
+ self.manifest = manifest
248
+ self.target_dir = target_dir
249
+ self.msa_dir = msa_dir
250
+ self.constraints_dir = constraints_dir
251
+
252
+ def predict_dataloader(self) -> DataLoader:
253
+ """Get the training dataloader.
254
+
255
+ Returns
256
+ -------
257
+ DataLoader
258
+ The training dataloader.
259
+
260
+ """
261
+ dataset = PredictionDataset(
262
+ manifest=self.manifest,
263
+ target_dir=self.target_dir,
264
+ msa_dir=self.msa_dir,
265
+ constraints_dir=self.constraints_dir,
266
+ )
267
+ return DataLoader(
268
+ dataset,
269
+ batch_size=1,
270
+ num_workers=self.num_workers,
271
+ pin_memory=True,
272
+ shuffle=False,
273
+ collate_fn=collate,
274
+ )
275
+
276
+ def transfer_batch_to_device(
277
+ self,
278
+ batch: dict,
279
+ device: torch.device,
280
+ dataloader_idx: int, # noqa: ARG002
281
+ ) -> dict:
282
+ """Transfer a batch to the given device.
283
+
284
+ Parameters
285
+ ----------
286
+ batch : Dict
287
+ The batch to transfer.
288
+ device : torch.device
289
+ The device to transfer to.
290
+ dataloader_idx : int
291
+ The dataloader index.
292
+
293
+ Returns
294
+ -------
295
+ np.Any
296
+ The transferred batch.
297
+
298
+ """
299
+ for key in batch:
300
+ if key not in [
301
+ "all_coords",
302
+ "all_resolved_mask",
303
+ "crop_to_all_atom_map",
304
+ "chain_symmetries",
305
+ "amino_acids_symmetries",
306
+ "ligand_symmetries",
307
+ "record",
308
+ ]:
309
+ batch[key] = batch[key].to(device)
310
+ return batch
protify/FastPLMs/boltz/src/boltz/data/module/inferencev2.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from torch import Tensor
9
+ from torch.utils.data import DataLoader
10
+
11
+ from boltz.data import const
12
+ from boltz.data.crop.affinity import AffinityCropper
13
+ from boltz.data.feature.featurizerv2 import Boltz2Featurizer
14
+ from boltz.data.mol import load_canonicals, load_molecules
15
+ from boltz.data.pad import pad_to_max
16
+ from boltz.data.tokenize.boltz2 import Boltz2Tokenizer
17
+ from boltz.data.types import (
18
+ MSA,
19
+ Input,
20
+ Manifest,
21
+ Record,
22
+ ResidueConstraints,
23
+ StructureV2,
24
+ )
25
+
26
+
27
+ def load_input(
28
+ record: Record,
29
+ target_dir: Path,
30
+ msa_dir: Path,
31
+ constraints_dir: Optional[Path] = None,
32
+ template_dir: Optional[Path] = None,
33
+ extra_mols_dir: Optional[Path] = None,
34
+ affinity: bool = False,
35
+ ) -> Input:
36
+ """Load the given input data.
37
+
38
+ Parameters
39
+ ----------
40
+ record : Record
41
+ The record to load.
42
+ target_dir : Path
43
+ The path to the data directory.
44
+ msa_dir : Path
45
+ The path to msa directory.
46
+ constraints_dir : Optional[Path]
47
+ The path to the constraints directory.
48
+ template_dir : Optional[Path]
49
+ The path to the template directory.
50
+ extra_mols_dir : Optional[Path]
51
+ The path to the extra molecules directory.
52
+ affinity : bool
53
+ Whether to load the affinity data.
54
+
55
+ Returns
56
+ -------
57
+ Input
58
+ The loaded input.
59
+
60
+ """
61
+ # Load the structure
62
+ if affinity:
63
+ structure = StructureV2.load(
64
+ target_dir / record.id / f"pre_affinity_{record.id}.npz"
65
+ )
66
+ else:
67
+ structure = StructureV2.load(target_dir / f"{record.id}.npz")
68
+
69
+ msas = {}
70
+ for chain in record.chains:
71
+ msa_id = chain.msa_id
72
+ # Load the MSA for this chain, if any
73
+ if msa_id != -1:
74
+ msa = MSA.load(msa_dir / f"{msa_id}.npz")
75
+ msas[chain.chain_id] = msa
76
+
77
+ # Load templates
78
+ templates = None
79
+ if record.templates and template_dir is not None:
80
+ templates = {}
81
+ for template_info in record.templates:
82
+ template_id = template_info.name
83
+ template_path = template_dir / f"{record.id}_{template_id}.npz"
84
+ template = StructureV2.load(template_path)
85
+ templates[template_id] = template
86
+
87
+ # Load residue constraints
88
+ residue_constraints = None
89
+ if constraints_dir is not None:
90
+ residue_constraints = ResidueConstraints.load(
91
+ constraints_dir / f"{record.id}.npz"
92
+ )
93
+
94
+ # Load extra molecules
95
+ extra_mols = {}
96
+ if extra_mols_dir is not None:
97
+ extra_mol_path = extra_mols_dir / f"{record.id}.pkl"
98
+ if extra_mol_path.exists():
99
+ with extra_mol_path.open("rb") as f:
100
+ extra_mols = pickle.load(f) # noqa: S301
101
+
102
+ return Input(
103
+ structure,
104
+ msas,
105
+ record=record,
106
+ residue_constraints=residue_constraints,
107
+ templates=templates,
108
+ extra_mols=extra_mols,
109
+ )
110
+
111
+
112
+ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
113
+ """Collate the data.
114
+
115
+ Parameters
116
+ ----------
117
+ data : List[Dict[str, Tensor]]
118
+ The data to collate.
119
+
120
+ Returns
121
+ -------
122
+ Dict[str, Tensor]
123
+ The collated data.
124
+
125
+ """
126
+ # Get the keys
127
+ keys = data[0].keys()
128
+
129
+ # Collate the data
130
+ collated = {}
131
+ for key in keys:
132
+ values = [d[key] for d in data]
133
+
134
+ if key not in [
135
+ "all_coords",
136
+ "all_resolved_mask",
137
+ "crop_to_all_atom_map",
138
+ "chain_symmetries",
139
+ "amino_acids_symmetries",
140
+ "ligand_symmetries",
141
+ "record",
142
+ "affinity_mw",
143
+ ]:
144
+ # Check if all have the same shape
145
+ shape = values[0].shape
146
+ if not all(v.shape == shape for v in values):
147
+ values, _ = pad_to_max(values, 0)
148
+ else:
149
+ values = torch.stack(values, dim=0)
150
+
151
+ # Stack the values
152
+ collated[key] = values
153
+
154
+ return collated
155
+
156
+
157
+ class PredictionDataset(torch.utils.data.Dataset):
158
+ """Base iterable dataset."""
159
+
160
+ def __init__(
161
+ self,
162
+ manifest: Manifest,
163
+ target_dir: Path,
164
+ msa_dir: Path,
165
+ mol_dir: Path,
166
+ constraints_dir: Optional[Path] = None,
167
+ template_dir: Optional[Path] = None,
168
+ extra_mols_dir: Optional[Path] = None,
169
+ override_method: Optional[str] = None,
170
+ affinity: bool = False,
171
+ ) -> None:
172
+ """Initialize the training dataset.
173
+
174
+ Parameters
175
+ ----------
176
+ manifest : Manifest
177
+ The manifest to load data from.
178
+ target_dir : Path
179
+ The path to the target directory.
180
+ msa_dir : Path
181
+ The path to the msa directory.
182
+ mol_dir : Path
183
+ The path to the moldir.
184
+ constraints_dir : Optional[Path]
185
+ The path to the constraints directory.
186
+ template_dir : Optional[Path]
187
+ The path to the template directory.
188
+
189
+ """
190
+ super().__init__()
191
+ self.manifest = manifest
192
+ self.target_dir = target_dir
193
+ self.msa_dir = msa_dir
194
+ self.mol_dir = mol_dir
195
+ self.constraints_dir = constraints_dir
196
+ self.template_dir = template_dir
197
+ self.tokenizer = Boltz2Tokenizer()
198
+ self.featurizer = Boltz2Featurizer()
199
+ self.canonicals = load_canonicals(self.mol_dir)
200
+ self.extra_mols_dir = extra_mols_dir
201
+ self.override_method = override_method
202
+ self.affinity = affinity
203
+ if self.affinity:
204
+ self.cropper = AffinityCropper()
205
+
206
+ def __getitem__(self, idx: int) -> dict:
207
+ """Get an item from the dataset.
208
+
209
+ Returns
210
+ -------
211
+ Dict[str, Tensor]
212
+ The sampled data features.
213
+
214
+ """
215
+ # Get record
216
+ record = self.manifest.records[idx]
217
+
218
+ # Finalize input data
219
+ input_data = load_input(
220
+ record=record,
221
+ target_dir=self.target_dir,
222
+ msa_dir=self.msa_dir,
223
+ constraints_dir=self.constraints_dir,
224
+ template_dir=self.template_dir,
225
+ extra_mols_dir=self.extra_mols_dir,
226
+ affinity=self.affinity,
227
+ )
228
+
229
+ # Tokenize structure
230
+ try:
231
+ tokenized = self.tokenizer.tokenize(input_data)
232
+ except Exception as e: # noqa: BLE001
233
+ print( # noqa: T201
234
+ f"Tokenizer failed on {record.id} with error {e}. Skipping."
235
+ )
236
+ return self.__getitem__(0)
237
+
238
+ if self.affinity:
239
+ try:
240
+ tokenized = self.cropper.crop(
241
+ tokenized,
242
+ max_tokens=256,
243
+ max_atoms=2048,
244
+ )
245
+ except Exception as e: # noqa: BLE001
246
+ print(f"Cropper failed on {record.id} with error {e}. Skipping.") # noqa: T201
247
+ return self.__getitem__(0)
248
+
249
+ # Load conformers
250
+ try:
251
+ molecules = {}
252
+ molecules.update(self.canonicals)
253
+ molecules.update(input_data.extra_mols)
254
+ mol_names = set(tokenized.tokens["res_name"].tolist())
255
+ mol_names = mol_names - set(molecules.keys())
256
+ molecules.update(load_molecules(self.mol_dir, mol_names))
257
+ except Exception as e: # noqa: BLE001
258
+ print(f"Molecule loading failed for {record.id} with error {e}. Skipping.")
259
+ return self.__getitem__(0)
260
+
261
+ # Inference specific options
262
+ options = record.inference_options
263
+ if options is None:
264
+ pocket_constraints, contact_constraints = None, None
265
+ else:
266
+ pocket_constraints, contact_constraints = (
267
+ options.pocket_constraints,
268
+ options.contact_constraints,
269
+ )
270
+
271
+ # Get random seed
272
+ seed = 42
273
+ random = np.random.default_rng(seed)
274
+
275
+ # Compute features
276
+ try:
277
+ features = self.featurizer.process(
278
+ tokenized,
279
+ molecules=molecules,
280
+ random=random,
281
+ training=False,
282
+ max_atoms=None,
283
+ max_tokens=None,
284
+ max_seqs=const.max_msa_seqs,
285
+ pad_to_max_seqs=False,
286
+ single_sequence_prop=0.0,
287
+ compute_frames=True,
288
+ inference_pocket_constraints=pocket_constraints,
289
+ inference_contact_constraints=contact_constraints,
290
+ compute_constraint_features=True,
291
+ override_method=self.override_method,
292
+ compute_affinity=self.affinity,
293
+ )
294
+ except Exception as e: # noqa: BLE001
295
+ import traceback
296
+
297
+ traceback.print_exc()
298
+ print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
299
+ return self.__getitem__(0)
300
+
301
+ # Add record
302
+ features["record"] = record
303
+ return features
304
+
305
+ def __len__(self) -> int:
306
+ """Get the length of the dataset.
307
+
308
+ Returns
309
+ -------
310
+ int
311
+ The length of the dataset.
312
+
313
+ """
314
+ return len(self.manifest.records)
315
+
316
+
317
+ class Boltz2InferenceDataModule(pl.LightningDataModule):
318
+ """DataModule for Boltz2 inference."""
319
+
320
+ def __init__(
321
+ self,
322
+ manifest: Manifest,
323
+ target_dir: Path,
324
+ msa_dir: Path,
325
+ mol_dir: Path,
326
+ num_workers: int,
327
+ constraints_dir: Optional[Path] = None,
328
+ template_dir: Optional[Path] = None,
329
+ extra_mols_dir: Optional[Path] = None,
330
+ override_method: Optional[str] = None,
331
+ affinity: bool = False,
332
+ ) -> None:
333
+ """Initialize the DataModule.
334
+
335
+ Parameters
336
+ ----------
337
+ manifest : Manifest
338
+ The manifest to load data from.
339
+ target_dir : Path
340
+ The path to the target directory.
341
+ msa_dir : Path
342
+ The path to the msa directory.
343
+ mol_dir : Path
344
+ The path to the moldir.
345
+ num_workers : int
346
+ The number of workers to use.
347
+ constraints_dir : Optional[Path]
348
+ The path to the constraints directory.
349
+ template_dir : Optional[Path]
350
+ The path to the template directory.
351
+ extra_mols_dir : Optional[Path]
352
+ The path to the extra molecules directory.
353
+ override_method : Optional[str]
354
+ The method to override.
355
+
356
+ """
357
+ super().__init__()
358
+ self.num_workers = num_workers
359
+ self.manifest = manifest
360
+ self.target_dir = target_dir
361
+ self.msa_dir = msa_dir
362
+ self.mol_dir = mol_dir
363
+ self.constraints_dir = constraints_dir
364
+ self.template_dir = template_dir
365
+ self.extra_mols_dir = extra_mols_dir
366
+ self.override_method = override_method
367
+ self.affinity = affinity
368
+
369
+ def predict_dataloader(self) -> DataLoader:
370
+ """Get the training dataloader.
371
+
372
+ Returns
373
+ -------
374
+ DataLoader
375
+ The training dataloader.
376
+
377
+ """
378
+ dataset = PredictionDataset(
379
+ manifest=self.manifest,
380
+ target_dir=self.target_dir,
381
+ msa_dir=self.msa_dir,
382
+ mol_dir=self.mol_dir,
383
+ constraints_dir=self.constraints_dir,
384
+ template_dir=self.template_dir,
385
+ extra_mols_dir=self.extra_mols_dir,
386
+ override_method=self.override_method,
387
+ affinity=self.affinity,
388
+ )
389
+ return DataLoader(
390
+ dataset,
391
+ batch_size=1,
392
+ num_workers=self.num_workers,
393
+ pin_memory=True,
394
+ shuffle=False,
395
+ collate_fn=collate,
396
+ )
397
+
398
+ def transfer_batch_to_device(
399
+ self,
400
+ batch: dict,
401
+ device: torch.device,
402
+ dataloader_idx: int, # noqa: ARG002
403
+ ) -> dict:
404
+ """Transfer a batch to the given device.
405
+
406
+ Parameters
407
+ ----------
408
+ batch : Dict
409
+ The batch to transfer.
410
+ device : torch.device
411
+ The device to transfer to.
412
+ dataloader_idx : int
413
+ The dataloader index.
414
+
415
+ Returns
416
+ -------
417
+ np.Any
418
+ The transferred batch.
419
+
420
+ """
421
+ for key in batch:
422
+ if key not in [
423
+ "all_coords",
424
+ "all_resolved_mask",
425
+ "crop_to_all_atom_map",
426
+ "chain_symmetries",
427
+ "amino_acids_symmetries",
428
+ "ligand_symmetries",
429
+ "record",
430
+ "affinity_mw",
431
+ ]:
432
+ batch[key] = batch[key].to(device)
433
+ return batch
protify/FastPLMs/boltz/src/boltz/data/module/training.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from torch import Tensor
9
+ from torch.utils.data import DataLoader
10
+
11
+ from boltz.data.crop.cropper import Cropper
12
+ from boltz.data.feature.featurizer import BoltzFeaturizer
13
+ from boltz.data.feature.symmetry import get_symmetries
14
+ from boltz.data.filter.dynamic.filter import DynamicFilter
15
+ from boltz.data.pad import pad_to_max
16
+ from boltz.data.sample.sampler import Sample, Sampler
17
+ from boltz.data.tokenize.tokenizer import Tokenizer
18
+ from boltz.data.types import MSA, Connection, Input, Manifest, Record, Structure
19
+
20
+
21
+ @dataclass
22
+ class DatasetConfig:
23
+ """Dataset configuration."""
24
+
25
+ target_dir: str
26
+ msa_dir: str
27
+ prob: float
28
+ sampler: Sampler
29
+ cropper: Cropper
30
+ filters: Optional[list] = None
31
+ split: Optional[str] = None
32
+ manifest_path: Optional[str] = None
33
+
34
+
35
+ @dataclass
36
+ class DataConfig:
37
+ """Data configuration."""
38
+
39
+ datasets: list[DatasetConfig]
40
+ filters: list[DynamicFilter]
41
+ featurizer: BoltzFeaturizer
42
+ tokenizer: Tokenizer
43
+ max_atoms: int
44
+ max_tokens: int
45
+ max_seqs: int
46
+ samples_per_epoch: int
47
+ batch_size: int
48
+ num_workers: int
49
+ random_seed: int
50
+ pin_memory: bool
51
+ symmetries: str
52
+ atoms_per_window_queries: int
53
+ min_dist: float
54
+ max_dist: float
55
+ num_bins: int
56
+ overfit: Optional[int] = None
57
+ pad_to_max_tokens: bool = False
58
+ pad_to_max_atoms: bool = False
59
+ pad_to_max_seqs: bool = False
60
+ crop_validation: bool = False
61
+ return_train_symmetries: bool = False
62
+ return_val_symmetries: bool = True
63
+ train_binder_pocket_conditioned_prop: float = 0.0
64
+ val_binder_pocket_conditioned_prop: float = 0.0
65
+ binder_pocket_cutoff: float = 6.0
66
+ binder_pocket_sampling_geometric_p: float = 0.0
67
+ val_batch_size: int = 1
68
+ compute_constraint_features: bool = False
69
+
70
+
71
+ @dataclass
72
+ class Dataset:
73
+ """Data holder."""
74
+
75
+ target_dir: Path
76
+ msa_dir: Path
77
+ manifest: Manifest
78
+ prob: float
79
+ sampler: Sampler
80
+ cropper: Cropper
81
+ tokenizer: Tokenizer
82
+ featurizer: BoltzFeaturizer
83
+
84
+
85
+ def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input:
86
+ """Load the given input data.
87
+
88
+ Parameters
89
+ ----------
90
+ record : Record
91
+ The record to load.
92
+ target_dir : Path
93
+ The path to the data directory.
94
+ msa_dir : Path
95
+ The path to msa directory.
96
+
97
+ Returns
98
+ -------
99
+ Input
100
+ The loaded input.
101
+
102
+ """
103
+ # Load the structure
104
+ structure = np.load(target_dir / "structures" / f"{record.id}.npz")
105
+
106
+ # In order to add cyclic_period to chains if it does not exist
107
+ # Extract the chains array
108
+ chains = structure["chains"]
109
+ # Check if the field exists
110
+ if "cyclic_period" not in chains.dtype.names:
111
+ # Create a new dtype with the additional field
112
+ new_dtype = chains.dtype.descr + [("cyclic_period", "i4")]
113
+ # Create a new array with the new dtype
114
+ new_chains = np.empty(chains.shape, dtype=new_dtype)
115
+ # Copy over existing fields
116
+ for name in chains.dtype.names:
117
+ new_chains[name] = chains[name]
118
+ # Set the new field to 0
119
+ new_chains["cyclic_period"] = 0
120
+ # Replace old chains array with new one
121
+ chains = new_chains
122
+
123
+ structure = Structure(
124
+ atoms=structure["atoms"],
125
+ bonds=structure["bonds"],
126
+ residues=structure["residues"],
127
+ chains=chains, # chains var accounting for missing cyclic_period
128
+ connections=structure["connections"].astype(Connection),
129
+ interfaces=structure["interfaces"],
130
+ mask=structure["mask"],
131
+ )
132
+
133
+ msas = {}
134
+ for chain in record.chains:
135
+ msa_id = chain.msa_id
136
+ # Load the MSA for this chain, if any
137
+ if msa_id != -1 and msa_id != "":
138
+ msa = np.load(msa_dir / f"{msa_id}.npz")
139
+ msas[chain.chain_id] = MSA(**msa)
140
+
141
+ return Input(structure, msas)
142
+
143
+
144
+ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
145
+ """Collate the data.
146
+
147
+ Parameters
148
+ ----------
149
+ data : list[dict[str, Tensor]]
150
+ The data to collate.
151
+
152
+ Returns
153
+ -------
154
+ dict[str, Tensor]
155
+ The collated data.
156
+
157
+ """
158
+ # Get the keys
159
+ keys = data[0].keys()
160
+
161
+ # Collate the data
162
+ collated = {}
163
+ for key in keys:
164
+ values = [d[key] for d in data]
165
+
166
+ if key not in [
167
+ "all_coords",
168
+ "all_resolved_mask",
169
+ "crop_to_all_atom_map",
170
+ "chain_symmetries",
171
+ "amino_acids_symmetries",
172
+ "ligand_symmetries",
173
+ ]:
174
+ # Check if all have the same shape
175
+ shape = values[0].shape
176
+ if not all(v.shape == shape for v in values):
177
+ values, _ = pad_to_max(values, 0)
178
+ else:
179
+ values = torch.stack(values, dim=0)
180
+
181
+ # Stack the values
182
+ collated[key] = values
183
+
184
+ return collated
185
+
186
+
187
+ class TrainingDataset(torch.utils.data.Dataset):
188
+ """Base iterable dataset."""
189
+
190
+ def __init__(
191
+ self,
192
+ datasets: list[Dataset],
193
+ samples_per_epoch: int,
194
+ symmetries: dict,
195
+ max_atoms: int,
196
+ max_tokens: int,
197
+ max_seqs: int,
198
+ pad_to_max_atoms: bool = False,
199
+ pad_to_max_tokens: bool = False,
200
+ pad_to_max_seqs: bool = False,
201
+ atoms_per_window_queries: int = 32,
202
+ min_dist: float = 2.0,
203
+ max_dist: float = 22.0,
204
+ num_bins: int = 64,
205
+ overfit: Optional[int] = None,
206
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
207
+ binder_pocket_cutoff: Optional[float] = 6.0,
208
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
209
+ return_symmetries: Optional[bool] = False,
210
+ compute_constraint_features: bool = False,
211
+ ) -> None:
212
+ """Initialize the training dataset."""
213
+ super().__init__()
214
+ self.datasets = datasets
215
+ self.probs = [d.prob for d in datasets]
216
+ self.samples_per_epoch = samples_per_epoch
217
+ self.symmetries = symmetries
218
+ self.max_tokens = max_tokens
219
+ self.max_seqs = max_seqs
220
+ self.max_atoms = max_atoms
221
+ self.pad_to_max_tokens = pad_to_max_tokens
222
+ self.pad_to_max_atoms = pad_to_max_atoms
223
+ self.pad_to_max_seqs = pad_to_max_seqs
224
+ self.atoms_per_window_queries = atoms_per_window_queries
225
+ self.min_dist = min_dist
226
+ self.max_dist = max_dist
227
+ self.num_bins = num_bins
228
+ self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
229
+ self.binder_pocket_cutoff = binder_pocket_cutoff
230
+ self.binder_pocket_sampling_geometric_p = binder_pocket_sampling_geometric_p
231
+ self.return_symmetries = return_symmetries
232
+ self.compute_constraint_features = compute_constraint_features
233
+ self.samples = []
234
+ for dataset in datasets:
235
+ records = dataset.manifest.records
236
+ if overfit is not None:
237
+ records = records[:overfit]
238
+ iterator = dataset.sampler.sample(records, np.random)
239
+ self.samples.append(iterator)
240
+
241
+ def __getitem__(self, idx: int) -> dict[str, Tensor]:
242
+ """Get an item from the dataset.
243
+
244
+ Parameters
245
+ ----------
246
+ idx : int
247
+ The data index.
248
+
249
+ Returns
250
+ -------
251
+ dict[str, Tensor]
252
+ The sampled data features.
253
+
254
+ """
255
+ # Pick a random dataset
256
+ dataset_idx = np.random.choice(
257
+ len(self.datasets),
258
+ p=self.probs,
259
+ )
260
+ dataset = self.datasets[dataset_idx]
261
+
262
+ # Get a sample from the dataset
263
+ sample: Sample = next(self.samples[dataset_idx])
264
+
265
+ # Get the structure
266
+ try:
267
+ input_data = load_input(sample.record, dataset.target_dir, dataset.msa_dir)
268
+ except Exception as e:
269
+ print(
270
+ f"Failed to load input for {sample.record.id} with error {e}. Skipping."
271
+ )
272
+ return self.__getitem__(idx)
273
+
274
+ # Tokenize structure
275
+ try:
276
+ tokenized = dataset.tokenizer.tokenize(input_data)
277
+ except Exception as e:
278
+ print(f"Tokenizer failed on {sample.record.id} with error {e}. Skipping.")
279
+ return self.__getitem__(idx)
280
+
281
+ # Compute crop
282
+ try:
283
+ if self.max_tokens is not None:
284
+ tokenized = dataset.cropper.crop(
285
+ tokenized,
286
+ max_atoms=self.max_atoms,
287
+ max_tokens=self.max_tokens,
288
+ random=np.random,
289
+ chain_id=sample.chain_id,
290
+ interface_id=sample.interface_id,
291
+ )
292
+ except Exception as e:
293
+ print(f"Cropper failed on {sample.record.id} with error {e}. Skipping.")
294
+ return self.__getitem__(idx)
295
+
296
+ # Check if there are tokens
297
+ if len(tokenized.tokens) == 0:
298
+ msg = "No tokens in cropped structure."
299
+ raise ValueError(msg)
300
+
301
+ # Compute features
302
+ try:
303
+ features = dataset.featurizer.process(
304
+ tokenized,
305
+ training=True,
306
+ max_atoms=self.max_atoms if self.pad_to_max_atoms else None,
307
+ max_tokens=self.max_tokens if self.pad_to_max_tokens else None,
308
+ max_seqs=self.max_seqs,
309
+ pad_to_max_seqs=self.pad_to_max_seqs,
310
+ symmetries=self.symmetries,
311
+ atoms_per_window_queries=self.atoms_per_window_queries,
312
+ min_dist=self.min_dist,
313
+ max_dist=self.max_dist,
314
+ num_bins=self.num_bins,
315
+ compute_symmetries=self.return_symmetries,
316
+ binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
317
+ binder_pocket_cutoff=self.binder_pocket_cutoff,
318
+ binder_pocket_sampling_geometric_p=self.binder_pocket_sampling_geometric_p,
319
+ compute_constraint_features=self.compute_constraint_features,
320
+ )
321
+ except Exception as e:
322
+ print(f"Featurizer failed on {sample.record.id} with error {e}. Skipping.")
323
+ return self.__getitem__(idx)
324
+
325
+ return features
326
+
327
+ def __len__(self) -> int:
328
+ """Get the length of the dataset.
329
+
330
+ Returns
331
+ -------
332
+ int
333
+ The length of the dataset.
334
+
335
+ """
336
+ return self.samples_per_epoch
337
+
338
+
339
+ class ValidationDataset(torch.utils.data.Dataset):
340
+ """Base iterable dataset."""
341
+
342
+ def __init__(
343
+ self,
344
+ datasets: list[Dataset],
345
+ seed: int,
346
+ symmetries: dict,
347
+ max_atoms: Optional[int] = None,
348
+ max_tokens: Optional[int] = None,
349
+ max_seqs: Optional[int] = None,
350
+ pad_to_max_atoms: bool = False,
351
+ pad_to_max_tokens: bool = False,
352
+ pad_to_max_seqs: bool = False,
353
+ atoms_per_window_queries: int = 32,
354
+ min_dist: float = 2.0,
355
+ max_dist: float = 22.0,
356
+ num_bins: int = 64,
357
+ overfit: Optional[int] = None,
358
+ crop_validation: bool = False,
359
+ return_symmetries: Optional[bool] = False,
360
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
361
+ binder_pocket_cutoff: Optional[float] = 6.0,
362
+ compute_constraint_features: bool = False,
363
+ ) -> None:
364
+ """Initialize the validation dataset."""
365
+ super().__init__()
366
+ self.datasets = datasets
367
+ self.max_atoms = max_atoms
368
+ self.max_tokens = max_tokens
369
+ self.max_seqs = max_seqs
370
+ self.seed = seed
371
+ self.symmetries = symmetries
372
+ self.random = np.random if overfit else np.random.RandomState(self.seed)
373
+ self.pad_to_max_tokens = pad_to_max_tokens
374
+ self.pad_to_max_atoms = pad_to_max_atoms
375
+ self.pad_to_max_seqs = pad_to_max_seqs
376
+ self.overfit = overfit
377
+ self.crop_validation = crop_validation
378
+ self.atoms_per_window_queries = atoms_per_window_queries
379
+ self.min_dist = min_dist
380
+ self.max_dist = max_dist
381
+ self.num_bins = num_bins
382
+ self.return_symmetries = return_symmetries
383
+ self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
384
+ self.binder_pocket_cutoff = binder_pocket_cutoff
385
+ self.compute_constraint_features = compute_constraint_features
386
+
387
+ def __getitem__(self, idx: int) -> dict[str, Tensor]:
388
+ """Get an item from the dataset.
389
+
390
+ Parameters
391
+ ----------
392
+ idx : int
393
+ The data index.
394
+
395
+ Returns
396
+ -------
397
+ dict[str, Tensor]
398
+ The sampled data features.
399
+
400
+ """
401
+ # Pick dataset based on idx
402
+ for dataset in self.datasets:
403
+ size = len(dataset.manifest.records)
404
+ if self.overfit is not None:
405
+ size = min(size, self.overfit)
406
+ if idx < size:
407
+ break
408
+ idx -= size
409
+
410
+ # Get a sample from the dataset
411
+ record = dataset.manifest.records[idx]
412
+
413
+ # Get the structure
414
+ try:
415
+ input_data = load_input(record, dataset.target_dir, dataset.msa_dir)
416
+ except Exception as e:
417
+ print(f"Failed to load input for {record.id} with error {e}. Skipping.")
418
+ return self.__getitem__(0)
419
+
420
+ # Tokenize structure
421
+ try:
422
+ tokenized = dataset.tokenizer.tokenize(input_data)
423
+ except Exception as e:
424
+ print(f"Tokenizer failed on {record.id} with error {e}. Skipping.")
425
+ return self.__getitem__(0)
426
+
427
+ # Compute crop
428
+ try:
429
+ if self.crop_validation and (self.max_tokens is not None):
430
+ tokenized = dataset.cropper.crop(
431
+ tokenized,
432
+ max_tokens=self.max_tokens,
433
+ random=self.random,
434
+ max_atoms=self.max_atoms,
435
+ )
436
+ except Exception as e:
437
+ print(f"Cropper failed on {record.id} with error {e}. Skipping.")
438
+ return self.__getitem__(0)
439
+
440
+ # Check if there are tokens
441
+ if len(tokenized.tokens) == 0:
442
+ msg = "No tokens in cropped structure."
443
+ raise ValueError(msg)
444
+
445
+ # Compute features
446
+ try:
447
+ pad_atoms = self.crop_validation and self.pad_to_max_atoms
448
+ pad_tokens = self.crop_validation and self.pad_to_max_tokens
449
+
450
+ features = dataset.featurizer.process(
451
+ tokenized,
452
+ training=False,
453
+ max_atoms=self.max_atoms if pad_atoms else None,
454
+ max_tokens=self.max_tokens if pad_tokens else None,
455
+ max_seqs=self.max_seqs,
456
+ pad_to_max_seqs=self.pad_to_max_seqs,
457
+ symmetries=self.symmetries,
458
+ atoms_per_window_queries=self.atoms_per_window_queries,
459
+ min_dist=self.min_dist,
460
+ max_dist=self.max_dist,
461
+ num_bins=self.num_bins,
462
+ compute_symmetries=self.return_symmetries,
463
+ binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
464
+ binder_pocket_cutoff=self.binder_pocket_cutoff,
465
+ binder_pocket_sampling_geometric_p=1.0, # this will only sample a single pocket token
466
+ only_ligand_binder_pocket=True,
467
+ compute_constraint_features=self.compute_constraint_features,
468
+ )
469
+ except Exception as e:
470
+ print(f"Featurizer failed on {record.id} with error {e}. Skipping.")
471
+ return self.__getitem__(0)
472
+
473
+ return features
474
+
475
+ def __len__(self) -> int:
476
+ """Get the length of the dataset.
477
+
478
+ Returns
479
+ -------
480
+ int
481
+ The length of the dataset.
482
+
483
+ """
484
+ if self.overfit is not None:
485
+ length = sum(len(d.manifest.records[: self.overfit]) for d in self.datasets)
486
+ else:
487
+ length = sum(len(d.manifest.records) for d in self.datasets)
488
+
489
+ return length
490
+
491
+
492
+ class BoltzTrainingDataModule(pl.LightningDataModule):
493
+ """DataModule for boltz."""
494
+
495
+ def __init__(self, cfg: DataConfig) -> None:
496
+ """Initialize the DataModule.
497
+
498
+ Parameters
499
+ ----------
500
+ config : DataConfig
501
+ The data configuration.
502
+
503
+ """
504
+ super().__init__()
505
+ self.cfg = cfg
506
+
507
+ assert self.cfg.val_batch_size == 1, "Validation only works with batch size=1."
508
+
509
+ # Load symmetries
510
+ symmetries = get_symmetries(cfg.symmetries)
511
+
512
+ # Load datasets
513
+ train: list[Dataset] = []
514
+ val: list[Dataset] = []
515
+
516
+ for data_config in cfg.datasets:
517
+ # Set target_dir
518
+ target_dir = Path(data_config.target_dir)
519
+ msa_dir = Path(data_config.msa_dir)
520
+
521
+ # Load manifest
522
+ if data_config.manifest_path is not None:
523
+ path = Path(data_config.manifest_path)
524
+ else:
525
+ path = target_dir / "manifest.json"
526
+ manifest: Manifest = Manifest.load(path)
527
+
528
+ # Split records if given
529
+ if data_config.split is not None:
530
+ with Path(data_config.split).open("r") as f:
531
+ split = {x.lower() for x in f.read().splitlines()}
532
+
533
+ train_records = []
534
+ val_records = []
535
+ for record in manifest.records:
536
+ if record.id.lower() in split:
537
+ val_records.append(record)
538
+ else:
539
+ train_records.append(record)
540
+ else:
541
+ train_records = manifest.records
542
+ val_records = []
543
+
544
+ # Filter training records
545
+ train_records = [
546
+ record
547
+ for record in train_records
548
+ if all(f.filter(record) for f in cfg.filters)
549
+ ]
550
+ # Filter training records
551
+ if data_config.filters is not None:
552
+ train_records = [
553
+ record
554
+ for record in train_records
555
+ if all(f.filter(record) for f in data_config.filters)
556
+ ]
557
+
558
+ # Create train dataset
559
+ train_manifest = Manifest(train_records)
560
+ train.append(
561
+ Dataset(
562
+ target_dir,
563
+ msa_dir,
564
+ train_manifest,
565
+ data_config.prob,
566
+ data_config.sampler,
567
+ data_config.cropper,
568
+ cfg.tokenizer,
569
+ cfg.featurizer,
570
+ )
571
+ )
572
+
573
+ # Create validation dataset
574
+ if val_records:
575
+ val_manifest = Manifest(val_records)
576
+ val.append(
577
+ Dataset(
578
+ target_dir,
579
+ msa_dir,
580
+ val_manifest,
581
+ data_config.prob,
582
+ data_config.sampler,
583
+ data_config.cropper,
584
+ cfg.tokenizer,
585
+ cfg.featurizer,
586
+ )
587
+ )
588
+
589
+ # Print dataset sizes
590
+ for dataset in train:
591
+ dataset: Dataset
592
+ print(f"Training dataset size: {len(dataset.manifest.records)}")
593
+
594
+ for dataset in val:
595
+ dataset: Dataset
596
+ print(f"Validation dataset size: {len(dataset.manifest.records)}")
597
+
598
+ # Create wrapper datasets
599
+ self._train_set = TrainingDataset(
600
+ datasets=train,
601
+ samples_per_epoch=cfg.samples_per_epoch,
602
+ max_atoms=cfg.max_atoms,
603
+ max_tokens=cfg.max_tokens,
604
+ max_seqs=cfg.max_seqs,
605
+ pad_to_max_atoms=cfg.pad_to_max_atoms,
606
+ pad_to_max_tokens=cfg.pad_to_max_tokens,
607
+ pad_to_max_seqs=cfg.pad_to_max_seqs,
608
+ symmetries=symmetries,
609
+ atoms_per_window_queries=cfg.atoms_per_window_queries,
610
+ min_dist=cfg.min_dist,
611
+ max_dist=cfg.max_dist,
612
+ num_bins=cfg.num_bins,
613
+ overfit=cfg.overfit,
614
+ binder_pocket_conditioned_prop=cfg.train_binder_pocket_conditioned_prop,
615
+ binder_pocket_cutoff=cfg.binder_pocket_cutoff,
616
+ binder_pocket_sampling_geometric_p=cfg.binder_pocket_sampling_geometric_p,
617
+ return_symmetries=cfg.return_train_symmetries,
618
+ compute_constraint_features=cfg.compute_constraint_features,
619
+ )
620
+ self._val_set = ValidationDataset(
621
+ datasets=train if cfg.overfit is not None else val,
622
+ seed=cfg.random_seed,
623
+ max_atoms=cfg.max_atoms,
624
+ max_tokens=cfg.max_tokens,
625
+ max_seqs=cfg.max_seqs,
626
+ pad_to_max_atoms=cfg.pad_to_max_atoms,
627
+ pad_to_max_tokens=cfg.pad_to_max_tokens,
628
+ pad_to_max_seqs=cfg.pad_to_max_seqs,
629
+ symmetries=symmetries,
630
+ atoms_per_window_queries=cfg.atoms_per_window_queries,
631
+ min_dist=cfg.min_dist,
632
+ max_dist=cfg.max_dist,
633
+ num_bins=cfg.num_bins,
634
+ overfit=cfg.overfit,
635
+ crop_validation=cfg.crop_validation,
636
+ return_symmetries=cfg.return_val_symmetries,
637
+ binder_pocket_conditioned_prop=cfg.val_binder_pocket_conditioned_prop,
638
+ binder_pocket_cutoff=cfg.binder_pocket_cutoff,
639
+ compute_constraint_features=cfg.compute_constraint_features,
640
+ )
641
+
642
+ def setup(self, stage: Optional[str] = None) -> None:
643
+ """Run the setup for the DataModule.
644
+
645
+ Parameters
646
+ ----------
647
+ stage : str, optional
648
+ The stage, one of 'fit', 'validate', 'test'.
649
+
650
+ """
651
+ return
652
+
653
+ def train_dataloader(self) -> DataLoader:
654
+ """Get the training dataloader.
655
+
656
+ Returns
657
+ -------
658
+ DataLoader
659
+ The training dataloader.
660
+
661
+ """
662
+ return DataLoader(
663
+ self._train_set,
664
+ batch_size=self.cfg.batch_size,
665
+ num_workers=self.cfg.num_workers,
666
+ pin_memory=self.cfg.pin_memory,
667
+ shuffle=False,
668
+ collate_fn=collate,
669
+ )
670
+
671
+ def val_dataloader(self) -> DataLoader:
672
+ """Get the validation dataloader.
673
+
674
+ Returns
675
+ -------
676
+ DataLoader
677
+ The validation dataloader.
678
+
679
+ """
680
+ return DataLoader(
681
+ self._val_set,
682
+ batch_size=self.cfg.val_batch_size,
683
+ num_workers=self.cfg.num_workers,
684
+ pin_memory=self.cfg.pin_memory,
685
+ shuffle=False,
686
+ collate_fn=collate,
687
+ )
protify/FastPLMs/boltz/src/boltz/data/module/trainingv2.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from torch import Tensor
9
+ from torch.utils.data import DataLoader
10
+
11
+ from boltz.data.crop.cropper import Cropper
12
+ from boltz.data.feature.featurizer import BoltzFeaturizer
13
+ from boltz.data.feature.symmetry import get_symmetries
14
+ from boltz.data.filter.dynamic.filter import DynamicFilter
15
+ from boltz.data.pad import pad_to_max
16
+ from boltz.data.sample.sampler import Sample, Sampler
17
+ from boltz.data.tokenize.tokenizer import Tokenizer
18
+ from boltz.data.types import MSA, Connection, Input, Manifest, Record, Structure
19
+
20
+
21
+ @dataclass
22
+ class DatasetConfig:
23
+ """Dataset configuration."""
24
+
25
+ target_dir: str
26
+ msa_dir: str
27
+ prob: float
28
+ sampler: Sampler
29
+ cropper: Cropper
30
+ filters: Optional[list] = None
31
+ split: Optional[str] = None
32
+ manifest_path: Optional[str] = None
33
+
34
+
35
+ @dataclass
36
+ class DataConfig:
37
+ """Data configuration."""
38
+
39
+ datasets: list[DatasetConfig]
40
+ filters: list[DynamicFilter]
41
+ featurizer: BoltzFeaturizer
42
+ tokenizer: Tokenizer
43
+ max_atoms: int
44
+ max_tokens: int
45
+ max_seqs: int
46
+ samples_per_epoch: int
47
+ batch_size: int
48
+ num_workers: int
49
+ random_seed: int
50
+ pin_memory: bool
51
+ symmetries: str
52
+ atoms_per_window_queries: int
53
+ min_dist: float
54
+ max_dist: float
55
+ num_bins: int
56
+ overfit: Optional[int] = None
57
+ pad_to_max_tokens: bool = False
58
+ pad_to_max_atoms: bool = False
59
+ pad_to_max_seqs: bool = False
60
+ crop_validation: bool = False
61
+ return_train_symmetries: bool = False
62
+ return_val_symmetries: bool = True
63
+ train_binder_pocket_conditioned_prop: float = 0.0
64
+ val_binder_pocket_conditioned_prop: float = 0.0
65
+ binder_pocket_cutoff: float = 6.0
66
+ binder_pocket_sampling_geometric_p: float = 0.0
67
+ val_batch_size: int = 1
68
+
69
+
70
+ @dataclass
71
+ class Dataset:
72
+ """Data holder."""
73
+
74
+ target_dir: Path
75
+ msa_dir: Path
76
+ manifest: Manifest
77
+ prob: float
78
+ sampler: Sampler
79
+ cropper: Cropper
80
+ tokenizer: Tokenizer
81
+ featurizer: BoltzFeaturizer
82
+
83
+
84
+ def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input:
85
+ """Load the given input data.
86
+
87
+ Parameters
88
+ ----------
89
+ record : Record
90
+ The record to load.
91
+ target_dir : Path
92
+ The path to the data directory.
93
+ msa_dir : Path
94
+ The path to msa directory.
95
+
96
+ Returns
97
+ -------
98
+ Input
99
+ The loaded input.
100
+
101
+ """
102
+ # Load the structure
103
+ structure = np.load(target_dir / "structures" / f"{record.id}.npz")
104
+ structure = Structure(
105
+ atoms=structure["atoms"],
106
+ bonds=structure["bonds"],
107
+ residues=structure["residues"],
108
+ chains=structure["chains"],
109
+ connections=structure["connections"].astype(Connection),
110
+ interfaces=structure["interfaces"],
111
+ mask=structure["mask"],
112
+ )
113
+
114
+ msas = {}
115
+ for chain in record.chains:
116
+ msa_id = chain.msa_id
117
+ # Load the MSA for this chain, if any
118
+ if msa_id != -1 and msa_id != "":
119
+ msa = np.load(msa_dir / f"{msa_id}.npz")
120
+ msas[chain.chain_id] = MSA(**msa)
121
+
122
+ return Input(structure, msas)
123
+
124
+
125
+ def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
126
+ """Collate the data.
127
+
128
+ Parameters
129
+ ----------
130
+ data : list[dict[str, Tensor]]
131
+ The data to collate.
132
+
133
+ Returns
134
+ -------
135
+ dict[str, Tensor]
136
+ The collated data.
137
+
138
+ """
139
+ # Get the keys
140
+ keys = data[0].keys()
141
+
142
+ # Collate the data
143
+ collated = {}
144
+ for key in keys:
145
+ values = [d[key] for d in data]
146
+
147
+ if key not in [
148
+ "all_coords",
149
+ "all_resolved_mask",
150
+ "crop_to_all_atom_map",
151
+ "chain_symmetries",
152
+ "amino_acids_symmetries",
153
+ "ligand_symmetries",
154
+ ]:
155
+ # Check if all have the same shape
156
+ shape = values[0].shape
157
+ if not all(v.shape == shape for v in values):
158
+ values, _ = pad_to_max(values, 0)
159
+ else:
160
+ values = torch.stack(values, dim=0)
161
+
162
+ # Stack the values
163
+ collated[key] = values
164
+
165
+ return collated
166
+
167
+
168
+ class TrainingDataset(torch.utils.data.Dataset):
169
+ """Base iterable dataset."""
170
+
171
+ def __init__(
172
+ self,
173
+ datasets: list[Dataset],
174
+ samples_per_epoch: int,
175
+ symmetries: dict,
176
+ max_atoms: int,
177
+ max_tokens: int,
178
+ max_seqs: int,
179
+ pad_to_max_atoms: bool = False,
180
+ pad_to_max_tokens: bool = False,
181
+ pad_to_max_seqs: bool = False,
182
+ atoms_per_window_queries: int = 32,
183
+ min_dist: float = 2.0,
184
+ max_dist: float = 22.0,
185
+ num_bins: int = 64,
186
+ overfit: Optional[int] = None,
187
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
188
+ binder_pocket_cutoff: Optional[float] = 6.0,
189
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
190
+ return_symmetries: Optional[bool] = False,
191
+ ) -> None:
192
+ """Initialize the training dataset."""
193
+ super().__init__()
194
+ self.datasets = datasets
195
+ self.probs = [d.prob for d in datasets]
196
+ self.samples_per_epoch = samples_per_epoch
197
+ self.symmetries = symmetries
198
+ self.max_tokens = max_tokens
199
+ self.max_seqs = max_seqs
200
+ self.max_atoms = max_atoms
201
+ self.pad_to_max_tokens = pad_to_max_tokens
202
+ self.pad_to_max_atoms = pad_to_max_atoms
203
+ self.pad_to_max_seqs = pad_to_max_seqs
204
+ self.atoms_per_window_queries = atoms_per_window_queries
205
+ self.min_dist = min_dist
206
+ self.max_dist = max_dist
207
+ self.num_bins = num_bins
208
+ self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
209
+ self.binder_pocket_cutoff = binder_pocket_cutoff
210
+ self.binder_pocket_sampling_geometric_p = binder_pocket_sampling_geometric_p
211
+ self.return_symmetries = return_symmetries
212
+ self.samples = []
213
+ for dataset in datasets:
214
+ records = dataset.manifest.records
215
+ if overfit is not None:
216
+ records = records[:overfit]
217
+ iterator = dataset.sampler.sample(records, np.random)
218
+ self.samples.append(iterator)
219
+
220
+ def __getitem__(self, idx: int) -> dict[str, Tensor]:
221
+ """Get an item from the dataset.
222
+
223
+ Parameters
224
+ ----------
225
+ idx : int
226
+ The data index.
227
+
228
+ Returns
229
+ -------
230
+ dict[str, Tensor]
231
+ The sampled data features.
232
+
233
+ """
234
+ # Pick a random dataset
235
+ dataset_idx = np.random.choice(
236
+ len(self.datasets),
237
+ p=self.probs,
238
+ )
239
+ dataset = self.datasets[dataset_idx]
240
+
241
+ # Get a sample from the dataset
242
+ sample: Sample = next(self.samples[dataset_idx])
243
+
244
+ # Get the structure
245
+ try:
246
+ input_data = load_input(sample.record, dataset.target_dir, dataset.msa_dir)
247
+ except Exception as e:
248
+ print(
249
+ f"Failed to load input for {sample.record.id} with error {e}. Skipping."
250
+ )
251
+ return self.__getitem__(idx)
252
+
253
+ # Tokenize structure
254
+ try:
255
+ tokenized = dataset.tokenizer.tokenize(input_data)
256
+ except Exception as e:
257
+ print(f"Tokenizer failed on {sample.record.id} with error {e}. Skipping.")
258
+ return self.__getitem__(idx)
259
+
260
+ # Compute crop
261
+ try:
262
+ if self.max_tokens is not None:
263
+ tokenized = dataset.cropper.crop(
264
+ tokenized,
265
+ max_atoms=self.max_atoms,
266
+ max_tokens=self.max_tokens,
267
+ random=np.random,
268
+ chain_id=sample.chain_id,
269
+ interface_id=sample.interface_id,
270
+ )
271
+ except Exception as e:
272
+ print(f"Cropper failed on {sample.record.id} with error {e}. Skipping.")
273
+ return self.__getitem__(idx)
274
+
275
+ # Check if there are tokens
276
+ if len(tokenized.tokens) == 0:
277
+ msg = "No tokens in cropped structure."
278
+ raise ValueError(msg)
279
+
280
+ # Compute features
281
+ try:
282
+ features = dataset.featurizer.process(
283
+ tokenized,
284
+ training=True,
285
+ max_atoms=self.max_atoms if self.pad_to_max_atoms else None,
286
+ max_tokens=self.max_tokens if self.pad_to_max_tokens else None,
287
+ max_seqs=self.max_seqs,
288
+ pad_to_max_seqs=self.pad_to_max_seqs,
289
+ symmetries=self.symmetries,
290
+ atoms_per_window_queries=self.atoms_per_window_queries,
291
+ min_dist=self.min_dist,
292
+ max_dist=self.max_dist,
293
+ num_bins=self.num_bins,
294
+ compute_symmetries=self.return_symmetries,
295
+ binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
296
+ binder_pocket_cutoff=self.binder_pocket_cutoff,
297
+ binder_pocket_sampling_geometric_p=self.binder_pocket_sampling_geometric_p,
298
+ )
299
+ except Exception as e:
300
+ print(f"Featurizer failed on {sample.record.id} with error {e}. Skipping.")
301
+ return self.__getitem__(idx)
302
+
303
+ return features
304
+
305
+ def __len__(self) -> int:
306
+ """Get the length of the dataset.
307
+
308
+ Returns
309
+ -------
310
+ int
311
+ The length of the dataset.
312
+
313
+ """
314
+ return self.samples_per_epoch
315
+
316
+
317
+ class ValidationDataset(torch.utils.data.Dataset):
318
+ """Base iterable dataset."""
319
+
320
+ def __init__(
321
+ self,
322
+ datasets: list[Dataset],
323
+ seed: int,
324
+ symmetries: dict,
325
+ max_atoms: Optional[int] = None,
326
+ max_tokens: Optional[int] = None,
327
+ max_seqs: Optional[int] = None,
328
+ pad_to_max_atoms: bool = False,
329
+ pad_to_max_tokens: bool = False,
330
+ pad_to_max_seqs: bool = False,
331
+ atoms_per_window_queries: int = 32,
332
+ min_dist: float = 2.0,
333
+ max_dist: float = 22.0,
334
+ num_bins: int = 64,
335
+ overfit: Optional[int] = None,
336
+ crop_validation: bool = False,
337
+ return_symmetries: Optional[bool] = False,
338
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
339
+ binder_pocket_cutoff: Optional[float] = 6.0,
340
+ ) -> None:
341
+ """Initialize the validation dataset."""
342
+ super().__init__()
343
+ self.datasets = datasets
344
+ self.max_atoms = max_atoms
345
+ self.max_tokens = max_tokens
346
+ self.max_seqs = max_seqs
347
+ self.seed = seed
348
+ self.symmetries = symmetries
349
+ self.random = np.random if overfit else np.random.RandomState(self.seed)
350
+ self.pad_to_max_tokens = pad_to_max_tokens
351
+ self.pad_to_max_atoms = pad_to_max_atoms
352
+ self.pad_to_max_seqs = pad_to_max_seqs
353
+ self.overfit = overfit
354
+ self.crop_validation = crop_validation
355
+ self.atoms_per_window_queries = atoms_per_window_queries
356
+ self.min_dist = min_dist
357
+ self.max_dist = max_dist
358
+ self.num_bins = num_bins
359
+ self.return_symmetries = return_symmetries
360
+ self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
361
+ self.binder_pocket_cutoff = binder_pocket_cutoff
362
+
363
+ def __getitem__(self, idx: int) -> dict[str, Tensor]:
364
+ """Get an item from the dataset.
365
+
366
+ Parameters
367
+ ----------
368
+ idx : int
369
+ The data index.
370
+
371
+ Returns
372
+ -------
373
+ dict[str, Tensor]
374
+ The sampled data features.
375
+
376
+ """
377
+ # Pick dataset based on idx
378
+ for dataset in self.datasets:
379
+ size = len(dataset.manifest.records)
380
+ if self.overfit is not None:
381
+ size = min(size, self.overfit)
382
+ if idx < size:
383
+ break
384
+ idx -= size
385
+
386
+ # Get a sample from the dataset
387
+ record = dataset.manifest.records[idx]
388
+
389
+ # Get the structure
390
+ try:
391
+ input_data = load_input(record, dataset.target_dir, dataset.msa_dir)
392
+ except Exception as e:
393
+ print(f"Failed to load input for {record.id} with error {e}. Skipping.")
394
+ return self.__getitem__(0)
395
+
396
+ # Tokenize structure
397
+ try:
398
+ tokenized = dataset.tokenizer.tokenize(input_data)
399
+ except Exception as e:
400
+ print(f"Tokenizer failed on {record.id} with error {e}. Skipping.")
401
+ return self.__getitem__(0)
402
+
403
+ # Compute crop
404
+ try:
405
+ if self.crop_validation and (self.max_tokens is not None):
406
+ tokenized = dataset.cropper.crop(
407
+ tokenized,
408
+ max_tokens=self.max_tokens,
409
+ random=self.random,
410
+ max_atoms=self.max_atoms,
411
+ )
412
+ except Exception as e:
413
+ print(f"Cropper failed on {record.id} with error {e}. Skipping.")
414
+ return self.__getitem__(0)
415
+
416
+ # Check if there are tokens
417
+ if len(tokenized.tokens) == 0:
418
+ msg = "No tokens in cropped structure."
419
+ raise ValueError(msg)
420
+
421
+ # Compute features
422
+ try:
423
+ pad_atoms = self.crop_validation and self.pad_to_max_atoms
424
+ pad_tokens = self.crop_validation and self.pad_to_max_tokens
425
+
426
+ features = dataset.featurizer.process(
427
+ tokenized,
428
+ training=False,
429
+ max_atoms=self.max_atoms if pad_atoms else None,
430
+ max_tokens=self.max_tokens if pad_tokens else None,
431
+ max_seqs=self.max_seqs,
432
+ pad_to_max_seqs=self.pad_to_max_seqs,
433
+ symmetries=self.symmetries,
434
+ atoms_per_window_queries=self.atoms_per_window_queries,
435
+ min_dist=self.min_dist,
436
+ max_dist=self.max_dist,
437
+ num_bins=self.num_bins,
438
+ compute_symmetries=self.return_symmetries,
439
+ binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
440
+ binder_pocket_cutoff=self.binder_pocket_cutoff,
441
+ binder_pocket_sampling_geometric_p=1.0, # this will only sample a single pocket token
442
+ only_ligand_binder_pocket=True,
443
+ )
444
+ except Exception as e:
445
+ print(f"Featurizer failed on {record.id} with error {e}. Skipping.")
446
+ return self.__getitem__(0)
447
+
448
+ return features
449
+
450
+ def __len__(self) -> int:
451
+ """Get the length of the dataset.
452
+
453
+ Returns
454
+ -------
455
+ int
456
+ The length of the dataset.
457
+
458
+ """
459
+ if self.overfit is not None:
460
+ length = sum(len(d.manifest.records[: self.overfit]) for d in self.datasets)
461
+ else:
462
+ length = sum(len(d.manifest.records) for d in self.datasets)
463
+
464
+ return length
465
+
466
+
467
+ class BoltzTrainingDataModule(pl.LightningDataModule):
468
+ """DataModule for boltz."""
469
+
470
+ def __init__(self, cfg: DataConfig) -> None:
471
+ """Initialize the DataModule.
472
+
473
+ Parameters
474
+ ----------
475
+ config : DataConfig
476
+ The data configuration.
477
+
478
+ """
479
+ super().__init__()
480
+ self.cfg = cfg
481
+
482
+ assert self.cfg.val_batch_size == 1, "Validation only works with batch size=1."
483
+
484
+ # Load symmetries
485
+ symmetries = get_symmetries(cfg.symmetries)
486
+
487
+ # Load datasets
488
+ train: list[Dataset] = []
489
+ val: list[Dataset] = []
490
+
491
+ for data_config in cfg.datasets:
492
+ # Set target_dir
493
+ target_dir = Path(data_config.target_dir)
494
+ msa_dir = Path(data_config.msa_dir)
495
+
496
+ # Load manifest
497
+ if data_config.manifest_path is not None:
498
+ path = Path(data_config.manifest_path)
499
+ else:
500
+ path = target_dir / "manifest.json"
501
+ manifest: Manifest = Manifest.load(path)
502
+
503
+ # Split records if given
504
+ if data_config.split is not None:
505
+ with Path(data_config.split).open("r") as f:
506
+ split = {x.lower() for x in f.read().splitlines()}
507
+
508
+ train_records = []
509
+ val_records = []
510
+ for record in manifest.records:
511
+ if record.id.lower() in split:
512
+ val_records.append(record)
513
+ else:
514
+ train_records.append(record)
515
+ else:
516
+ train_records = manifest.records
517
+ val_records = []
518
+
519
+ # Filter training records
520
+ train_records = [
521
+ record
522
+ for record in train_records
523
+ if all(f.filter(record) for f in cfg.filters)
524
+ ]
525
+ # Filter training records
526
+ if data_config.filters is not None:
527
+ train_records = [
528
+ record
529
+ for record in train_records
530
+ if all(f.filter(record) for f in data_config.filters)
531
+ ]
532
+
533
+ # Create train dataset
534
+ train_manifest = Manifest(train_records)
535
+ train.append(
536
+ Dataset(
537
+ target_dir,
538
+ msa_dir,
539
+ train_manifest,
540
+ data_config.prob,
541
+ data_config.sampler,
542
+ data_config.cropper,
543
+ cfg.tokenizer,
544
+ cfg.featurizer,
545
+ )
546
+ )
547
+
548
+ # Create validation dataset
549
+ if val_records:
550
+ val_manifest = Manifest(val_records)
551
+ val.append(
552
+ Dataset(
553
+ target_dir,
554
+ msa_dir,
555
+ val_manifest,
556
+ data_config.prob,
557
+ data_config.sampler,
558
+ data_config.cropper,
559
+ cfg.tokenizer,
560
+ cfg.featurizer,
561
+ )
562
+ )
563
+
564
+ # Print dataset sizes
565
+ for dataset in train:
566
+ dataset: Dataset
567
+ print(f"Training dataset size: {len(dataset.manifest.records)}")
568
+
569
+ for dataset in val:
570
+ dataset: Dataset
571
+ print(f"Validation dataset size: {len(dataset.manifest.records)}")
572
+
573
+ # Create wrapper datasets
574
+ self._train_set = TrainingDataset(
575
+ datasets=train,
576
+ samples_per_epoch=cfg.samples_per_epoch,
577
+ max_atoms=cfg.max_atoms,
578
+ max_tokens=cfg.max_tokens,
579
+ max_seqs=cfg.max_seqs,
580
+ pad_to_max_atoms=cfg.pad_to_max_atoms,
581
+ pad_to_max_tokens=cfg.pad_to_max_tokens,
582
+ pad_to_max_seqs=cfg.pad_to_max_seqs,
583
+ symmetries=symmetries,
584
+ atoms_per_window_queries=cfg.atoms_per_window_queries,
585
+ min_dist=cfg.min_dist,
586
+ max_dist=cfg.max_dist,
587
+ num_bins=cfg.num_bins,
588
+ overfit=cfg.overfit,
589
+ binder_pocket_conditioned_prop=cfg.train_binder_pocket_conditioned_prop,
590
+ binder_pocket_cutoff=cfg.binder_pocket_cutoff,
591
+ binder_pocket_sampling_geometric_p=cfg.binder_pocket_sampling_geometric_p,
592
+ return_symmetries=cfg.return_train_symmetries,
593
+ )
594
+ self._val_set = ValidationDataset(
595
+ datasets=train if cfg.overfit is not None else val,
596
+ seed=cfg.random_seed,
597
+ max_atoms=cfg.max_atoms,
598
+ max_tokens=cfg.max_tokens,
599
+ max_seqs=cfg.max_seqs,
600
+ pad_to_max_atoms=cfg.pad_to_max_atoms,
601
+ pad_to_max_tokens=cfg.pad_to_max_tokens,
602
+ pad_to_max_seqs=cfg.pad_to_max_seqs,
603
+ symmetries=symmetries,
604
+ atoms_per_window_queries=cfg.atoms_per_window_queries,
605
+ min_dist=cfg.min_dist,
606
+ max_dist=cfg.max_dist,
607
+ num_bins=cfg.num_bins,
608
+ overfit=cfg.overfit,
609
+ crop_validation=cfg.crop_validation,
610
+ return_symmetries=cfg.return_val_symmetries,
611
+ binder_pocket_conditioned_prop=cfg.val_binder_pocket_conditioned_prop,
612
+ binder_pocket_cutoff=cfg.binder_pocket_cutoff,
613
+ )
614
+
615
+ def setup(self, stage: Optional[str] = None) -> None:
616
+ """Run the setup for the DataModule.
617
+
618
+ Parameters
619
+ ----------
620
+ stage : str, optional
621
+ The stage, one of 'fit', 'validate', 'test'.
622
+
623
+ """
624
+ return
625
+
626
+ def train_dataloader(self) -> DataLoader:
627
+ """Get the training dataloader.
628
+
629
+ Returns
630
+ -------
631
+ DataLoader
632
+ The training dataloader.
633
+
634
+ """
635
+ return DataLoader(
636
+ self._train_set,
637
+ batch_size=self.cfg.batch_size,
638
+ num_workers=self.cfg.num_workers,
639
+ pin_memory=self.cfg.pin_memory,
640
+ shuffle=False,
641
+ collate_fn=collate,
642
+ )
643
+
644
+ def val_dataloader(self) -> DataLoader:
645
+ """Get the validation dataloader.
646
+
647
+ Returns
648
+ -------
649
+ DataLoader
650
+ The validation dataloader.
651
+
652
+ """
653
+ return DataLoader(
654
+ self._val_set,
655
+ batch_size=self.cfg.val_batch_size,
656
+ num_workers=self.cfg.num_workers,
657
+ pin_memory=self.cfg.pin_memory,
658
+ shuffle=False,
659
+ collate_fn=collate,
660
+ )
protify/FastPLMs/boltz/src/boltz/data/mol.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import pickle
3
+ import random
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ from rdkit.Chem import Mol
9
+ from tqdm import tqdm
10
+
11
+ from boltz.data import const
12
+ from boltz.data.pad import pad_dim
13
+ from boltz.model.loss.confidence import lddt_dist
14
+
15
+
16
+ def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
17
+ """Load the given input data.
18
+
19
+ Parameters
20
+ ----------
21
+ moldir : str
22
+ The path to the molecules directory.
23
+ molecules : list[str]
24
+ The molecules to load.
25
+
26
+ Returns
27
+ -------
28
+ dict[str, Mol]
29
+ The loaded molecules.
30
+ """
31
+ loaded_mols = {}
32
+ for molecule in molecules:
33
+ path = Path(moldir) / f"{molecule}.pkl"
34
+ if not path.exists():
35
+ msg = f"CCD component {molecule} not found!"
36
+ raise ValueError(msg)
37
+ with path.open("rb") as f:
38
+ loaded_mols[molecule] = pickle.load(f) # noqa: S301
39
+ return loaded_mols
40
+
41
+
42
+ def load_canonicals(moldir: str) -> dict[str, Mol]:
43
+ """Load the given input data.
44
+
45
+ Parameters
46
+ ----------
47
+ moldir : str
48
+ The molecules to load.
49
+
50
+ Returns
51
+ -------
52
+ dict[str, Mol]
53
+ The loaded molecules.
54
+
55
+ """
56
+ return load_molecules(moldir, const.canonical_tokens)
57
+
58
+
59
+ def load_all_molecules(moldir: str) -> dict[str, Mol]:
60
+ """Load the given input data.
61
+
62
+ Parameters
63
+ ----------
64
+ moldir : str
65
+ The path to the molecules directory.
66
+ molecules : list[str]
67
+ The molecules to load.
68
+
69
+ Returns
70
+ -------
71
+ dict[str, Mol]
72
+ The loaded molecules.
73
+
74
+ """
75
+ loaded_mols = {}
76
+ files = list(Path(moldir).glob("*.pkl"))
77
+ for path in tqdm(files, total=len(files), desc="Loading molecules", leave=False):
78
+ mol_name = path.stem
79
+ with path.open("rb") as f:
80
+ loaded_mols[mol_name] = pickle.load(f) # noqa: S301
81
+ return loaded_mols
82
+
83
+
84
+ def get_symmetries(mols: dict[str, Mol]) -> dict: # noqa: PLR0912
85
+ """Create a dictionary for the ligand symmetries.
86
+
87
+ Parameters
88
+ ----------
89
+ path : str
90
+ The path to the ligand symmetries.
91
+
92
+ Returns
93
+ -------
94
+ dict
95
+ The ligand symmetries.
96
+
97
+ """
98
+ symmetries = {}
99
+ for key, mol in mols.items():
100
+ try:
101
+ sym = pickle.loads(bytes.fromhex(mol.GetProp("symmetries"))) # noqa: S301
102
+
103
+ if mol.HasProp("pb_edge_index"):
104
+ edge_index = pickle.loads(
105
+ bytes.fromhex(mol.GetProp("pb_edge_index"))
106
+ ).astype(np.int64) # noqa: S301
107
+ lower_bounds = pickle.loads(
108
+ bytes.fromhex(mol.GetProp("pb_lower_bounds"))
109
+ ) # noqa: S301
110
+ upper_bounds = pickle.loads(
111
+ bytes.fromhex(mol.GetProp("pb_upper_bounds"))
112
+ ) # noqa: S301
113
+ bond_mask = pickle.loads(bytes.fromhex(mol.GetProp("pb_bond_mask"))) # noqa: S301
114
+ angle_mask = pickle.loads(bytes.fromhex(mol.GetProp("pb_angle_mask"))) # noqa: S301
115
+ else:
116
+ edge_index = np.empty((2, 0), dtype=np.int64)
117
+ lower_bounds = np.array([], dtype=np.float32)
118
+ upper_bounds = np.array([], dtype=np.float32)
119
+ bond_mask = np.array([], dtype=np.float32)
120
+ angle_mask = np.array([], dtype=np.float32)
121
+
122
+ if mol.HasProp("chiral_atom_index"):
123
+ chiral_atom_index = pickle.loads(
124
+ bytes.fromhex(mol.GetProp("chiral_atom_index"))
125
+ ).astype(np.int64)
126
+ chiral_check_mask = pickle.loads(
127
+ bytes.fromhex(mol.GetProp("chiral_check_mask"))
128
+ ).astype(np.int64)
129
+ chiral_atom_orientations = pickle.loads(
130
+ bytes.fromhex(mol.GetProp("chiral_atom_orientations"))
131
+ )
132
+ else:
133
+ chiral_atom_index = np.empty((4, 0), dtype=np.int64)
134
+ chiral_check_mask = np.array([], dtype=bool)
135
+ chiral_atom_orientations = np.array([], dtype=bool)
136
+
137
+ if mol.HasProp("stereo_bond_index"):
138
+ stereo_bond_index = pickle.loads(
139
+ bytes.fromhex(mol.GetProp("stereo_bond_index"))
140
+ ).astype(np.int64)
141
+ stereo_check_mask = pickle.loads(
142
+ bytes.fromhex(mol.GetProp("stereo_check_mask"))
143
+ ).astype(np.int64)
144
+ stereo_bond_orientations = pickle.loads(
145
+ bytes.fromhex(mol.GetProp("stereo_bond_orientations"))
146
+ )
147
+ else:
148
+ stereo_bond_index = np.empty((4, 0), dtype=np.int64)
149
+ stereo_check_mask = np.array([], dtype=bool)
150
+ stereo_bond_orientations = np.array([], dtype=bool)
151
+
152
+ if mol.HasProp("aromatic_5_ring_index"):
153
+ aromatic_5_ring_index = pickle.loads(
154
+ bytes.fromhex(mol.GetProp("aromatic_5_ring_index"))
155
+ ).astype(np.int64)
156
+ else:
157
+ aromatic_5_ring_index = np.empty((5, 0), dtype=np.int64)
158
+ if mol.HasProp("aromatic_6_ring_index"):
159
+ aromatic_6_ring_index = pickle.loads(
160
+ bytes.fromhex(mol.GetProp("aromatic_6_ring_index"))
161
+ ).astype(np.int64)
162
+ else:
163
+ aromatic_6_ring_index = np.empty((6, 0), dtype=np.int64)
164
+ if mol.HasProp("planar_double_bond_index"):
165
+ planar_double_bond_index = pickle.loads(
166
+ bytes.fromhex(mol.GetProp("planar_double_bond_index"))
167
+ ).astype(np.int64)
168
+ else:
169
+ planar_double_bond_index = np.empty((6, 0), dtype=np.int64)
170
+
171
+ atom_names = [atom.GetProp("name") for atom in mol.GetAtoms()]
172
+ symmetries[key] = (
173
+ sym,
174
+ atom_names,
175
+ edge_index,
176
+ lower_bounds,
177
+ upper_bounds,
178
+ bond_mask,
179
+ angle_mask,
180
+ chiral_atom_index,
181
+ chiral_check_mask,
182
+ chiral_atom_orientations,
183
+ stereo_bond_index,
184
+ stereo_check_mask,
185
+ stereo_bond_orientations,
186
+ aromatic_5_ring_index,
187
+ aromatic_6_ring_index,
188
+ planar_double_bond_index,
189
+ )
190
+ except Exception as e: # noqa: BLE001, PERF203, S110
191
+ pass
192
+
193
+ return symmetries
194
+
195
+
196
+ def compute_symmetry_idx_dictionary(data):
197
+ # Compute the symmetry index dictionary
198
+ total_count = 0
199
+ all_coords = []
200
+ for i, chain in enumerate(data.chains):
201
+ chain.start_idx = total_count
202
+ for j, token in enumerate(chain.tokens):
203
+ token.start_idx = total_count - chain.start_idx
204
+ all_coords.extend(
205
+ [[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
206
+ )
207
+ total_count += len(token.atoms)
208
+ return all_coords
209
+
210
+
211
+ def get_current_idx_list(data):
212
+ idx = []
213
+ for chain in data.chains:
214
+ if chain.in_crop:
215
+ for token in chain.tokens:
216
+ if token.in_crop:
217
+ idx.extend(
218
+ [
219
+ chain.start_idx + token.start_idx + i
220
+ for i in range(len(token.atoms))
221
+ ]
222
+ )
223
+ return idx
224
+
225
+
226
+ def all_different_after_swap(l):
227
+ final = [s[-1] for s in l]
228
+ return len(final) == len(set(final))
229
+
230
+
231
+ def minimum_lddt_symmetry_coords(
232
+ coords: torch.Tensor,
233
+ feats: dict,
234
+ index_batch: int,
235
+ ):
236
+ all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
237
+ all_resolved_mask = (
238
+ feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
239
+ )
240
+ crop_to_all_atom_map = (
241
+ feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
242
+ )
243
+ chain_symmetries = feats["chain_swaps"][index_batch]
244
+ amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
245
+ ligand_symmetries = feats["ligand_symmetries"][index_batch]
246
+
247
+ dmat_predicted = torch.cdist(
248
+ coords[:, : len(crop_to_all_atom_map)], coords[:, : len(crop_to_all_atom_map)]
249
+ )
250
+
251
+ # Check best symmetry on chain swap
252
+ best_true_coords = all_coords[:, crop_to_all_atom_map].clone()
253
+ best_true_resolved_mask = all_resolved_mask[crop_to_all_atom_map].clone()
254
+ best_lddt = -1.0
255
+ for c in chain_symmetries:
256
+ true_all_coords = all_coords.clone()
257
+ true_all_resolved_mask = all_resolved_mask.clone()
258
+ for start1, end1, start2, end2, chainidx1, chainidx2 in c:
259
+ true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
260
+ true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
261
+ true_coords = true_all_coords[:, crop_to_all_atom_map]
262
+ true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
263
+ dmat_true = torch.cdist(true_coords, true_coords)
264
+ pair_mask = (
265
+ true_resolved_mask[:, None]
266
+ * true_resolved_mask[None, :]
267
+ * (1 - torch.eye(len(true_resolved_mask))).to(true_resolved_mask)
268
+ )
269
+
270
+ lddt = lddt_dist(
271
+ dmat_predicted, dmat_true, pair_mask, cutoff=15.0, per_atom=False
272
+ )[0]
273
+ lddt = lddt.item()
274
+
275
+ if lddt > best_lddt and torch.sum(true_resolved_mask) > 3:
276
+ best_lddt = lddt
277
+ best_true_coords = true_coords
278
+ best_true_resolved_mask = true_resolved_mask
279
+
280
+ # atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
281
+ true_coords = best_true_coords.clone()
282
+ true_resolved_mask = best_true_resolved_mask.clone()
283
+ for symmetric_amino_or_lig in amino_acids_symmetries + ligand_symmetries:
284
+ best_lddt_improvement = 0.0
285
+
286
+ indices = set()
287
+ for c in symmetric_amino_or_lig:
288
+ for i, j in c:
289
+ indices.add(i)
290
+ indices = sorted(list(indices))
291
+ indices = torch.from_numpy(np.asarray(indices)).to(true_coords.device).long()
292
+ pred_coords_subset = coords[:, : len(crop_to_all_atom_map)][:, indices]
293
+ sub_dmat_pred = torch.cdist(
294
+ coords[:, : len(crop_to_all_atom_map)], pred_coords_subset
295
+ )
296
+
297
+ for c in symmetric_amino_or_lig:
298
+ # starting from greedy best, try to swap the atoms
299
+ new_true_coords = true_coords.clone()
300
+ new_true_resolved_mask = true_resolved_mask.clone()
301
+ for i, j in c:
302
+ new_true_coords[:, i] = true_coords[:, j]
303
+ new_true_resolved_mask[i] = true_resolved_mask[j]
304
+
305
+ true_coords_subset = true_coords[:, indices]
306
+ new_true_coords_subset = new_true_coords[:, indices]
307
+
308
+ sub_dmat_true = torch.cdist(true_coords, true_coords_subset)
309
+ sub_dmat_new_true = torch.cdist(new_true_coords, new_true_coords_subset)
310
+
311
+ sub_true_pair_lddt = (
312
+ true_resolved_mask[:, None] * true_resolved_mask[None, indices]
313
+ )
314
+ sub_true_pair_lddt[indices] = (
315
+ sub_true_pair_lddt[indices]
316
+ * (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
317
+ )
318
+
319
+ sub_new_true_pair_lddt = (
320
+ new_true_resolved_mask[:, None] * new_true_resolved_mask[None, indices]
321
+ )
322
+ sub_new_true_pair_lddt[indices] = (
323
+ sub_new_true_pair_lddt[indices]
324
+ * (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
325
+ )
326
+
327
+ lddt, total = lddt_dist(
328
+ sub_dmat_pred,
329
+ sub_dmat_true,
330
+ sub_true_pair_lddt,
331
+ cutoff=15.0,
332
+ per_atom=False,
333
+ )
334
+ new_lddt, new_total = lddt_dist(
335
+ sub_dmat_pred,
336
+ sub_dmat_new_true,
337
+ sub_new_true_pair_lddt,
338
+ cutoff=15.0,
339
+ per_atom=False,
340
+ )
341
+
342
+ lddt_improvement = new_lddt - lddt
343
+
344
+ if lddt_improvement > best_lddt_improvement:
345
+ best_true_coords = new_true_coords
346
+ best_true_resolved_mask = new_true_resolved_mask
347
+ best_lddt_improvement = lddt_improvement
348
+
349
+ # greedily update best coordinates after each amino acid
350
+ true_coords = best_true_coords.clone()
351
+ true_resolved_mask = best_true_resolved_mask.clone()
352
+
353
+ # Recomputing alignment
354
+ true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
355
+ true_resolved_mask = pad_dim(
356
+ true_resolved_mask,
357
+ 0,
358
+ coords.shape[1] - true_resolved_mask.shape[0],
359
+ )
360
+
361
+ return true_coords, true_resolved_mask.unsqueeze(0)
362
+
363
+
364
+ def compute_single_distogram_loss(pred, target, mask):
365
+ # Compute the distogram loss
366
+ errors = -1 * torch.sum(
367
+ target * torch.nn.functional.log_softmax(pred, dim=-1),
368
+ dim=-1,
369
+ )
370
+ denom = 1e-5 + torch.sum(mask, dim=(-1, -2))
371
+ mean = errors * mask
372
+ mean = torch.sum(mean, dim=-1)
373
+ mean = mean / denom[..., None]
374
+ batch_loss = torch.sum(mean, dim=-1)
375
+ global_loss = torch.mean(batch_loss)
376
+ return global_loss
377
+
378
+
379
+ def minimum_lddt_symmetry_dist(
380
+ pred_distogram: torch.Tensor,
381
+ feats: dict,
382
+ index_batch: int,
383
+ ):
384
+ # Note: for now only ligand symmetries are resolved
385
+
386
+ disto_target = feats["disto_target"][index_batch]
387
+ mask = feats["token_disto_mask"][index_batch]
388
+ mask = mask[None, :] * mask[:, None]
389
+ mask = mask * (1 - torch.eye(mask.shape[1])).to(disto_target)
390
+
391
+ coords = feats["coords"][index_batch]
392
+
393
+ ligand_symmetries = feats["ligand_symmetries"][index_batch]
394
+ atom_to_token_map = feats["atom_to_token"][index_batch].argmax(dim=-1)
395
+
396
+ # atom symmetries, resolved greedily without recomputing alignment
397
+ for symmetric_amino_or_lig in ligand_symmetries:
398
+ best_c, best_disto, best_loss_improvement = None, None, 0.0
399
+ for c in symmetric_amino_or_lig:
400
+ # starting from greedy best, try to swap the atoms
401
+ new_disto_target = disto_target.clone()
402
+ indices = []
403
+
404
+ # fix the distogram by replacing first the columns then the rows
405
+ disto_temp = new_disto_target.clone()
406
+ for i, j in c:
407
+ new_disto_target[:, atom_to_token_map[i]] = disto_temp[
408
+ :, atom_to_token_map[j]
409
+ ]
410
+ indices.append(atom_to_token_map[i].item())
411
+ disto_temp = new_disto_target.clone()
412
+ for i, j in c:
413
+ new_disto_target[atom_to_token_map[i], :] = disto_temp[
414
+ atom_to_token_map[j], :
415
+ ]
416
+
417
+ indices = (
418
+ torch.from_numpy(np.asarray(indices)).to(disto_target.device).long()
419
+ )
420
+
421
+ pred_distogram_subset = pred_distogram[:, indices]
422
+ disto_target_subset = disto_target[:, indices]
423
+ new_disto_target_subset = new_disto_target[:, indices]
424
+ mask_subset = mask[:, indices]
425
+
426
+ loss = compute_single_distogram_loss(
427
+ pred_distogram_subset, disto_target_subset, mask_subset
428
+ )
429
+ new_loss = compute_single_distogram_loss(
430
+ pred_distogram_subset, new_disto_target_subset, mask_subset
431
+ )
432
+ loss_improvement = (loss - new_loss) * len(indices)
433
+
434
+ if loss_improvement > best_loss_improvement:
435
+ best_c = c
436
+ best_disto = new_disto_target
437
+ best_loss_improvement = loss_improvement
438
+
439
+ # greedily update best coordinates after each ligand
440
+ if best_loss_improvement > 0:
441
+ disto_target = best_disto.clone()
442
+ old_coords = coords.clone()
443
+ for i, j in best_c:
444
+ coords[:, i] = old_coords[:, j]
445
+
446
+ # update features to be used in diffusion and in distogram loss
447
+ feats["disto_target"][index_batch] = disto_target
448
+ feats["coords"][index_batch] = coords
449
+ return
450
+
451
+
452
+ def compute_all_coords_mask(structure):
453
+ # Compute all coords, crop mask and add start_idx to structure
454
+ total_count = 0
455
+ all_coords = []
456
+ all_coords_crop_mask = []
457
+ all_resolved_mask = []
458
+ for i, chain in enumerate(structure.chains):
459
+ chain.start_idx = total_count
460
+ for j, token in enumerate(chain.tokens):
461
+ token.start_idx = total_count - chain.start_idx
462
+ all_coords.extend(
463
+ [[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
464
+ )
465
+ all_coords_crop_mask.extend(
466
+ [token.in_crop for _ in range(len(token.atoms))]
467
+ )
468
+ all_resolved_mask.extend(
469
+ [token.is_present for _ in range(len(token.atoms))]
470
+ )
471
+ total_count += len(token.atoms)
472
+ if len(all_coords_crop_mask) != len(all_resolved_mask):
473
+ pass
474
+ return all_coords, all_coords_crop_mask, all_resolved_mask
475
+
476
+
477
+ def get_chain_symmetries(cropped, max_n_symmetries=100):
478
+ # get all coordinates and resolved mask
479
+ structure = cropped.structure
480
+ all_coords = []
481
+ all_resolved_mask = []
482
+ original_atom_idx = []
483
+ chain_atom_idx = []
484
+ chain_atom_num = []
485
+ chain_in_crop = []
486
+ chain_asym_id = []
487
+ new_atom_idx = 0
488
+
489
+ for chain in structure.chains:
490
+ atom_idx, atom_num = (
491
+ chain["atom_idx"], # Global index of first atom in the chain
492
+ chain["atom_num"], # Number of atoms in the chain
493
+ )
494
+
495
+ # compute coordinates and resolved mask
496
+ resolved_mask = structure.atoms["is_present"][
497
+ atom_idx : atom_idx + atom_num
498
+ ] # Whether each atom in the chain is actually resolved
499
+
500
+ # ensemble_atom_starts = [structure.ensemble[idx]["atom_coord_idx"] for idx in cropped.ensemble_ref_idxs]
501
+ # coords = np.array(
502
+ # [structure.coords[ensemble_atom_start + atom_idx: ensemble_atom_start + atom_idx + atom_num]["coords"] for
503
+ # ensemble_atom_start in ensemble_atom_starts])
504
+
505
+ coords = structure.atoms["coords"][atom_idx : atom_idx + atom_num]
506
+
507
+ in_crop = False
508
+ for token in cropped.tokens:
509
+ if token["asym_id"] == chain["asym_id"]:
510
+ in_crop = True
511
+ break
512
+
513
+ all_coords.append(coords)
514
+ all_resolved_mask.append(resolved_mask)
515
+ original_atom_idx.append(atom_idx)
516
+ chain_atom_idx.append(new_atom_idx)
517
+ chain_atom_num.append(atom_num)
518
+ chain_in_crop.append(in_crop)
519
+ chain_asym_id.append(chain["asym_id"])
520
+
521
+ new_atom_idx += atom_num
522
+
523
+ all_coords = np.concatenate(all_coords, axis=0)
524
+ # Compute backmapping from token to all coords
525
+ crop_to_all_atom_map = []
526
+ for token in cropped.tokens:
527
+ chain_idx = chain_asym_id.index(token["asym_id"])
528
+ start = (
529
+ chain_atom_idx[chain_idx] - original_atom_idx[chain_idx] + token["atom_idx"]
530
+ )
531
+ crop_to_all_atom_map.append(np.arange(start, start + token["atom_num"]))
532
+ crop_to_all_atom_map = np.concatenate(crop_to_all_atom_map, axis=0)
533
+
534
+ # Compute the connections edge index for covalent bonds
535
+ all_atom_to_crop_map = np.zeros(all_coords.shape[0], dtype=np.int64)
536
+ all_atom_to_crop_map[crop_to_all_atom_map.astype(np.int64)] = np.arange(
537
+ crop_to_all_atom_map.shape[0]
538
+ )
539
+ connections_edge_index = []
540
+ for connection in structure.bonds:
541
+ if (connection["chain_1"] == connection["chain_2"]) and (
542
+ connection["res_1"] == connection["res_2"]
543
+ ):
544
+ continue
545
+ connections_edge_index.append([connection["atom_1"], connection["atom_2"]])
546
+ if len(connections_edge_index) > 0:
547
+ connections_edge_index = np.array(connections_edge_index, dtype=np.int64).T
548
+ connections_edge_index = all_atom_to_crop_map[connections_edge_index]
549
+ else:
550
+ connections_edge_index = np.empty((2, 0))
551
+
552
+ # Compute the symmetries between chains
553
+ symmetries = []
554
+ swaps = []
555
+ for i, chain in enumerate(structure.chains):
556
+ start = chain_atom_idx[i]
557
+ end = start + chain_atom_num[i]
558
+
559
+ if chain_in_crop[i]:
560
+ possible_swaps = []
561
+ for j, chain2 in enumerate(structure.chains):
562
+ start2 = chain_atom_idx[j]
563
+ end2 = start2 + chain_atom_num[j]
564
+ if (
565
+ chain["entity_id"] == chain2["entity_id"]
566
+ and end - start == end2 - start2
567
+ ):
568
+ possible_swaps.append((start, end, start2, end2, i, j))
569
+ swaps.append(possible_swaps)
570
+
571
+ found = False
572
+ for symmetry_idx, symmetry in enumerate(symmetries):
573
+ j = symmetry[0][0]
574
+ chain2 = structure.chains[j]
575
+ start2 = chain_atom_idx[j]
576
+ end2 = start2 + chain_atom_num[j]
577
+ if (
578
+ chain["entity_id"] == chain2["entity_id"]
579
+ and end - start == end2 - start2
580
+ ):
581
+ symmetries[symmetry_idx].append(
582
+ (i, start, end, chain_in_crop[i], chain["mol_type"])
583
+ )
584
+ found = True
585
+ if not found:
586
+ symmetries.append([(i, start, end, chain_in_crop[i], chain["mol_type"])])
587
+
588
+ combinations = itertools.product(*swaps)
589
+ # to avoid combinatorial explosion, bound the number of combinations even considered
590
+ combinations = list(itertools.islice(combinations, max_n_symmetries * 10))
591
+ # filter for all chains getting a different assignment
592
+ combinations = [c for c in combinations if all_different_after_swap(c)]
593
+
594
+ if len(combinations) > max_n_symmetries:
595
+ combinations = random.sample(combinations, max_n_symmetries)
596
+
597
+ if len(combinations) == 0:
598
+ combinations.append([])
599
+
600
+ for i in range(len(symmetries) - 1, -1, -1):
601
+ if not any(chain[3] for chain in symmetries[i]):
602
+ symmetries.pop(i)
603
+
604
+ features = {}
605
+ features["all_coords"] = torch.Tensor(all_coords) # axis=1 with ensemble
606
+
607
+ features["all_resolved_mask"] = torch.Tensor(
608
+ np.concatenate(all_resolved_mask, axis=0)
609
+ )
610
+ features["crop_to_all_atom_map"] = torch.Tensor(crop_to_all_atom_map)
611
+ features["chain_symmetries"] = symmetries
612
+ features["connections_edge_index"] = torch.tensor(connections_edge_index)
613
+ features["chain_swaps"] = combinations
614
+
615
+ return features
616
+
617
+
618
+ def get_amino_acids_symmetries(cropped):
619
+ # Compute standard amino-acids symmetries
620
+ swaps = []
621
+ start_index_crop = 0
622
+ for token in cropped.tokens:
623
+ symmetries = const.ref_symmetries.get(const.tokens[token["res_type"]], [])
624
+ if len(symmetries) > 0:
625
+ residue_swaps = []
626
+ for sym in symmetries:
627
+ sym_new_idx = [
628
+ (i + start_index_crop, j + start_index_crop) for i, j in sym
629
+ ]
630
+ residue_swaps.append(sym_new_idx)
631
+ swaps.append(residue_swaps)
632
+ start_index_crop += token["atom_num"]
633
+
634
+ features = {"amino_acids_symmetries": swaps}
635
+ return features
636
+
637
+
638
+ def slice_valid_index(index, ccd_to_valid_id_array, args=None):
639
+ index = ccd_to_valid_id_array[index]
640
+ valid_index_mask = (~np.isnan(index)).all(axis=0)
641
+ index = index[:, valid_index_mask]
642
+ if args is None:
643
+ return index
644
+ args = (arg[valid_index_mask] for arg in args)
645
+ return index, args
646
+
647
+
648
+ def get_ligand_symmetries(cropped, symmetries, return_physical_metrics=False):
649
+ # Compute ligand and non-standard amino-acids symmetries
650
+ structure = cropped.structure
651
+
652
+ added_molecules = {}
653
+ index_mols = []
654
+ atom_count = 0
655
+
656
+ for token in cropped.tokens:
657
+ # check if molecule is already added by identifying it through asym_id and res_idx
658
+ atom_count += token["atom_num"]
659
+ mol_id = (token["asym_id"], token["res_idx"])
660
+ if mol_id in added_molecules:
661
+ added_molecules[mol_id] += token["atom_num"]
662
+ continue
663
+ added_molecules[mol_id] = token["atom_num"]
664
+
665
+ # get the molecule type and indices
666
+ residue_idx = token["res_idx"] + structure.chains[token["asym_id"]]["res_idx"]
667
+ mol_name = structure.residues[residue_idx]["name"]
668
+ atom_idx = structure.residues[residue_idx]["atom_idx"]
669
+ mol_atom_names = structure.atoms[
670
+ atom_idx : atom_idx + structure.residues[residue_idx]["atom_num"]
671
+ ]["name"]
672
+ if mol_name not in const.ref_symmetries:
673
+ index_mols.append(
674
+ (mol_name, atom_count - token["atom_num"], mol_id, mol_atom_names)
675
+ )
676
+
677
+ # for each molecule, get the symmetries
678
+ molecule_symmetries = []
679
+ all_edge_index = []
680
+ all_lower_bounds, all_upper_bounds = [], []
681
+ all_bond_mask, all_angle_mask = [], []
682
+ all_chiral_atom_index, all_chiral_check_mask, all_chiral_atom_orientations = (
683
+ [],
684
+ [],
685
+ [],
686
+ )
687
+ all_stereo_bond_index, all_stereo_check_mask, all_stereo_bond_orientations = (
688
+ [],
689
+ [],
690
+ [],
691
+ )
692
+ (
693
+ all_aromatic_5_ring_index,
694
+ all_aromatic_6_ring_index,
695
+ all_planar_double_bond_index,
696
+ ) = (
697
+ [],
698
+ [],
699
+ [],
700
+ )
701
+ for mol_name, start_mol, mol_id, mol_atom_names in index_mols:
702
+ if not mol_name in symmetries:
703
+ continue
704
+ else:
705
+ swaps = []
706
+ (
707
+ syms_ccd,
708
+ mol_atom_names_ccd,
709
+ edge_index,
710
+ lower_bounds,
711
+ upper_bounds,
712
+ bond_mask,
713
+ angle_mask,
714
+ chiral_atom_index,
715
+ chiral_check_mask,
716
+ chiral_atom_orientations,
717
+ stereo_bond_index,
718
+ stereo_check_mask,
719
+ stereo_bond_orientations,
720
+ aromatic_5_ring_index,
721
+ aromatic_6_ring_index,
722
+ planar_double_bond_index,
723
+ ) = symmetries[mol_name]
724
+ # Get indices of mol_atom_names_ccd that are in mol_atom_names
725
+ ccd_to_valid_ids = {
726
+ mol_atom_names_ccd.index(name): i
727
+ for i, name in enumerate(mol_atom_names)
728
+ }
729
+ ccd_to_valid_id_array = np.array(
730
+ [
731
+ float("nan") if i not in ccd_to_valid_ids else ccd_to_valid_ids[i]
732
+ for i in range(len(mol_atom_names_ccd))
733
+ ]
734
+ )
735
+ ccd_valid_ids = set(ccd_to_valid_ids.keys())
736
+ syms = []
737
+ # Get syms
738
+ for sym_ccd in syms_ccd:
739
+ sym_dict = {}
740
+ bool_add = True
741
+ for i, j in enumerate(sym_ccd):
742
+ if i in ccd_valid_ids:
743
+ if j in ccd_valid_ids:
744
+ i_true = ccd_to_valid_ids[i]
745
+ j_true = ccd_to_valid_ids[j]
746
+ sym_dict[i_true] = j_true
747
+ else:
748
+ bool_add = False
749
+ break
750
+ if bool_add:
751
+ syms.append([sym_dict[i] for i in range(len(ccd_valid_ids))])
752
+ for sym in syms:
753
+ if len(sym) != added_molecules[mol_id]:
754
+ raise Exception(
755
+ f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
756
+ )
757
+ # assert (
758
+ # len(sym) == added_molecules[mol_id]
759
+ # ), f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
760
+ sym_new_idx = []
761
+ for i, j in enumerate(sym):
762
+ if i != int(j):
763
+ sym_new_idx.append((i + start_mol, int(j) + start_mol))
764
+ if len(sym_new_idx) > 0:
765
+ swaps.append(sym_new_idx)
766
+
767
+ if len(swaps) > 0:
768
+ molecule_symmetries.append(swaps)
769
+
770
+ if return_physical_metrics:
771
+ edge_index, (lower_bounds, upper_bounds, bond_mask, angle_mask) = (
772
+ slice_valid_index(
773
+ edge_index,
774
+ ccd_to_valid_id_array,
775
+ (lower_bounds, upper_bounds, bond_mask, angle_mask),
776
+ )
777
+ )
778
+ all_edge_index.append(edge_index + start_mol)
779
+ all_lower_bounds.append(lower_bounds)
780
+ all_upper_bounds.append(upper_bounds)
781
+ all_bond_mask.append(bond_mask)
782
+ all_angle_mask.append(angle_mask)
783
+
784
+ chiral_atom_index, (chiral_check_mask, chiral_atom_orientations) = (
785
+ slice_valid_index(
786
+ chiral_atom_index,
787
+ ccd_to_valid_id_array,
788
+ (chiral_check_mask, chiral_atom_orientations),
789
+ )
790
+ )
791
+ all_chiral_atom_index.append(chiral_atom_index + start_mol)
792
+ all_chiral_check_mask.append(chiral_check_mask)
793
+ all_chiral_atom_orientations.append(chiral_atom_orientations)
794
+
795
+ stereo_bond_index, (stereo_check_mask, stereo_bond_orientations) = (
796
+ slice_valid_index(
797
+ stereo_bond_index,
798
+ ccd_to_valid_id_array,
799
+ (stereo_check_mask, stereo_bond_orientations),
800
+ )
801
+ )
802
+ all_stereo_bond_index.append(stereo_bond_index + start_mol)
803
+ all_stereo_check_mask.append(stereo_check_mask)
804
+ all_stereo_bond_orientations.append(stereo_bond_orientations)
805
+
806
+ aromatic_5_ring_index = slice_valid_index(
807
+ aromatic_5_ring_index, ccd_to_valid_id_array
808
+ )
809
+ aromatic_6_ring_index = slice_valid_index(
810
+ aromatic_6_ring_index, ccd_to_valid_id_array
811
+ )
812
+ planar_double_bond_index = slice_valid_index(
813
+ planar_double_bond_index, ccd_to_valid_id_array
814
+ )
815
+ all_aromatic_5_ring_index.append(aromatic_5_ring_index + start_mol)
816
+ all_aromatic_6_ring_index.append(aromatic_6_ring_index + start_mol)
817
+ all_planar_double_bond_index.append(
818
+ planar_double_bond_index + start_mol
819
+ )
820
+
821
+ if return_physical_metrics:
822
+ if len(all_edge_index) > 0:
823
+ all_edge_index = np.concatenate(all_edge_index, axis=1)
824
+ all_lower_bounds = np.concatenate(all_lower_bounds, axis=0)
825
+ all_upper_bounds = np.concatenate(all_upper_bounds, axis=0)
826
+ all_bond_mask = np.concatenate(all_bond_mask, axis=0)
827
+ all_angle_mask = np.concatenate(all_angle_mask, axis=0)
828
+
829
+ all_chiral_atom_index = np.concatenate(all_chiral_atom_index, axis=1)
830
+ all_chiral_check_mask = np.concatenate(all_chiral_check_mask, axis=0)
831
+ all_chiral_atom_orientations = np.concatenate(
832
+ all_chiral_atom_orientations, axis=0
833
+ )
834
+
835
+ all_stereo_bond_index = np.concatenate(all_stereo_bond_index, axis=1)
836
+ all_stereo_check_mask = np.concatenate(all_stereo_check_mask, axis=0)
837
+ all_stereo_bond_orientations = np.concatenate(
838
+ all_stereo_bond_orientations, axis=0
839
+ )
840
+
841
+ all_aromatic_5_ring_index = np.concatenate(
842
+ all_aromatic_5_ring_index, axis=1
843
+ )
844
+ all_aromatic_6_ring_index = np.concatenate(
845
+ all_aromatic_6_ring_index, axis=1
846
+ )
847
+ all_planar_double_bond_index = np.empty(
848
+ (6, 0), dtype=np.int64
849
+ ) # TODO remove np.concatenate(all_planar_double_bond_index, axis=1)
850
+ else:
851
+ all_edge_index = np.empty((2, 0), dtype=np.int64)
852
+ all_lower_bounds = np.array([], dtype=np.float32)
853
+ all_upper_bounds = np.array([], dtype=np.float32)
854
+ all_bond_mask = np.array([], dtype=bool)
855
+ all_angle_mask = np.array([], dtype=bool)
856
+
857
+ all_chiral_atom_index = np.empty((4, 0), dtype=np.int64)
858
+ all_chiral_check_mask = np.array([], dtype=bool)
859
+ all_chiral_atom_orientations = np.array([], dtype=bool)
860
+
861
+ all_stereo_bond_index = np.empty((4, 0), dtype=np.int64)
862
+ all_stereo_check_mask = np.array([], dtype=bool)
863
+ all_stereo_bond_orientations = np.array([], dtype=bool)
864
+
865
+ all_aromatic_5_ring_index = np.empty((5, 0), dtype=np.int64)
866
+ all_aromatic_6_ring_index = np.empty((6, 0), dtype=np.int64)
867
+ all_planar_double_bond_index = np.empty((6, 0), dtype=np.int64)
868
+
869
+ features = {
870
+ "ligand_symmetries": molecule_symmetries,
871
+ "ligand_edge_index": torch.tensor(all_edge_index).long(),
872
+ "ligand_edge_lower_bounds": torch.tensor(all_lower_bounds),
873
+ "ligand_edge_upper_bounds": torch.tensor(all_upper_bounds),
874
+ "ligand_edge_bond_mask": torch.tensor(all_bond_mask),
875
+ "ligand_edge_angle_mask": torch.tensor(all_angle_mask),
876
+ "ligand_chiral_atom_index": torch.tensor(all_chiral_atom_index).long(),
877
+ "ligand_chiral_check_mask": torch.tensor(all_chiral_check_mask),
878
+ "ligand_chiral_atom_orientations": torch.tensor(
879
+ all_chiral_atom_orientations
880
+ ),
881
+ "ligand_stereo_bond_index": torch.tensor(all_stereo_bond_index).long(),
882
+ "ligand_stereo_check_mask": torch.tensor(all_stereo_check_mask),
883
+ "ligand_stereo_bond_orientations": torch.tensor(
884
+ all_stereo_bond_orientations
885
+ ),
886
+ "ligand_aromatic_5_ring_index": torch.tensor(
887
+ all_aromatic_5_ring_index
888
+ ).long(),
889
+ "ligand_aromatic_6_ring_index": torch.tensor(
890
+ all_aromatic_6_ring_index
891
+ ).long(),
892
+ "ligand_planar_double_bond_index": torch.tensor(
893
+ all_planar_double_bond_index
894
+ ).long(),
895
+ }
896
+ else:
897
+ features = {
898
+ "ligand_symmetries": molecule_symmetries,
899
+ }
900
+ return features
protify/FastPLMs/boltz/src/boltz/data/msa/__init__.py ADDED
File without changes
protify/FastPLMs/boltz/src/boltz/data/msa/mmseqs2.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/sokrypton/ColabFold/blob/main/colabfold/colabfold.py
2
+
3
+ import logging
4
+ import os
5
+ import random
6
+ import tarfile
7
+ import time
8
+ from typing import Optional, Union, Dict
9
+
10
+ import requests
11
+ from requests.auth import HTTPBasicAuth
12
+ from tqdm import tqdm
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ TQDM_BAR_FORMAT = (
17
+ "{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]"
18
+ )
19
+
20
+
21
+ def run_mmseqs2( # noqa: PLR0912, D103, C901, PLR0915
22
+ x: Union[str, list[str]],
23
+ prefix: str = "tmp",
24
+ use_env: bool = True,
25
+ use_filter: bool = True,
26
+ use_pairing: bool = False,
27
+ pairing_strategy: str = "greedy",
28
+ host_url: str = "https://api.colabfold.com",
29
+ msa_server_username: Optional[str] = None,
30
+ msa_server_password: Optional[str] = None,
31
+ auth_headers: Optional[Dict[str, str]] = None,
32
+ ) -> tuple[list[str], list[str]]:
33
+ submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"
34
+
35
+ # Validate mutually exclusive authentication methods
36
+ has_basic_auth = msa_server_username and msa_server_password
37
+ has_header_auth = auth_headers is not None
38
+ if has_basic_auth and (has_header_auth or auth_headers):
39
+ raise ValueError(
40
+ "Cannot use both basic authentication (username/password) and header/API key authentication. "
41
+ "Please use only one authentication method."
42
+ )
43
+
44
+ # Set header agent as boltz
45
+ headers = {}
46
+ headers["User-Agent"] = "boltz"
47
+
48
+ # Set up authentication
49
+ auth = None
50
+ if has_basic_auth:
51
+ auth = HTTPBasicAuth(msa_server_username, msa_server_password)
52
+ logger.debug(f"MMSeqs2 server authentication: using basic auth for user '{msa_server_username}'")
53
+ elif has_header_auth:
54
+ headers.update(auth_headers)
55
+ logger.debug("MMSeqs2 server authentication: using header-based authentication")
56
+ else:
57
+ logger.debug("MMSeqs2 server authentication: no credentials provided")
58
+
59
+ logger.debug(f"Connecting to MMSeqs2 server at: {host_url}")
60
+ logger.debug(f"Using endpoint: {submission_endpoint}")
61
+ logger.debug(f"Pairing strategy: {pairing_strategy}")
62
+ logger.debug(f"Use environment databases: {use_env}")
63
+ logger.debug(f"Use filtering: {use_filter}")
64
+
65
+ def submit(seqs, mode, N=101):
66
+ n, query = N, ""
67
+ for seq in seqs:
68
+ query += f">{n}\n{seq}\n"
69
+ n += 1
70
+
71
+ error_count = 0
72
+ while True:
73
+ try:
74
+ # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
75
+ # "good practice to set connect timeouts to slightly larger than a multiple of 3"
76
+ logger.debug(f"Submitting MSA request to {host_url}/{submission_endpoint}")
77
+ res = requests.post(
78
+ f"{host_url}/{submission_endpoint}",
79
+ data={"q": query, "mode": mode},
80
+ timeout=6.02,
81
+ headers=headers,
82
+ auth=auth,
83
+ )
84
+ logger.debug(f"MSA submission response status: {res.status_code}")
85
+ except Exception as e:
86
+ error_count += 1
87
+ logger.warning(
88
+ f"Error while fetching result from MSA server. Retrying... ({error_count}/5)"
89
+ )
90
+ logger.warning(f"Error: {e}")
91
+ if error_count > 5:
92
+ raise Exception(
93
+ "Too many failed attempts for the MSA generation request."
94
+ )
95
+ time.sleep(5)
96
+ else:
97
+ break
98
+
99
+ try:
100
+ out = res.json()
101
+ except ValueError:
102
+ logger.error(f"Server didn't reply with json: {res.text}")
103
+ out = {"status": "ERROR"}
104
+ return out
105
+
106
+ def status(ID):
107
+ error_count = 0
108
+ while True:
109
+ try:
110
+ logger.debug(f"Checking MSA job status for ID: {ID}")
111
+ res = requests.get(
112
+ f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers, auth=auth
113
+ )
114
+ logger.debug(f"MSA status check response status: {res.status_code}")
115
+ except Exception as e:
116
+ error_count += 1
117
+ logger.warning(
118
+ f"Error while fetching result from MSA server. Retrying... ({error_count}/5)"
119
+ )
120
+ logger.warning(f"Error: {e}")
121
+ if error_count > 5:
122
+ raise Exception(
123
+ "Too many failed attempts for the MSA generation request."
124
+ )
125
+ time.sleep(5)
126
+ else:
127
+ break
128
+ try:
129
+ out = res.json()
130
+ except ValueError:
131
+ logger.error(f"Server didn't reply with json: {res.text}")
132
+ out = {"status": "ERROR"}
133
+ return out
134
+
135
+ def download(ID, path):
136
+ error_count = 0
137
+ while True:
138
+ try:
139
+ logger.debug(f"Downloading MSA results for ID: {ID}")
140
+ res = requests.get(
141
+ f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers, auth=auth
142
+ )
143
+ logger.debug(f"MSA download response status: {res.status_code}")
144
+ except Exception as e:
145
+ error_count += 1
146
+ logger.warning(
147
+ f"Error while fetching result from MSA server. Retrying... ({error_count}/5)"
148
+ )
149
+ logger.warning(f"Error: {e}")
150
+ if error_count > 5:
151
+ raise Exception(
152
+ "Too many failed attempts for the MSA generation request."
153
+ )
154
+ time.sleep(5)
155
+ else:
156
+ break
157
+ with open(path, "wb") as out:
158
+ out.write(res.content)
159
+
160
+ # process input x
161
+ seqs = [x] if isinstance(x, str) else x
162
+
163
+ # setup mode
164
+ if use_filter:
165
+ mode = "env" if use_env else "all"
166
+ else:
167
+ mode = "env-nofilter" if use_env else "nofilter"
168
+
169
+ if use_pairing:
170
+ mode = ""
171
+ # greedy is default, complete was the previous behavior
172
+ if pairing_strategy == "greedy":
173
+ mode = "pairgreedy"
174
+ elif pairing_strategy == "complete":
175
+ mode = "paircomplete"
176
+ if use_env:
177
+ mode = mode + "-env"
178
+
179
+ # define path
180
+ path = f"{prefix}_{mode}"
181
+ if not os.path.isdir(path):
182
+ os.mkdir(path)
183
+
184
+ # call mmseqs2 api
185
+ tar_gz_file = f"{path}/out.tar.gz"
186
+ N, REDO = 101, True
187
+
188
+ # deduplicate and keep track of order
189
+ seqs_unique = []
190
+ # TODO this might be slow for large sets
191
+ [seqs_unique.append(x) for x in seqs if x not in seqs_unique]
192
+ Ms = [N + seqs_unique.index(seq) for seq in seqs]
193
+ # lets do it!
194
+ if not os.path.isfile(tar_gz_file):
195
+ TIME_ESTIMATE = 150 * len(seqs_unique)
196
+ with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
197
+ while REDO:
198
+ pbar.set_description("SUBMIT")
199
+
200
+ # Resubmit job until it goes through
201
+ out = submit(seqs_unique, mode, N)
202
+ while out["status"] in ["UNKNOWN", "RATELIMIT"]:
203
+ sleep_time = 5 + random.randint(0, 5)
204
+ logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
205
+ # resubmit
206
+ time.sleep(sleep_time)
207
+ out = submit(seqs_unique, mode, N)
208
+
209
+ if out["status"] == "ERROR":
210
+ msg = (
211
+ "MMseqs2 API is giving errors. Please confirm your "
212
+ " input is a valid protein sequence. If error persists, "
213
+ "please try again an hour later."
214
+ )
215
+ raise Exception(msg)
216
+
217
+ if out["status"] == "MAINTENANCE":
218
+ msg = (
219
+ "MMseqs2 API is undergoing maintenance. "
220
+ "Please try again in a few minutes."
221
+ )
222
+ raise Exception(msg)
223
+
224
+ # wait for job to finish
225
+ ID, TIME = out["id"], 0
226
+ logger.debug(f"MSA job submitted successfully with ID: {ID}")
227
+ pbar.set_description(out["status"])
228
+ while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]:
229
+ t = 5 + random.randint(0, 5)
230
+ logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
231
+ time.sleep(t)
232
+ out = status(ID)
233
+ pbar.set_description(out["status"])
234
+ if out["status"] == "RUNNING":
235
+ TIME += t
236
+ pbar.update(n=t)
237
+
238
+ if out["status"] == "COMPLETE":
239
+ logger.debug(f"MSA job completed successfully for ID: {ID}")
240
+ if TIME < TIME_ESTIMATE:
241
+ pbar.update(n=(TIME_ESTIMATE - TIME))
242
+ REDO = False
243
+
244
+ if out["status"] == "ERROR":
245
+ REDO = False
246
+ msg = (
247
+ "MMseqs2 API is giving errors. Please confirm your "
248
+ " input is a valid protein sequence. If error persists, "
249
+ "please try again an hour later."
250
+ )
251
+ raise Exception(msg)
252
+
253
+ # Download results
254
+ download(ID, tar_gz_file)
255
+
256
+ # prep list of a3m files
257
+ if use_pairing:
258
+ a3m_files = [f"{path}/pair.a3m"]
259
+ else:
260
+ a3m_files = [f"{path}/uniref.a3m"]
261
+ if use_env:
262
+ a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
263
+
264
+ # extract a3m files
265
+ if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
266
+ with tarfile.open(tar_gz_file) as tar_gz:
267
+ tar_gz.extractall(path)
268
+
269
+ # gather a3m lines
270
+ a3m_lines = {}
271
+ for a3m_file in a3m_files:
272
+ update_M, M = True, None
273
+ for line in open(a3m_file, "r"):
274
+ if len(line) > 0:
275
+ if "\x00" in line:
276
+ line = line.replace("\x00", "")
277
+ update_M = True
278
+ if line.startswith(">") and update_M:
279
+ M = int(line[1:].rstrip())
280
+ update_M = False
281
+ if M not in a3m_lines:
282
+ a3m_lines[M] = []
283
+ a3m_lines[M].append(line)
284
+
285
+ a3m_lines = ["".join(a3m_lines[n]) for n in Ms]
286
+ return a3m_lines
protify/FastPLMs/boltz/src/boltz/data/pad.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from torch.nn.functional import pad
4
+
5
+
6
+ def pad_dim(data: Tensor, dim: int, pad_len: float, value: float = 0) -> Tensor:
7
+ """Pad a tensor along a given dimension.
8
+
9
+ Parameters
10
+ ----------
11
+ data : Tensor
12
+ The input tensor.
13
+ dim : int
14
+ The dimension to pad.
15
+ pad_len : float
16
+ The padding length.
17
+ value : int, optional
18
+ The value to pad with.
19
+
20
+ Returns
21
+ -------
22
+ Tensor
23
+ The padded tensor.
24
+
25
+ """
26
+ if pad_len == 0:
27
+ return data
28
+
29
+ total_dims = len(data.shape)
30
+ padding = [0] * (2 * (total_dims - dim))
31
+ padding[2 * (total_dims - 1 - dim) + 1] = pad_len
32
+ return pad(data, tuple(padding), value=value)
33
+
34
+
35
+ def pad_to_max(data: list[Tensor], value: float = 0) -> tuple[Tensor, Tensor]:
36
+ """Pad the data in all dimensions to the maximum found.
37
+
38
+ Parameters
39
+ ----------
40
+ data : list[Tensor]
41
+ list of tensors to pad.
42
+ value : float
43
+ The value to use for padding.
44
+
45
+ Returns
46
+ -------
47
+ Tensor
48
+ The padded tensor.
49
+ Tensor
50
+ The padding mask.
51
+
52
+ """
53
+ if isinstance(data[0], str):
54
+ return data, 0
55
+
56
+ # Check if all have the same shape
57
+ if all(d.shape == data[0].shape for d in data):
58
+ return torch.stack(data, dim=0), 0
59
+
60
+ # Get the maximum in each dimension
61
+ num_dims = len(data[0].shape)
62
+ max_dims = [max(d.shape[i] for d in data) for i in range(num_dims)]
63
+
64
+ # Get the padding lengths
65
+ pad_lengths = []
66
+ for d in data:
67
+ dims = []
68
+ for i in range(num_dims):
69
+ dims.append(0)
70
+ dims.append(max_dims[num_dims - i - 1] - d.shape[num_dims - i - 1])
71
+ pad_lengths.append(dims)
72
+
73
+ # Pad the data
74
+ padding = [
75
+ pad(torch.ones_like(d), pad_len, value=0)
76
+ for d, pad_len in zip(data, pad_lengths)
77
+ ]
78
+ data = [pad(d, pad_len, value=value) for d, pad_len in zip(data, pad_lengths)]
79
+
80
+ # Stack the data
81
+ padding = torch.stack(padding, dim=0)
82
+ data = torch.stack(data, dim=0)
83
+
84
+ return data, padding
protify/FastPLMs/boltz/src/boltz/data/parse/__init__.py ADDED
File without changes
protify/FastPLMs/boltz/src/boltz/data/parse/a3m.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ from pathlib import Path
3
+ from typing import Optional, TextIO
4
+
5
+ import numpy as np
6
+
7
+ from boltz.data import const
8
+ from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
9
+
10
+
11
+ def _parse_a3m( # noqa: C901
12
+ lines: TextIO,
13
+ taxonomy: Optional[dict[str, str]],
14
+ max_seqs: Optional[int] = None,
15
+ ) -> MSA:
16
+ """Process an MSA file.
17
+
18
+ Parameters
19
+ ----------
20
+ lines : TextIO
21
+ The lines of the MSA file.
22
+ taxonomy : dict[str, str]
23
+ The taxonomy database, if available.
24
+ max_seqs : int, optional
25
+ The maximum number of sequences.
26
+
27
+ Returns
28
+ -------
29
+ MSA
30
+ The MSA object.
31
+
32
+ """
33
+ visited = set()
34
+ sequences = []
35
+ deletions = []
36
+ residues = []
37
+
38
+ seq_idx = 0
39
+ for line in lines:
40
+ line: str
41
+ line = line.strip() # noqa: PLW2901
42
+ if not line or line.startswith("#"):
43
+ continue
44
+
45
+ # Get taxonomy, if annotated
46
+ if line.startswith(">"):
47
+ header = line.split()[0]
48
+ if taxonomy and header.startswith(">UniRef100"):
49
+ uniref_id = header.split("_")[1]
50
+ taxonomy_id = taxonomy.get(uniref_id)
51
+ if taxonomy_id is None:
52
+ taxonomy_id = -1
53
+ else:
54
+ taxonomy_id = -1
55
+ continue
56
+
57
+ # Skip if duplicate sequence
58
+ str_seq = line.replace("-", "").upper()
59
+ if str_seq not in visited:
60
+ visited.add(str_seq)
61
+ else:
62
+ continue
63
+
64
+ # Process sequence
65
+ residue = []
66
+ deletion = []
67
+ count = 0
68
+ res_idx = 0
69
+ for c in line:
70
+ if c != "-" and c.islower():
71
+ count += 1
72
+ continue
73
+ token = const.prot_letter_to_token[c]
74
+ token = const.token_ids[token]
75
+ residue.append(token)
76
+ if count > 0:
77
+ deletion.append((res_idx, count))
78
+ count = 0
79
+ res_idx += 1
80
+
81
+ res_start = len(residues)
82
+ res_end = res_start + len(residue)
83
+
84
+ del_start = len(deletions)
85
+ del_end = del_start + len(deletion)
86
+
87
+ sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
88
+ residues.extend(residue)
89
+ deletions.extend(deletion)
90
+
91
+ seq_idx += 1
92
+ if (max_seqs is not None) and (seq_idx >= max_seqs):
93
+ break
94
+
95
+ # Create MSA object
96
+ msa = MSA(
97
+ residues=np.array(residues, dtype=MSAResidue),
98
+ deletions=np.array(deletions, dtype=MSADeletion),
99
+ sequences=np.array(sequences, dtype=MSASequence),
100
+ )
101
+ return msa
102
+
103
+
104
+ def parse_a3m(
105
+ path: Path,
106
+ taxonomy: Optional[dict[str, str]],
107
+ max_seqs: Optional[int] = None,
108
+ ) -> MSA:
109
+ """Process an A3M file.
110
+
111
+ Parameters
112
+ ----------
113
+ path : Path
114
+ The path to the a3m(.gz) file.
115
+ taxonomy : Redis
116
+ The taxonomy database.
117
+ max_seqs : int, optional
118
+ The maximum number of sequences.
119
+
120
+ Returns
121
+ -------
122
+ MSA
123
+ The MSA object.
124
+
125
+ """
126
+ # Read the file
127
+ if path.suffix == ".gz":
128
+ with gzip.open(str(path), "rt") as f:
129
+ msa = _parse_a3m(f, taxonomy, max_seqs)
130
+ else:
131
+ with path.open("r") as f:
132
+ msa = _parse_a3m(f, taxonomy, max_seqs)
133
+
134
+ return msa
protify/FastPLMs/boltz/src/boltz/data/parse/csv.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ from boltz.data import const
8
+ from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
9
+
10
+
11
+ def parse_csv(
12
+ path: Path,
13
+ max_seqs: Optional[int] = None,
14
+ ) -> MSA:
15
+ """Process an A3M file.
16
+
17
+ Parameters
18
+ ----------
19
+ path : Path
20
+ The path to the a3m(.gz) file.
21
+ max_seqs : int, optional
22
+ The maximum number of sequences.
23
+
24
+ Returns
25
+ -------
26
+ MSA
27
+ The MSA object.
28
+
29
+ """
30
+ # Read file
31
+ data = pd.read_csv(path)
32
+
33
+ # Check columns
34
+ if tuple(sorted(data.columns)) != ("key", "sequence"):
35
+ msg = "Invalid CSV format, expected columns: ['sequence', 'key']"
36
+ raise ValueError(msg)
37
+
38
+ # Create taxonomy mapping
39
+ visited = set()
40
+ sequences = []
41
+ deletions = []
42
+ residues = []
43
+
44
+ seq_idx = 0
45
+ for line, key in zip(data["sequence"], data["key"]):
46
+ line: str
47
+ line = line.strip() # noqa: PLW2901
48
+ if not line:
49
+ continue
50
+
51
+ # Get taxonomy, if annotated
52
+ taxonomy_id = -1
53
+ if (str(key) != "nan") and (key is not None) and (key != ""):
54
+ taxonomy_id = key
55
+
56
+ # Skip if duplicate sequence
57
+ str_seq = line.replace("-", "").upper()
58
+ if str_seq not in visited:
59
+ visited.add(str_seq)
60
+ else:
61
+ continue
62
+
63
+ # Process sequence
64
+ residue = []
65
+ deletion = []
66
+ count = 0
67
+ res_idx = 0
68
+ for c in line:
69
+ if c != "-" and c.islower():
70
+ count += 1
71
+ continue
72
+ token = const.prot_letter_to_token[c]
73
+ token = const.token_ids[token]
74
+ residue.append(token)
75
+ if count > 0:
76
+ deletion.append((res_idx, count))
77
+ count = 0
78
+ res_idx += 1
79
+
80
+ res_start = len(residues)
81
+ res_end = res_start + len(residue)
82
+
83
+ del_start = len(deletions)
84
+ del_end = del_start + len(deletion)
85
+
86
+ sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
87
+ residues.extend(residue)
88
+ deletions.extend(deletion)
89
+
90
+ seq_idx += 1
91
+ if (max_seqs is not None) and (seq_idx >= max_seqs):
92
+ break
93
+
94
+ # Create MSA object
95
+ msa = MSA(
96
+ residues=np.array(residues, dtype=MSAResidue),
97
+ deletions=np.array(deletions, dtype=MSADeletion),
98
+ sequences=np.array(sequences, dtype=MSASequence),
99
+ )
100
+ return msa
protify/FastPLMs/boltz/src/boltz/data/parse/fasta.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Mapping
2
+ from pathlib import Path
3
+
4
+ from Bio import SeqIO
5
+ from rdkit.Chem.rdchem import Mol
6
+
7
+ from boltz.data.parse.yaml import parse_boltz_schema
8
+ from boltz.data.types import Target
9
+
10
+
11
+ def parse_fasta( # noqa: C901, PLR0912
12
+ path: Path,
13
+ ccd: Mapping[str, Mol],
14
+ mol_dir: Path,
15
+ boltz2: bool = False,
16
+ ) -> Target:
17
+ """Parse a fasta file.
18
+
19
+ The name of the fasta file is used as the name of this job.
20
+ We rely on the fasta record id to determine the entity type.
21
+
22
+ > CHAIN_ID|ENTITY_TYPE|MSA_ID
23
+ SEQUENCE
24
+ > CHAIN_ID|ENTITY_TYPE|MSA_ID
25
+ ...
26
+
27
+ Where ENTITY_TYPE is either protein, rna, dna, ccd or smiles,
28
+ and CHAIN_ID is the chain identifier, which should be unique.
29
+ The MSA_ID is optional and should only be used on proteins.
30
+
31
+ Parameters
32
+ ----------
33
+ fasta_file : Path
34
+ Path to the fasta file.
35
+ ccd : Dict
36
+ Dictionary of CCD components.
37
+ mol_dir : Path
38
+ Path to the directory containing the molecules.
39
+ boltz2 : bool
40
+ Whether to parse the input for Boltz2.
41
+
42
+ Returns
43
+ -------
44
+ Target
45
+ The parsed target.
46
+
47
+ """
48
+ # Read fasta file
49
+ with path.open("r") as f:
50
+ records = list(SeqIO.parse(f, "fasta"))
51
+
52
+ # Make sure all records have a chain id and entity
53
+ for seq_record in records:
54
+ if "|" not in seq_record.id:
55
+ msg = f"Invalid record id: {seq_record.id}"
56
+ raise ValueError(msg)
57
+
58
+ header = seq_record.id.split("|")
59
+ assert len(header) >= 2, f"Invalid record id: {seq_record.id}"
60
+
61
+ chain_id, entity_type = header[:2]
62
+ if entity_type.lower() not in {"protein", "dna", "rna", "ccd", "smiles"}:
63
+ msg = f"Invalid entity type: {entity_type}"
64
+ raise ValueError(msg)
65
+ if chain_id == "":
66
+ msg = "Empty chain id in input fasta!"
67
+ raise ValueError(msg)
68
+ if entity_type == "":
69
+ msg = "Empty entity type in input fasta!"
70
+ raise ValueError(msg)
71
+
72
+ # Convert to yaml format
73
+ sequences = []
74
+ for seq_record in records:
75
+ # Get chain id, entity type and sequence
76
+ header = seq_record.id.split("|")
77
+ chain_id, entity_type = header[:2]
78
+ if len(header) == 3 and header[2] != "":
79
+ assert entity_type.lower() == "protein", (
80
+ "MSA_ID is only allowed for proteins"
81
+ )
82
+ msa_id = header[2]
83
+ else:
84
+ msa_id = None
85
+
86
+ entity_type = entity_type.upper()
87
+ seq = str(seq_record.seq)
88
+
89
+ if entity_type == "PROTEIN":
90
+ molecule = {
91
+ "protein": {
92
+ "id": chain_id,
93
+ "sequence": seq,
94
+ "modifications": [],
95
+ "msa": msa_id,
96
+ },
97
+ }
98
+ elif entity_type == "RNA":
99
+ molecule = {
100
+ "rna": {
101
+ "id": chain_id,
102
+ "sequence": seq,
103
+ "modifications": [],
104
+ },
105
+ }
106
+ elif entity_type == "DNA":
107
+ molecule = {
108
+ "dna": {
109
+ "id": chain_id,
110
+ "sequence": seq,
111
+ "modifications": [],
112
+ }
113
+ }
114
+ elif entity_type.upper() == "CCD":
115
+ molecule = {
116
+ "ligand": {
117
+ "id": chain_id,
118
+ "ccd": seq,
119
+ }
120
+ }
121
+ elif entity_type.upper() == "SMILES":
122
+ molecule = {
123
+ "ligand": {
124
+ "id": chain_id,
125
+ "smiles": seq,
126
+ }
127
+ }
128
+
129
+ sequences.append(molecule)
130
+
131
+ data = {
132
+ "sequences": sequences,
133
+ "bonds": [],
134
+ "version": 1,
135
+ }
136
+
137
+ name = path.stem
138
+ return parse_boltz_schema(name, data, ccd, mol_dir, boltz2)