Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +65 -0
- config.json +116 -0
- model.safetensors +3 -0
- packaged_probe_model.py +215 -0
- protify/FastPLMs/__init__.py +0 -0
- protify/FastPLMs/boltz/scripts/eval/aggregate_evals.py +753 -0
- protify/FastPLMs/boltz/scripts/eval/physcialsim_metrics.py +304 -0
- protify/FastPLMs/boltz/scripts/eval/run_evals.py +167 -0
- protify/FastPLMs/boltz/scripts/process/ccd.py +295 -0
- protify/FastPLMs/boltz/scripts/process/cluster.py +111 -0
- protify/FastPLMs/boltz/scripts/process/mmcif.py +1123 -0
- protify/FastPLMs/boltz/scripts/process/msa.py +130 -0
- protify/FastPLMs/boltz/scripts/process/rcsb.py +359 -0
- protify/FastPLMs/boltz/scripts/train/train.py +241 -0
- protify/FastPLMs/boltz/src/boltz/__init__.py +7 -0
- protify/FastPLMs/boltz/src/boltz/data/__init__.py +0 -0
- protify/FastPLMs/boltz/src/boltz/data/const.py +1184 -0
- protify/FastPLMs/boltz/src/boltz/data/crop/__init__.py +0 -0
- protify/FastPLMs/boltz/src/boltz/data/crop/affinity.py +164 -0
- protify/FastPLMs/boltz/src/boltz/data/crop/boltz.py +296 -0
- protify/FastPLMs/boltz/src/boltz/data/crop/cropper.py +45 -0
- protify/FastPLMs/boltz/src/boltz/data/feature/__init__.py +0 -0
- protify/FastPLMs/boltz/src/boltz/data/feature/featurizer.py +1225 -0
- protify/FastPLMs/boltz/src/boltz/data/feature/featurizerv2.py +2354 -0
- protify/FastPLMs/boltz/src/boltz/data/feature/symmetry.py +602 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/__init__.py +0 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/__init__.py +0 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/date.py +76 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/filter.py +24 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/max_residues.py +37 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/resolution.py +34 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/size.py +38 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/subset.py +42 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/static/__init__.py +0 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/static/filter.py +26 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/static/ligand.py +37 -0
- protify/FastPLMs/boltz/src/boltz/data/filter/static/polymer.py +299 -0
- protify/FastPLMs/boltz/src/boltz/data/module/__init__.py +0 -0
- protify/FastPLMs/boltz/src/boltz/data/module/inference.py +310 -0
- protify/FastPLMs/boltz/src/boltz/data/module/inferencev2.py +433 -0
- protify/FastPLMs/boltz/src/boltz/data/module/training.py +687 -0
- protify/FastPLMs/boltz/src/boltz/data/module/trainingv2.py +660 -0
- protify/FastPLMs/boltz/src/boltz/data/mol.py +900 -0
- protify/FastPLMs/boltz/src/boltz/data/msa/__init__.py +0 -0
- protify/FastPLMs/boltz/src/boltz/data/msa/mmseqs2.py +286 -0
- protify/FastPLMs/boltz/src/boltz/data/pad.py +84 -0
- protify/FastPLMs/boltz/src/boltz/data/parse/__init__.py +0 -0
- protify/FastPLMs/boltz/src/boltz/data/parse/a3m.py +134 -0
- protify/FastPLMs/boltz/src/boltz/data/parse/csv.py +100 -0
- protify/FastPLMs/boltz/src/boltz/data/parse/fasta.py +138 -0
README.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags: []
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# nikraf/OmniPath_2class_clustered-30_ESMC-600_2026-03-11-15-46_NQRV
|
| 7 |
+
|
| 8 |
+
Fine-tuned with Protify.
|
| 9 |
+
|
| 10 |
+
## About Protify
|
| 11 |
+
|
| 12 |
+
Protify is an open source platform designed to simplify and democratize workflows for chemical language models. With Protify, deep learning models can be trained to predict chemical properties without requiring extensive coding knowledge or computational resources.
|
| 13 |
+
|
| 14 |
+
### Why Protify?
|
| 15 |
+
|
| 16 |
+
- Benchmark multiple models efficiently.
|
| 17 |
+
- Flexible for all skill levels.
|
| 18 |
+
- Accessible computing with support for precomputed embeddings.
|
| 19 |
+
- Cost-effective workflows for training and evaluation.
|
| 20 |
+
|
| 21 |
+
## Training Run
|
| 22 |
+
|
| 23 |
+
- `dataset`: OmniPath_2class_clustered-30
|
| 24 |
+
- `model`: ESMC-600
|
| 25 |
+
- `run_id`: 2026-03-11-15-46_NQRV
|
| 26 |
+
- `task_type`: singlelabel
|
| 27 |
+
- `num_runs`: 1
|
| 28 |
+
|
| 29 |
+
## Dataset Statistics
|
| 30 |
+
|
| 31 |
+
- `train_size`: 102872
|
| 32 |
+
- `valid_size`: 18102
|
| 33 |
+
- `test_size`: 18074
|
| 34 |
+
|
| 35 |
+
## Validation Metrics
|
| 36 |
+
|
| 37 |
+
- `epoch`: 5.000000
|
| 38 |
+
- `eval_accuracy`: 0.789750
|
| 39 |
+
- `eval_f1`: 0.789330
|
| 40 |
+
- `eval_loss`: 0.445219
|
| 41 |
+
- `eval_mcc`: 0.581780
|
| 42 |
+
- `eval_model_preparation_time`: 0.000300
|
| 43 |
+
- `eval_pr_auc`: 0.884610
|
| 44 |
+
- `eval_precision`: 0.792040
|
| 45 |
+
- `eval_recall`: 0.789750
|
| 46 |
+
- `eval_roc_auc`: 0.880010
|
| 47 |
+
- `eval_runtime`: 21.260300
|
| 48 |
+
- `eval_samples_per_second`: 851.444000
|
| 49 |
+
- `eval_steps_per_second`: 13.311000
|
| 50 |
+
|
| 51 |
+
## Test Metrics
|
| 52 |
+
|
| 53 |
+
- `test_accuracy`: 0.779350
|
| 54 |
+
- `test_f1`: 0.778210
|
| 55 |
+
- `test_loss`: 0.455012
|
| 56 |
+
- `test_mcc`: 0.564560
|
| 57 |
+
- `test_model_preparation_time`: 0.000300
|
| 58 |
+
- `test_pr_auc`: 0.884200
|
| 59 |
+
- `test_precision`: 0.785240
|
| 60 |
+
- `test_recall`: 0.779350
|
| 61 |
+
- `test_roc_auc`: 0.874270
|
| 62 |
+
- `test_runtime`: 21.119900
|
| 63 |
+
- `test_samples_per_second`: 855.780000
|
| 64 |
+
- `test_steps_per_second`: 13.400000
|
| 65 |
+
- `training_time_seconds`: 1235.285100
|
config.json
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_token_ids": false,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"PackagedProbeModel"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "packaged_probe_model.PackagedProbeConfig",
|
| 8 |
+
"AutoModel": "packaged_probe_model.PackagedProbeModel"
|
| 9 |
+
},
|
| 10 |
+
"base_model_name": "ESMC-600",
|
| 11 |
+
"dtype": "float32",
|
| 12 |
+
"matrix_embed": true,
|
| 13 |
+
"model_type": "packaged_probe",
|
| 14 |
+
"pooling_types": [
|
| 15 |
+
"mean",
|
| 16 |
+
"var"
|
| 17 |
+
],
|
| 18 |
+
"ppi": true,
|
| 19 |
+
"probe_config": {
|
| 20 |
+
"_name_or_path": "",
|
| 21 |
+
"add_cross_attention": false,
|
| 22 |
+
"add_token_ids": false,
|
| 23 |
+
"architectures": [
|
| 24 |
+
"TransformerForSequenceClassification"
|
| 25 |
+
],
|
| 26 |
+
"bad_words_ids": null,
|
| 27 |
+
"begin_suppress_tokens": null,
|
| 28 |
+
"bos_token_id": null,
|
| 29 |
+
"chunk_size_feed_forward": 0,
|
| 30 |
+
"classifier_dropout": 0.2,
|
| 31 |
+
"classifier_size": 4096,
|
| 32 |
+
"cross_attention_hidden_size": null,
|
| 33 |
+
"decoder_start_token_id": null,
|
| 34 |
+
"diversity_penalty": 0.0,
|
| 35 |
+
"do_sample": false,
|
| 36 |
+
"dropout": 0.2,
|
| 37 |
+
"dtype": "float32",
|
| 38 |
+
"early_stopping": false,
|
| 39 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 40 |
+
"eos_token_id": null,
|
| 41 |
+
"exponential_decay_length_penalty": null,
|
| 42 |
+
"finetuning_task": null,
|
| 43 |
+
"forced_bos_token_id": null,
|
| 44 |
+
"forced_eos_token_id": null,
|
| 45 |
+
"hidden_size": 512,
|
| 46 |
+
"id2label": {
|
| 47 |
+
"0": "LABEL_0",
|
| 48 |
+
"1": "LABEL_1"
|
| 49 |
+
},
|
| 50 |
+
"input_size": 1152,
|
| 51 |
+
"is_decoder": false,
|
| 52 |
+
"is_encoder_decoder": false,
|
| 53 |
+
"label2id": {
|
| 54 |
+
"LABEL_0": 0,
|
| 55 |
+
"LABEL_1": 1
|
| 56 |
+
},
|
| 57 |
+
"length_penalty": 1.0,
|
| 58 |
+
"lora": false,
|
| 59 |
+
"lora_alpha": 32.0,
|
| 60 |
+
"lora_dropout": 0.01,
|
| 61 |
+
"lora_r": 8,
|
| 62 |
+
"max_length": 20,
|
| 63 |
+
"min_length": 0,
|
| 64 |
+
"model_type": "probe",
|
| 65 |
+
"n_heads": 4,
|
| 66 |
+
"n_layers": 1,
|
| 67 |
+
"no_repeat_ngram_size": 0,
|
| 68 |
+
"num_beam_groups": 1,
|
| 69 |
+
"num_beams": 1,
|
| 70 |
+
"num_return_sequences": 1,
|
| 71 |
+
"output_attentions": false,
|
| 72 |
+
"output_hidden_states": false,
|
| 73 |
+
"output_scores": false,
|
| 74 |
+
"pad_token_id": null,
|
| 75 |
+
"pooling_types": [
|
| 76 |
+
"mean",
|
| 77 |
+
"cls"
|
| 78 |
+
],
|
| 79 |
+
"pre_ln": true,
|
| 80 |
+
"prefix": null,
|
| 81 |
+
"probe_type": "transformer",
|
| 82 |
+
"problem_type": null,
|
| 83 |
+
"pruned_heads": {},
|
| 84 |
+
"remove_invalid_values": false,
|
| 85 |
+
"repetition_penalty": 1.0,
|
| 86 |
+
"return_dict": true,
|
| 87 |
+
"return_dict_in_generate": false,
|
| 88 |
+
"rotary": true,
|
| 89 |
+
"sep_token_id": null,
|
| 90 |
+
"sim_type": "dot",
|
| 91 |
+
"suppress_tokens": null,
|
| 92 |
+
"task_specific_params": null,
|
| 93 |
+
"task_type": "singlelabel",
|
| 94 |
+
"temperature": 1.0,
|
| 95 |
+
"tf_legacy_loss": false,
|
| 96 |
+
"tie_encoder_decoder": false,
|
| 97 |
+
"tie_word_embeddings": true,
|
| 98 |
+
"token_attention": false,
|
| 99 |
+
"tokenizer_class": null,
|
| 100 |
+
"tokenwise": false,
|
| 101 |
+
"top_k": 50,
|
| 102 |
+
"top_p": 1.0,
|
| 103 |
+
"torchscript": false,
|
| 104 |
+
"transformer_dropout": 0.1,
|
| 105 |
+
"transformer_hidden_size": 512,
|
| 106 |
+
"transformers_version": "4.57.6",
|
| 107 |
+
"typical_p": 1.0,
|
| 108 |
+
"use_bfloat16": false,
|
| 109 |
+
"use_bias": false
|
| 110 |
+
},
|
| 111 |
+
"probe_type": "transformer",
|
| 112 |
+
"sep_token_id": 2,
|
| 113 |
+
"task_type": "singlelabel",
|
| 114 |
+
"tokenwise": false,
|
| 115 |
+
"transformers_version": "4.57.6"
|
| 116 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec8ea16612d2975dad1abb1da8977591cbba6ff2b0566374755120e6e950bded
|
| 3 |
+
size 2331568712
|
packaged_probe_model.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Any, Dict, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from transformers import AutoModel, PreTrainedModel, PretrainedConfig
|
| 8 |
+
from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from protify.base_models.supported_models import all_presets_with_paths
|
| 13 |
+
from protify.pooler import Pooler
|
| 14 |
+
from protify.probes.get_probe import rebuild_probe_from_saved_config
|
| 15 |
+
except ImportError:
|
| 16 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
+
candidate_paths = [
|
| 18 |
+
current_dir,
|
| 19 |
+
os.path.dirname(current_dir),
|
| 20 |
+
os.path.dirname(os.path.dirname(current_dir)),
|
| 21 |
+
os.path.join(current_dir, "src"),
|
| 22 |
+
]
|
| 23 |
+
for candidate in candidate_paths:
|
| 24 |
+
if os.path.isdir(candidate) and candidate not in sys.path:
|
| 25 |
+
sys.path.insert(0, candidate)
|
| 26 |
+
from protify.base_models.supported_models import all_presets_with_paths
|
| 27 |
+
from protify.pooler import Pooler
|
| 28 |
+
from protify.probes.get_probe import rebuild_probe_from_saved_config
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class PackagedProbeConfig(PretrainedConfig):
|
| 32 |
+
model_type = "packaged_probe"
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
base_model_name: str = "",
|
| 37 |
+
probe_type: str = "linear",
|
| 38 |
+
probe_config: Optional[Dict[str, Any]] = None,
|
| 39 |
+
tokenwise: bool = False,
|
| 40 |
+
matrix_embed: bool = False,
|
| 41 |
+
pooling_types: Optional[list[str]] = None,
|
| 42 |
+
task_type: str = "singlelabel",
|
| 43 |
+
num_labels: int = 2,
|
| 44 |
+
ppi: bool = False,
|
| 45 |
+
add_token_ids: bool = False,
|
| 46 |
+
sep_token_id: Optional[int] = None,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
super().__init__(**kwargs)
|
| 50 |
+
self.base_model_name = base_model_name
|
| 51 |
+
self.probe_type = probe_type
|
| 52 |
+
self.probe_config = {} if probe_config is None else probe_config
|
| 53 |
+
self.tokenwise = tokenwise
|
| 54 |
+
self.matrix_embed = matrix_embed
|
| 55 |
+
self.pooling_types = ["mean"] if pooling_types is None else pooling_types
|
| 56 |
+
self.task_type = task_type
|
| 57 |
+
self.num_labels = num_labels
|
| 58 |
+
self.ppi = ppi
|
| 59 |
+
self.add_token_ids = add_token_ids
|
| 60 |
+
self.sep_token_id = sep_token_id
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class PackagedProbeModel(PreTrainedModel):
|
| 64 |
+
config_class = PackagedProbeConfig
|
| 65 |
+
base_model_prefix = "backbone"
|
| 66 |
+
all_tied_weights_keys = {}
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
config: PackagedProbeConfig,
|
| 71 |
+
base_model: Optional[nn.Module] = None,
|
| 72 |
+
probe: Optional[nn.Module] = None,
|
| 73 |
+
):
|
| 74 |
+
super().__init__(config)
|
| 75 |
+
self.config = config
|
| 76 |
+
self.backbone = self._load_base_model() if base_model is None else base_model
|
| 77 |
+
self.probe = self._load_probe() if probe is None else probe
|
| 78 |
+
self.pooler = Pooler(self.config.pooling_types)
|
| 79 |
+
|
| 80 |
+
def _load_base_model(self) -> nn.Module:
|
| 81 |
+
if self.config.base_model_name in all_presets_with_paths:
|
| 82 |
+
model_path = all_presets_with_paths[self.config.base_model_name]
|
| 83 |
+
else:
|
| 84 |
+
model_path = self.config.base_model_name
|
| 85 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
| 86 |
+
model.eval()
|
| 87 |
+
return model
|
| 88 |
+
|
| 89 |
+
def _load_probe(self) -> nn.Module:
|
| 90 |
+
return rebuild_probe_from_saved_config(
|
| 91 |
+
probe_type=self.config.probe_type,
|
| 92 |
+
tokenwise=self.config.tokenwise,
|
| 93 |
+
probe_config=self.config.probe_config,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def _extract_hidden_states(backbone_output: Any) -> torch.Tensor:
|
| 98 |
+
if isinstance(backbone_output, tuple):
|
| 99 |
+
return backbone_output[0]
|
| 100 |
+
if hasattr(backbone_output, "last_hidden_state"):
|
| 101 |
+
return backbone_output.last_hidden_state
|
| 102 |
+
if isinstance(backbone_output, torch.Tensor):
|
| 103 |
+
return backbone_output
|
| 104 |
+
raise ValueError("Unsupported backbone output format for packaged probe model")
|
| 105 |
+
|
| 106 |
+
@staticmethod
|
| 107 |
+
def _extract_attentions(backbone_output: Any) -> Optional[torch.Tensor]:
|
| 108 |
+
if hasattr(backbone_output, "attentions"):
|
| 109 |
+
return backbone_output.attentions
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
def _build_ppi_segment_masks(
|
| 113 |
+
self,
|
| 114 |
+
input_ids: torch.Tensor,
|
| 115 |
+
attention_mask: torch.Tensor,
|
| 116 |
+
token_type_ids: Optional[torch.Tensor],
|
| 117 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 118 |
+
if token_type_ids is not None and torch.any(token_type_ids == 1):
|
| 119 |
+
mask_a = ((token_type_ids == 0) & (attention_mask == 1)).long()
|
| 120 |
+
mask_b = ((token_type_ids == 1) & (attention_mask == 1)).long()
|
| 121 |
+
assert torch.all(mask_a.sum(dim=1) > 0), "PPI token_type_ids produced empty segment A"
|
| 122 |
+
assert torch.all(mask_b.sum(dim=1) > 0), "PPI token_type_ids produced empty segment B"
|
| 123 |
+
return mask_a, mask_b
|
| 124 |
+
|
| 125 |
+
assert self.config.sep_token_id is not None, "sep_token_id is required for PPI fallback segmentation"
|
| 126 |
+
batch_size, seq_len = input_ids.shape
|
| 127 |
+
mask_a = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
|
| 128 |
+
mask_b = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
|
| 129 |
+
|
| 130 |
+
for batch_idx in range(batch_size):
|
| 131 |
+
valid_positions = torch.where(attention_mask[batch_idx] == 1)[0]
|
| 132 |
+
sep_positions = torch.where((input_ids[batch_idx] == self.config.sep_token_id) & (attention_mask[batch_idx] == 1))[0]
|
| 133 |
+
if len(valid_positions) == 0:
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
if len(sep_positions) >= 2:
|
| 137 |
+
first_sep = int(sep_positions[0].item())
|
| 138 |
+
second_sep = int(sep_positions[1].item())
|
| 139 |
+
mask_a[batch_idx, :first_sep + 1] = 1
|
| 140 |
+
mask_b[batch_idx, first_sep + 1:second_sep + 1] = 1
|
| 141 |
+
elif len(sep_positions) == 1:
|
| 142 |
+
first_sep = int(sep_positions[0].item())
|
| 143 |
+
mask_a[batch_idx, :first_sep + 1] = 1
|
| 144 |
+
mask_b[batch_idx, first_sep + 1: int(valid_positions[-1].item()) + 1] = 1
|
| 145 |
+
else:
|
| 146 |
+
midpoint = len(valid_positions) // 2
|
| 147 |
+
mask_a[batch_idx, valid_positions[:midpoint]] = 1
|
| 148 |
+
mask_b[batch_idx, valid_positions[midpoint:]] = 1
|
| 149 |
+
|
| 150 |
+
assert torch.all(mask_a.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment A"
|
| 151 |
+
assert torch.all(mask_b.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment B"
|
| 152 |
+
return mask_a, mask_b
|
| 153 |
+
|
| 154 |
+
def _build_probe_inputs(
|
| 155 |
+
self,
|
| 156 |
+
hidden_states: torch.Tensor,
|
| 157 |
+
input_ids: torch.Tensor,
|
| 158 |
+
attention_mask: torch.Tensor,
|
| 159 |
+
token_type_ids: Optional[torch.Tensor],
|
| 160 |
+
attentions: Optional[torch.Tensor],
|
| 161 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 162 |
+
if self.config.ppi and (not self.config.matrix_embed) and (not self.config.tokenwise):
|
| 163 |
+
mask_a, mask_b = self._build_ppi_segment_masks(input_ids, attention_mask, token_type_ids)
|
| 164 |
+
vec_a = self.pooler(hidden_states, attention_mask=mask_a, attentions=attentions)
|
| 165 |
+
vec_b = self.pooler(hidden_states, attention_mask=mask_b, attentions=attentions)
|
| 166 |
+
return torch.cat((vec_a, vec_b), dim=-1), None
|
| 167 |
+
|
| 168 |
+
if self.config.matrix_embed or self.config.tokenwise:
|
| 169 |
+
return hidden_states, attention_mask
|
| 170 |
+
|
| 171 |
+
pooled = self.pooler(hidden_states, attention_mask=attention_mask, attentions=attentions)
|
| 172 |
+
return pooled, None
|
| 173 |
+
|
| 174 |
+
def forward(
|
| 175 |
+
self,
|
| 176 |
+
input_ids: torch.Tensor,
|
| 177 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 178 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 179 |
+
labels: Optional[torch.Tensor] = None,
|
| 180 |
+
) -> SequenceClassifierOutput | TokenClassifierOutput:
|
| 181 |
+
if attention_mask is None:
|
| 182 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
| 183 |
+
|
| 184 |
+
requires_attentions = "parti" in self.config.pooling_types and (not self.config.matrix_embed) and (not self.config.tokenwise)
|
| 185 |
+
backbone_kwargs: Dict[str, Any] = {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 186 |
+
if requires_attentions:
|
| 187 |
+
backbone_kwargs["output_attentions"] = True
|
| 188 |
+
backbone_output = self.backbone(**backbone_kwargs)
|
| 189 |
+
hidden_states = self._extract_hidden_states(backbone_output)
|
| 190 |
+
attentions = self._extract_attentions(backbone_output)
|
| 191 |
+
if requires_attentions:
|
| 192 |
+
assert attentions is not None, "parti pooling requires base model attentions"
|
| 193 |
+
probe_embeddings, probe_attention_mask = self._build_probe_inputs(
|
| 194 |
+
hidden_states=hidden_states,
|
| 195 |
+
input_ids=input_ids,
|
| 196 |
+
attention_mask=attention_mask,
|
| 197 |
+
token_type_ids=token_type_ids,
|
| 198 |
+
attentions=attentions,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if self.config.probe_type == "linear":
|
| 202 |
+
return self.probe(embeddings=probe_embeddings, labels=labels)
|
| 203 |
+
|
| 204 |
+
if self.config.probe_type == "transformer":
|
| 205 |
+
forward_kwargs: Dict[str, Any] = {"embeddings": probe_embeddings, "labels": labels}
|
| 206 |
+
if probe_attention_mask is not None:
|
| 207 |
+
forward_kwargs["attention_mask"] = probe_attention_mask
|
| 208 |
+
if self.config.add_token_ids and token_type_ids is not None and probe_attention_mask is not None:
|
| 209 |
+
forward_kwargs["token_type_ids"] = token_type_ids
|
| 210 |
+
return self.probe(**forward_kwargs)
|
| 211 |
+
|
| 212 |
+
if self.config.probe_type in ["retrievalnet", "lyra"]:
|
| 213 |
+
return self.probe(embeddings=probe_embeddings, attention_mask=probe_attention_mask, labels=labels)
|
| 214 |
+
|
| 215 |
+
raise ValueError(f"Unsupported probe type for packaged model: {self.config.probe_type}")
|
protify/FastPLMs/__init__.py
ADDED
|
File without changes
|
protify/FastPLMs/boltz/scripts/eval/aggregate_evals.py
ADDED
|
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
METRICS = ["lddt", "bb_lddt", "tm_score", "rmsd"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def compute_af3_metrics(preds, evals, name):
|
| 13 |
+
metrics = {}
|
| 14 |
+
|
| 15 |
+
top_model = None
|
| 16 |
+
top_confidence = -1000
|
| 17 |
+
for model_id in range(5):
|
| 18 |
+
# Load confidence file
|
| 19 |
+
confidence_file = (
|
| 20 |
+
Path(preds) / f"seed-1_sample-{model_id}" / "summary_confidences.json"
|
| 21 |
+
)
|
| 22 |
+
with confidence_file.open("r") as f:
|
| 23 |
+
confidence_data = json.load(f)
|
| 24 |
+
confidence = confidence_data["ranking_score"]
|
| 25 |
+
if confidence > top_confidence:
|
| 26 |
+
top_model = model_id
|
| 27 |
+
top_confidence = confidence
|
| 28 |
+
|
| 29 |
+
# Load eval file
|
| 30 |
+
eval_file = Path(evals) / f"{name}_model_{model_id}.json"
|
| 31 |
+
with eval_file.open("r") as f:
|
| 32 |
+
eval_data = json.load(f)
|
| 33 |
+
for metric_name in METRICS:
|
| 34 |
+
if metric_name in eval_data:
|
| 35 |
+
metrics.setdefault(metric_name, []).append(eval_data[metric_name])
|
| 36 |
+
|
| 37 |
+
if "dockq" in eval_data and eval_data["dockq"] is not None:
|
| 38 |
+
metrics.setdefault("dockq_>0.23", []).append(
|
| 39 |
+
np.mean(
|
| 40 |
+
[float(v > 0.23) for v in eval_data["dockq"] if v is not None]
|
| 41 |
+
)
|
| 42 |
+
)
|
| 43 |
+
metrics.setdefault("dockq_>0.49", []).append(
|
| 44 |
+
np.mean(
|
| 45 |
+
[float(v > 0.49) for v in eval_data["dockq"] if v is not None]
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
+
metrics.setdefault("len_dockq_", []).append(
|
| 49 |
+
len([v for v in eval_data["dockq"] if v is not None])
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json"
|
| 53 |
+
with eval_file.open("r") as f:
|
| 54 |
+
eval_data = json.load(f)
|
| 55 |
+
if "lddt_pli" in eval_data:
|
| 56 |
+
lddt_plis = [
|
| 57 |
+
x["score"] for x in eval_data["lddt_pli"]["assigned_scores"]
|
| 58 |
+
]
|
| 59 |
+
for _ in eval_data["lddt_pli"][
|
| 60 |
+
"model_ligand_unassigned_reason"
|
| 61 |
+
].items():
|
| 62 |
+
lddt_plis.append(0)
|
| 63 |
+
if not lddt_plis:
|
| 64 |
+
continue
|
| 65 |
+
lddt_pli = np.mean([x for x in lddt_plis])
|
| 66 |
+
metrics.setdefault("lddt_pli", []).append(lddt_pli)
|
| 67 |
+
metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis))
|
| 68 |
+
|
| 69 |
+
if "rmsd" in eval_data:
|
| 70 |
+
rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]]
|
| 71 |
+
for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items():
|
| 72 |
+
rmsds.append(100)
|
| 73 |
+
if not rmsds:
|
| 74 |
+
continue
|
| 75 |
+
rmsd2 = np.mean([x < 2.0 for x in rmsds])
|
| 76 |
+
rmsd5 = np.mean([x < 5.0 for x in rmsds])
|
| 77 |
+
metrics.setdefault("rmsd<2", []).append(rmsd2)
|
| 78 |
+
metrics.setdefault("rmsd<5", []).append(rmsd5)
|
| 79 |
+
metrics.setdefault("len_rmsd", []).append(len(rmsds))
|
| 80 |
+
|
| 81 |
+
# Get oracle
|
| 82 |
+
oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()}
|
| 83 |
+
avg = {k: sum(v) / len(v) for k, v in metrics.items()}
|
| 84 |
+
top1 = {k: v[top_model] for k, v in metrics.items()}
|
| 85 |
+
|
| 86 |
+
results = {}
|
| 87 |
+
for metric_name in metrics:
|
| 88 |
+
if metric_name.startswith("len_"):
|
| 89 |
+
continue
|
| 90 |
+
if metric_name == "lddt_pli":
|
| 91 |
+
l = metrics["len_lddt_pli"][0]
|
| 92 |
+
elif metric_name == "rmsd<2" or metric_name == "rmsd<5":
|
| 93 |
+
l = metrics["len_rmsd"][0]
|
| 94 |
+
elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49":
|
| 95 |
+
l = metrics["len_dockq_"][0]
|
| 96 |
+
else:
|
| 97 |
+
l = 1
|
| 98 |
+
results[metric_name] = {
|
| 99 |
+
"oracle": oracle[metric_name],
|
| 100 |
+
"average": avg[metric_name],
|
| 101 |
+
"top1": top1[metric_name],
|
| 102 |
+
"len": l,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
return results
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def compute_chai_metrics(preds, evals, name):
|
| 109 |
+
metrics = {}
|
| 110 |
+
|
| 111 |
+
top_model = None
|
| 112 |
+
top_confidence = 0
|
| 113 |
+
for model_id in range(5):
|
| 114 |
+
# Load confidence file
|
| 115 |
+
confidence_file = Path(preds) / f"scores.model_idx_{model_id}.npz"
|
| 116 |
+
confidence_data = np.load(confidence_file)
|
| 117 |
+
confidence = confidence_data["aggregate_score"].item()
|
| 118 |
+
if confidence > top_confidence:
|
| 119 |
+
top_model = model_id
|
| 120 |
+
top_confidence = confidence
|
| 121 |
+
|
| 122 |
+
# Load eval file
|
| 123 |
+
eval_file = Path(evals) / f"{name}_model_{model_id}.json"
|
| 124 |
+
with eval_file.open("r") as f:
|
| 125 |
+
eval_data = json.load(f)
|
| 126 |
+
for metric_name in METRICS:
|
| 127 |
+
if metric_name in eval_data:
|
| 128 |
+
metrics.setdefault(metric_name, []).append(eval_data[metric_name])
|
| 129 |
+
|
| 130 |
+
if "dockq" in eval_data and eval_data["dockq"] is not None:
|
| 131 |
+
metrics.setdefault("dockq_>0.23", []).append(
|
| 132 |
+
np.mean(
|
| 133 |
+
[float(v > 0.23) for v in eval_data["dockq"] if v is not None]
|
| 134 |
+
)
|
| 135 |
+
)
|
| 136 |
+
metrics.setdefault("dockq_>0.49", []).append(
|
| 137 |
+
np.mean(
|
| 138 |
+
[float(v > 0.49) for v in eval_data["dockq"] if v is not None]
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
metrics.setdefault("len_dockq_", []).append(
|
| 142 |
+
len([v for v in eval_data["dockq"] if v is not None])
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json"
|
| 146 |
+
with eval_file.open("r") as f:
|
| 147 |
+
eval_data = json.load(f)
|
| 148 |
+
if "lddt_pli" in eval_data:
|
| 149 |
+
lddt_plis = [
|
| 150 |
+
x["score"] for x in eval_data["lddt_pli"]["assigned_scores"]
|
| 151 |
+
]
|
| 152 |
+
for _ in eval_data["lddt_pli"][
|
| 153 |
+
"model_ligand_unassigned_reason"
|
| 154 |
+
].items():
|
| 155 |
+
lddt_plis.append(0)
|
| 156 |
+
if not lddt_plis:
|
| 157 |
+
continue
|
| 158 |
+
lddt_pli = np.mean([x for x in lddt_plis])
|
| 159 |
+
metrics.setdefault("lddt_pli", []).append(lddt_pli)
|
| 160 |
+
metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis))
|
| 161 |
+
|
| 162 |
+
if "rmsd" in eval_data:
|
| 163 |
+
rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]]
|
| 164 |
+
for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items():
|
| 165 |
+
rmsds.append(100)
|
| 166 |
+
if not rmsds:
|
| 167 |
+
continue
|
| 168 |
+
rmsd2 = np.mean([x < 2.0 for x in rmsds])
|
| 169 |
+
rmsd5 = np.mean([x < 5.0 for x in rmsds])
|
| 170 |
+
metrics.setdefault("rmsd<2", []).append(rmsd2)
|
| 171 |
+
metrics.setdefault("rmsd<5", []).append(rmsd5)
|
| 172 |
+
metrics.setdefault("len_rmsd", []).append(len(rmsds))
|
| 173 |
+
|
| 174 |
+
# Get oracle
|
| 175 |
+
oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()}
|
| 176 |
+
avg = {k: sum(v) / len(v) for k, v in metrics.items()}
|
| 177 |
+
top1 = {k: v[top_model] for k, v in metrics.items()}
|
| 178 |
+
|
| 179 |
+
results = {}
|
| 180 |
+
for metric_name in metrics:
|
| 181 |
+
if metric_name.startswith("len_"):
|
| 182 |
+
continue
|
| 183 |
+
if metric_name == "lddt_pli":
|
| 184 |
+
l = metrics["len_lddt_pli"][0]
|
| 185 |
+
elif metric_name == "rmsd<2" or metric_name == "rmsd<5":
|
| 186 |
+
l = metrics["len_rmsd"][0]
|
| 187 |
+
elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49":
|
| 188 |
+
l = metrics["len_dockq_"][0]
|
| 189 |
+
else:
|
| 190 |
+
l = 1
|
| 191 |
+
results[metric_name] = {
|
| 192 |
+
"oracle": oracle[metric_name],
|
| 193 |
+
"average": avg[metric_name],
|
| 194 |
+
"top1": top1[metric_name],
|
| 195 |
+
"len": l,
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
return results
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def compute_boltz_metrics(preds, evals, name):
|
| 202 |
+
metrics = {}
|
| 203 |
+
|
| 204 |
+
top_model = None
|
| 205 |
+
top_confidence = 0
|
| 206 |
+
for model_id in range(5):
|
| 207 |
+
# Load confidence file
|
| 208 |
+
confidence_file = (
|
| 209 |
+
Path(preds) / f"confidence_{Path(preds).name}_model_{model_id}.json"
|
| 210 |
+
)
|
| 211 |
+
with confidence_file.open("r") as f:
|
| 212 |
+
confidence_data = json.load(f)
|
| 213 |
+
confidence = confidence_data["confidence_score"]
|
| 214 |
+
if confidence > top_confidence:
|
| 215 |
+
top_model = model_id
|
| 216 |
+
top_confidence = confidence
|
| 217 |
+
|
| 218 |
+
# Load eval file
|
| 219 |
+
eval_file = Path(evals) / f"{name}_model_{model_id}.json"
|
| 220 |
+
with eval_file.open("r") as f:
|
| 221 |
+
eval_data = json.load(f)
|
| 222 |
+
for metric_name in METRICS:
|
| 223 |
+
if metric_name in eval_data:
|
| 224 |
+
metrics.setdefault(metric_name, []).append(eval_data[metric_name])
|
| 225 |
+
|
| 226 |
+
if "dockq" in eval_data and eval_data["dockq"] is not None:
|
| 227 |
+
metrics.setdefault("dockq_>0.23", []).append(
|
| 228 |
+
np.mean(
|
| 229 |
+
[float(v > 0.23) for v in eval_data["dockq"] if v is not None]
|
| 230 |
+
)
|
| 231 |
+
)
|
| 232 |
+
metrics.setdefault("dockq_>0.49", []).append(
|
| 233 |
+
np.mean(
|
| 234 |
+
[float(v > 0.49) for v in eval_data["dockq"] if v is not None]
|
| 235 |
+
)
|
| 236 |
+
)
|
| 237 |
+
metrics.setdefault("len_dockq_", []).append(
|
| 238 |
+
len([v for v in eval_data["dockq"] if v is not None])
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
eval_file = Path(evals) / f"{name}_model_{model_id}_ligand.json"
|
| 242 |
+
with eval_file.open("r") as f:
|
| 243 |
+
eval_data = json.load(f)
|
| 244 |
+
if "lddt_pli" in eval_data:
|
| 245 |
+
lddt_plis = [
|
| 246 |
+
x["score"] for x in eval_data["lddt_pli"]["assigned_scores"]
|
| 247 |
+
]
|
| 248 |
+
for _ in eval_data["lddt_pli"][
|
| 249 |
+
"model_ligand_unassigned_reason"
|
| 250 |
+
].items():
|
| 251 |
+
lddt_plis.append(0)
|
| 252 |
+
if not lddt_plis:
|
| 253 |
+
continue
|
| 254 |
+
lddt_pli = np.mean([x for x in lddt_plis])
|
| 255 |
+
metrics.setdefault("lddt_pli", []).append(lddt_pli)
|
| 256 |
+
metrics.setdefault("len_lddt_pli", []).append(len(lddt_plis))
|
| 257 |
+
|
| 258 |
+
if "rmsd" in eval_data:
|
| 259 |
+
rmsds = [x["score"] for x in eval_data["rmsd"]["assigned_scores"]]
|
| 260 |
+
for _ in eval_data["rmsd"]["model_ligand_unassigned_reason"].items():
|
| 261 |
+
rmsds.append(100)
|
| 262 |
+
if not rmsds:
|
| 263 |
+
continue
|
| 264 |
+
rmsd2 = np.mean([x < 2.0 for x in rmsds])
|
| 265 |
+
rmsd5 = np.mean([x < 5.0 for x in rmsds])
|
| 266 |
+
metrics.setdefault("rmsd<2", []).append(rmsd2)
|
| 267 |
+
metrics.setdefault("rmsd<5", []).append(rmsd5)
|
| 268 |
+
metrics.setdefault("len_rmsd", []).append(len(rmsds))
|
| 269 |
+
|
| 270 |
+
# Get oracle
|
| 271 |
+
oracle = {k: min(v) if k == "rmsd" else max(v) for k, v in metrics.items()}
|
| 272 |
+
avg = {k: sum(v) / len(v) for k, v in metrics.items()}
|
| 273 |
+
top1 = {k: v[top_model] for k, v in metrics.items()}
|
| 274 |
+
|
| 275 |
+
results = {}
|
| 276 |
+
for metric_name in metrics:
|
| 277 |
+
if metric_name.startswith("len_"):
|
| 278 |
+
continue
|
| 279 |
+
if metric_name == "lddt_pli":
|
| 280 |
+
l = metrics["len_lddt_pli"][0]
|
| 281 |
+
elif metric_name == "rmsd<2" or metric_name == "rmsd<5":
|
| 282 |
+
l = metrics["len_rmsd"][0]
|
| 283 |
+
elif metric_name == "dockq_>0.23" or metric_name == "dockq_>0.49":
|
| 284 |
+
l = metrics["len_dockq_"][0]
|
| 285 |
+
else:
|
| 286 |
+
l = 1
|
| 287 |
+
results[metric_name] = {
|
| 288 |
+
"oracle": oracle[metric_name],
|
| 289 |
+
"average": avg[metric_name],
|
| 290 |
+
"top1": top1[metric_name],
|
| 291 |
+
"len": l,
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
return results
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def eval_models(
|
| 298 |
+
chai_preds,
|
| 299 |
+
chai_evals,
|
| 300 |
+
af3_preds,
|
| 301 |
+
af3_evals,
|
| 302 |
+
boltz_preds,
|
| 303 |
+
boltz_evals,
|
| 304 |
+
boltz_preds_x,
|
| 305 |
+
boltz_evals_x,
|
| 306 |
+
):
|
| 307 |
+
# Load preds and make sure we have predictions for all models
|
| 308 |
+
chai_preds_names = {
|
| 309 |
+
x.name.lower(): x
|
| 310 |
+
for x in Path(chai_preds).iterdir()
|
| 311 |
+
if not x.name.lower().startswith(".")
|
| 312 |
+
}
|
| 313 |
+
af3_preds_names = {
|
| 314 |
+
x.name.lower(): x
|
| 315 |
+
for x in Path(af3_preds).iterdir()
|
| 316 |
+
if not x.name.lower().startswith(".")
|
| 317 |
+
}
|
| 318 |
+
boltz_preds_names = {
|
| 319 |
+
x.name.lower(): x
|
| 320 |
+
for x in Path(boltz_preds).iterdir()
|
| 321 |
+
if not x.name.lower().startswith(".")
|
| 322 |
+
}
|
| 323 |
+
boltz_preds_names_x = {
|
| 324 |
+
x.name.lower(): x
|
| 325 |
+
for x in Path(boltz_preds_x).iterdir()
|
| 326 |
+
if not x.name.lower().startswith(".")
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
print("Chai preds", len(chai_preds_names))
|
| 330 |
+
print("Af3 preds", len(af3_preds_names))
|
| 331 |
+
print("Boltz preds", len(boltz_preds_names))
|
| 332 |
+
print("Boltzx preds", len(boltz_preds_names_x))
|
| 333 |
+
|
| 334 |
+
common = (
|
| 335 |
+
set(chai_preds_names.keys())
|
| 336 |
+
& set(af3_preds_names.keys())
|
| 337 |
+
& set(boltz_preds_names.keys())
|
| 338 |
+
& set(boltz_preds_names_x.keys())
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Remove examples in the validation set
|
| 342 |
+
keys_to_remove = ["t1133", "h1134", "r1134s1", "t1134s2", "t1121", "t1123", "t1159"]
|
| 343 |
+
for key in keys_to_remove:
|
| 344 |
+
if key in common:
|
| 345 |
+
common.remove(key)
|
| 346 |
+
print("Common", len(common))
|
| 347 |
+
|
| 348 |
+
# Create a dataframe with the following schema:
|
| 349 |
+
# tool, name, metric, oracle, average, top1
|
| 350 |
+
results = []
|
| 351 |
+
for name in tqdm(common):
|
| 352 |
+
try:
|
| 353 |
+
af3_results = compute_af3_metrics(
|
| 354 |
+
af3_preds_names[name],
|
| 355 |
+
af3_evals,
|
| 356 |
+
name,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
except Exception as e:
|
| 360 |
+
import traceback
|
| 361 |
+
|
| 362 |
+
traceback.print_exc()
|
| 363 |
+
print(f"Error evaluating AF3 {name}: {e}")
|
| 364 |
+
continue
|
| 365 |
+
try:
|
| 366 |
+
chai_results = compute_chai_metrics(
|
| 367 |
+
chai_preds_names[name],
|
| 368 |
+
chai_evals,
|
| 369 |
+
name,
|
| 370 |
+
)
|
| 371 |
+
except Exception as e:
|
| 372 |
+
import traceback
|
| 373 |
+
|
| 374 |
+
traceback.print_exc()
|
| 375 |
+
print(f"Error evaluating Chai {name}: {e}")
|
| 376 |
+
continue
|
| 377 |
+
try:
|
| 378 |
+
boltz_results = compute_boltz_metrics(
|
| 379 |
+
boltz_preds_names[name],
|
| 380 |
+
boltz_evals,
|
| 381 |
+
name,
|
| 382 |
+
)
|
| 383 |
+
except Exception as e:
|
| 384 |
+
import traceback
|
| 385 |
+
|
| 386 |
+
traceback.print_exc()
|
| 387 |
+
print(f"Error evaluating Boltz {name}: {e}")
|
| 388 |
+
continue
|
| 389 |
+
|
| 390 |
+
try:
|
| 391 |
+
boltz_results_x = compute_boltz_metrics(
|
| 392 |
+
boltz_preds_names_x[name],
|
| 393 |
+
boltz_evals_x,
|
| 394 |
+
name,
|
| 395 |
+
)
|
| 396 |
+
except Exception as e:
|
| 397 |
+
import traceback
|
| 398 |
+
|
| 399 |
+
traceback.print_exc()
|
| 400 |
+
print(f"Error evaluating Boltzx {name}: {e}")
|
| 401 |
+
continue
|
| 402 |
+
|
| 403 |
+
for metric_name in af3_results:
|
| 404 |
+
if metric_name in chai_results and metric_name in boltz_results:
|
| 405 |
+
if (
|
| 406 |
+
(
|
| 407 |
+
af3_results[metric_name]["len"]
|
| 408 |
+
== chai_results[metric_name]["len"]
|
| 409 |
+
)
|
| 410 |
+
and (
|
| 411 |
+
af3_results[metric_name]["len"]
|
| 412 |
+
== boltz_results[metric_name]["len"]
|
| 413 |
+
)
|
| 414 |
+
and (
|
| 415 |
+
af3_results[metric_name]["len"]
|
| 416 |
+
== boltz_results_x[metric_name]["len"]
|
| 417 |
+
)
|
| 418 |
+
):
|
| 419 |
+
results.append(
|
| 420 |
+
{
|
| 421 |
+
"tool": "AF3 oracle",
|
| 422 |
+
"target": name,
|
| 423 |
+
"metric": metric_name,
|
| 424 |
+
"value": af3_results[metric_name]["oracle"],
|
| 425 |
+
}
|
| 426 |
+
)
|
| 427 |
+
results.append(
|
| 428 |
+
{
|
| 429 |
+
"tool": "AF3 top-1",
|
| 430 |
+
"target": name,
|
| 431 |
+
"metric": metric_name,
|
| 432 |
+
"value": af3_results[metric_name]["top1"],
|
| 433 |
+
}
|
| 434 |
+
)
|
| 435 |
+
results.append(
|
| 436 |
+
{
|
| 437 |
+
"tool": "Chai-1 oracle",
|
| 438 |
+
"target": name,
|
| 439 |
+
"metric": metric_name,
|
| 440 |
+
"value": chai_results[metric_name]["oracle"],
|
| 441 |
+
}
|
| 442 |
+
)
|
| 443 |
+
results.append(
|
| 444 |
+
{
|
| 445 |
+
"tool": "Chai-1 top-1",
|
| 446 |
+
"target": name,
|
| 447 |
+
"metric": metric_name,
|
| 448 |
+
"value": chai_results[metric_name]["top1"],
|
| 449 |
+
}
|
| 450 |
+
)
|
| 451 |
+
results.append(
|
| 452 |
+
{
|
| 453 |
+
"tool": "Boltz-1 oracle",
|
| 454 |
+
"target": name,
|
| 455 |
+
"metric": metric_name,
|
| 456 |
+
"value": boltz_results[metric_name]["oracle"],
|
| 457 |
+
}
|
| 458 |
+
)
|
| 459 |
+
results.append(
|
| 460 |
+
{
|
| 461 |
+
"tool": "Boltz-1 top-1",
|
| 462 |
+
"target": name,
|
| 463 |
+
"metric": metric_name,
|
| 464 |
+
"value": boltz_results[metric_name]["top1"],
|
| 465 |
+
}
|
| 466 |
+
)
|
| 467 |
+
results.append(
|
| 468 |
+
{
|
| 469 |
+
"tool": "Boltz-1x oracle",
|
| 470 |
+
"target": name,
|
| 471 |
+
"metric": metric_name,
|
| 472 |
+
"value": boltz_results_x[metric_name]["oracle"],
|
| 473 |
+
}
|
| 474 |
+
)
|
| 475 |
+
results.append(
|
| 476 |
+
{
|
| 477 |
+
"tool": "Boltz-1x top-1",
|
| 478 |
+
"target": name,
|
| 479 |
+
"metric": metric_name,
|
| 480 |
+
"value": boltz_results_x[metric_name]["top1"],
|
| 481 |
+
}
|
| 482 |
+
)
|
| 483 |
+
else:
|
| 484 |
+
print(
|
| 485 |
+
"Different lengths",
|
| 486 |
+
name,
|
| 487 |
+
metric_name,
|
| 488 |
+
af3_results[metric_name]["len"],
|
| 489 |
+
chai_results[metric_name]["len"],
|
| 490 |
+
boltz_results[metric_name]["len"],
|
| 491 |
+
boltz_results_x[metric_name]["len"],
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
print(
|
| 495 |
+
"Missing metric",
|
| 496 |
+
name,
|
| 497 |
+
metric_name,
|
| 498 |
+
metric_name in chai_results,
|
| 499 |
+
metric_name in boltz_results,
|
| 500 |
+
metric_name in boltz_results_x,
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
# Write the results to a file, ensure we only keep the target & metrics where we have all tools
|
| 504 |
+
df = pd.DataFrame(results)
|
| 505 |
+
return df
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
def eval_validity_checks(df):
|
| 509 |
+
# Filter the dataframe to only include the targets in the validity checks
|
| 510 |
+
name_mapping = {
|
| 511 |
+
"af3": "AF3 top-1",
|
| 512 |
+
"chai": "Chai-1 top-1",
|
| 513 |
+
"boltz1": "Boltz-1 top-1",
|
| 514 |
+
"boltz1x": "Boltz-1x top-1",
|
| 515 |
+
}
|
| 516 |
+
top1 = df[df["model_idx"] == 0]
|
| 517 |
+
top1 = top1[["tool", "pdb_id", "valid"]]
|
| 518 |
+
top1["tool"] = top1["tool"].apply(lambda x: name_mapping[x])
|
| 519 |
+
top1 = top1.rename(columns={"tool": "tool", "pdb_id": "target", "valid": "value"})
|
| 520 |
+
top1["metric"] = "physical validity"
|
| 521 |
+
top1["target"] = top1["target"].apply(lambda x: x.lower())
|
| 522 |
+
top1 = top1[["tool", "target", "metric", "value"]]
|
| 523 |
+
|
| 524 |
+
name_mapping = {
|
| 525 |
+
"af3": "AF3 oracle",
|
| 526 |
+
"chai": "Chai-1 oracle",
|
| 527 |
+
"boltz1": "Boltz-1 oracle",
|
| 528 |
+
"boltz1x": "Boltz-1x oracle",
|
| 529 |
+
}
|
| 530 |
+
oracle = df[["tool", "model_idx", "pdb_id", "valid"]]
|
| 531 |
+
oracle = oracle.groupby(["tool", "pdb_id"])["valid"].max().reset_index()
|
| 532 |
+
oracle = oracle.rename(
|
| 533 |
+
columns={"tool": "tool", "pdb_id": "target", "valid": "value"}
|
| 534 |
+
)
|
| 535 |
+
oracle["tool"] = oracle["tool"].apply(lambda x: name_mapping[x])
|
| 536 |
+
oracle["metric"] = "physical validity"
|
| 537 |
+
oracle = oracle[["tool", "target", "metric", "value"]]
|
| 538 |
+
oracle["target"] = oracle["target"].apply(lambda x: x.lower())
|
| 539 |
+
out = pd.concat([top1, oracle])
|
| 540 |
+
return out
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
def bootstrap_ci(series, n_boot=1000, alpha=0.05):
|
| 544 |
+
"""
|
| 545 |
+
Compute 95% bootstrap confidence intervals for the mean of 'series'.
|
| 546 |
+
"""
|
| 547 |
+
n = len(series)
|
| 548 |
+
boot_means = []
|
| 549 |
+
# Perform bootstrap resampling
|
| 550 |
+
for _ in range(n_boot):
|
| 551 |
+
sample = series.sample(n, replace=True)
|
| 552 |
+
boot_means.append(sample.mean())
|
| 553 |
+
|
| 554 |
+
boot_means = np.array(boot_means)
|
| 555 |
+
mean_val = np.mean(series)
|
| 556 |
+
lower = np.percentile(boot_means, 100 * alpha / 2)
|
| 557 |
+
upper = np.percentile(boot_means, 100 * (1 - alpha / 2))
|
| 558 |
+
return mean_val, lower, upper
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def plot_data(desired_tools, desired_metrics, df, dataset, filename):
|
| 562 |
+
filtered_df = df[
|
| 563 |
+
df["tool"].isin(desired_tools) & df["metric"].isin(desired_metrics)
|
| 564 |
+
]
|
| 565 |
+
|
| 566 |
+
# Apply bootstrap to each (tool, metric) group
|
| 567 |
+
boot_stats = filtered_df.groupby(["tool", "metric"])["value"].apply(bootstrap_ci)
|
| 568 |
+
|
| 569 |
+
# boot_stats is a Series of tuples (mean, lower, upper). Convert to DataFrame:
|
| 570 |
+
boot_stats = boot_stats.apply(pd.Series)
|
| 571 |
+
boot_stats.columns = ["mean", "lower", "upper"]
|
| 572 |
+
|
| 573 |
+
# Unstack to get a DataFrame suitable for plotting
|
| 574 |
+
plot_data = boot_stats["mean"].unstack("tool")
|
| 575 |
+
plot_data = plot_data.reindex(desired_metrics)
|
| 576 |
+
|
| 577 |
+
lower_data = boot_stats["lower"].unstack("tool")
|
| 578 |
+
lower_data = lower_data.reindex(desired_metrics)
|
| 579 |
+
|
| 580 |
+
upper_data = boot_stats["upper"].unstack("tool")
|
| 581 |
+
upper_data = upper_data.reindex(desired_metrics)
|
| 582 |
+
|
| 583 |
+
# If you need a specific order of tools:
|
| 584 |
+
tool_order = [
|
| 585 |
+
"AF3 oracle",
|
| 586 |
+
"AF3 top-1",
|
| 587 |
+
"Chai-1 oracle",
|
| 588 |
+
"Chai-1 top-1",
|
| 589 |
+
"Boltz-1 oracle",
|
| 590 |
+
"Boltz-1 top-1",
|
| 591 |
+
"Boltz-1x oracle",
|
| 592 |
+
"Boltz-1x top-1",
|
| 593 |
+
]
|
| 594 |
+
plot_data = plot_data[tool_order]
|
| 595 |
+
lower_data = lower_data[tool_order]
|
| 596 |
+
upper_data = upper_data[tool_order]
|
| 597 |
+
|
| 598 |
+
# Rename metrics
|
| 599 |
+
renaming = {
|
| 600 |
+
"lddt_pli": "Mean LDDT-PLI",
|
| 601 |
+
"rmsd<2": "L-RMSD < 2A",
|
| 602 |
+
"lddt": "Mean LDDT",
|
| 603 |
+
"dockq_>0.23": "DockQ > 0.23",
|
| 604 |
+
"physical validity": "Physical Validity",
|
| 605 |
+
}
|
| 606 |
+
plot_data = plot_data.rename(index=renaming)
|
| 607 |
+
lower_data = lower_data.rename(index=renaming)
|
| 608 |
+
upper_data = upper_data.rename(index=renaming)
|
| 609 |
+
mean_vals = plot_data.values
|
| 610 |
+
|
| 611 |
+
# Colors
|
| 612 |
+
tool_colors = [
|
| 613 |
+
"#994C00", # AF3 oracle
|
| 614 |
+
"#FFB55A", # AF3 top-1
|
| 615 |
+
"#931652", # Chai-1 oracle
|
| 616 |
+
"#FC8AD9", # Chai-1 top-1
|
| 617 |
+
"#188F52", # Boltz-1 oracle
|
| 618 |
+
"#86E935", # Boltz-1 top-1
|
| 619 |
+
"#004D80", # Boltz-1x oracle
|
| 620 |
+
"#55C2FF", # Boltz-1x top-1
|
| 621 |
+
]
|
| 622 |
+
|
| 623 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 624 |
+
|
| 625 |
+
x = np.arange(len(plot_data.index))
|
| 626 |
+
bar_spacing = 0.015
|
| 627 |
+
total_width = 0.7
|
| 628 |
+
# Adjust width to account for the spacing
|
| 629 |
+
width = (total_width - (len(tool_order) - 1) * bar_spacing) / len(tool_order)
|
| 630 |
+
|
| 631 |
+
for i, tool in enumerate(tool_order):
|
| 632 |
+
# Each subsequent bar moves over by width + bar_spacing
|
| 633 |
+
offsets = x - (total_width - width) / 2 + i * (width + bar_spacing)
|
| 634 |
+
# Extract the means and errors for this tool
|
| 635 |
+
tool_means = plot_data[tool].values
|
| 636 |
+
tool_yerr_lower = mean_vals[:, i] - lower_data.values[:, i]
|
| 637 |
+
tool_yerr_upper = upper_data.values[:, i] - mean_vals[:, i]
|
| 638 |
+
# Construct yerr array specifically for this tool
|
| 639 |
+
tool_yerr = np.vstack([tool_yerr_lower, tool_yerr_upper])
|
| 640 |
+
|
| 641 |
+
ax.bar(
|
| 642 |
+
offsets,
|
| 643 |
+
tool_means,
|
| 644 |
+
width=width,
|
| 645 |
+
color=tool_colors[i],
|
| 646 |
+
label=tool,
|
| 647 |
+
yerr=tool_yerr,
|
| 648 |
+
capsize=2,
|
| 649 |
+
error_kw={"elinewidth": 0.75},
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
ax.set_xticks(x)
|
| 653 |
+
ax.set_xticklabels(plot_data.index, rotation=0)
|
| 654 |
+
ax.set_ylabel("Value")
|
| 655 |
+
ax.set_title(f"Performances on {dataset} with 95% CI (Bootstrap)")
|
| 656 |
+
|
| 657 |
+
plt.tight_layout()
|
| 658 |
+
ax.legend(loc="lower center", bbox_to_anchor=(0.5, 0.85), ncols=4, frameon=False)
|
| 659 |
+
|
| 660 |
+
plt.savefig(filename)
|
| 661 |
+
plt.show()
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def main():
|
| 665 |
+
eval_folder = "../../boltz_results_final/"
|
| 666 |
+
output_folder = "../../boltz_results_final/"
|
| 667 |
+
|
| 668 |
+
# Eval the test set
|
| 669 |
+
chai_preds = eval_folder + "outputs/test/chai"
|
| 670 |
+
chai_evals = eval_folder + "evals/test/chai"
|
| 671 |
+
|
| 672 |
+
af3_preds = eval_folder + "outputs/test/af3"
|
| 673 |
+
af3_evals = eval_folder + "evals/test/af3"
|
| 674 |
+
|
| 675 |
+
boltz_preds = eval_folder + "outputs/test/boltz/predictions"
|
| 676 |
+
boltz_evals = eval_folder + "evals/test/boltz"
|
| 677 |
+
|
| 678 |
+
boltz_preds_x = eval_folder + "outputs/test/boltzx/predictions"
|
| 679 |
+
boltz_evals_x = eval_folder + "evals/test/boltzx"
|
| 680 |
+
|
| 681 |
+
validity_checks = eval_folder + "physical_checks_test.csv"
|
| 682 |
+
|
| 683 |
+
df_validity_checks = pd.read_csv(validity_checks)
|
| 684 |
+
df_validity_checks = eval_validity_checks(df_validity_checks)
|
| 685 |
+
|
| 686 |
+
df = eval_models(
|
| 687 |
+
chai_preds,
|
| 688 |
+
chai_evals,
|
| 689 |
+
af3_preds,
|
| 690 |
+
af3_evals,
|
| 691 |
+
boltz_preds,
|
| 692 |
+
boltz_evals,
|
| 693 |
+
boltz_preds_x,
|
| 694 |
+
boltz_evals_x,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
df = pd.concat([df, df_validity_checks]).reset_index(drop=True)
|
| 698 |
+
df.to_csv(output_folder + "results_test.csv", index=False)
|
| 699 |
+
|
| 700 |
+
desired_tools = [
|
| 701 |
+
"AF3 oracle",
|
| 702 |
+
"AF3 top-1",
|
| 703 |
+
"Chai-1 oracle",
|
| 704 |
+
"Chai-1 top-1",
|
| 705 |
+
"Boltz-1 oracle",
|
| 706 |
+
"Boltz-1 top-1",
|
| 707 |
+
"Boltz-1x oracle",
|
| 708 |
+
"Boltz-1x top-1",
|
| 709 |
+
]
|
| 710 |
+
desired_metrics = ["lddt", "dockq_>0.23", "lddt_pli", "rmsd<2", "physical validity"]
|
| 711 |
+
plot_data(
|
| 712 |
+
desired_tools, desired_metrics, df, "PDB Test", output_folder + "plot_test.pdf"
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
# Eval CASP
|
| 716 |
+
chai_preds = eval_folder + "outputs/casp15/chai"
|
| 717 |
+
chai_evals = eval_folder + "evals/casp15/chai"
|
| 718 |
+
|
| 719 |
+
af3_preds = eval_folder + "outputs/casp15/af3"
|
| 720 |
+
af3_evals = eval_folder + "evals/casp15/af3"
|
| 721 |
+
|
| 722 |
+
boltz_preds = eval_folder + "outputs/casp15/boltz/predictions"
|
| 723 |
+
boltz_evals = eval_folder + "evals/casp15/boltz"
|
| 724 |
+
|
| 725 |
+
boltz_preds_x = eval_folder + "outputs/casp15/boltzx/predictions"
|
| 726 |
+
boltz_evals_x = eval_folder + "evals/casp15/boltzx"
|
| 727 |
+
|
| 728 |
+
validity_checks = eval_folder + "physical_checks_casp.csv"
|
| 729 |
+
|
| 730 |
+
df_validity_checks = pd.read_csv(validity_checks)
|
| 731 |
+
df_validity_checks = eval_validity_checks(df_validity_checks)
|
| 732 |
+
|
| 733 |
+
df = eval_models(
|
| 734 |
+
chai_preds,
|
| 735 |
+
chai_evals,
|
| 736 |
+
af3_preds,
|
| 737 |
+
af3_evals,
|
| 738 |
+
boltz_preds,
|
| 739 |
+
boltz_evals,
|
| 740 |
+
boltz_preds_x,
|
| 741 |
+
boltz_evals_x,
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
df = pd.concat([df, df_validity_checks]).reset_index(drop=True)
|
| 745 |
+
df.to_csv(output_folder + "results_casp.csv", index=False)
|
| 746 |
+
|
| 747 |
+
plot_data(
|
| 748 |
+
desired_tools, desired_metrics, df, "CASP15", output_folder + "plot_casp.pdf"
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
if __name__ == "__main__":
|
| 753 |
+
main()
|
protify/FastPLMs/boltz/scripts/eval/physcialsim_metrics.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from boltz.data.mol import load_molecules
|
| 10 |
+
from boltz.data import const
|
| 11 |
+
from boltz.data.parse.mmcif_with_constraints import parse_mmcif
|
| 12 |
+
from multiprocessing import Pool
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def compute_torsion_angles(coords, torsion_index):
|
| 16 |
+
r_ij = coords[..., torsion_index[0], :] - coords[..., torsion_index[1], :]
|
| 17 |
+
r_kj = coords[..., torsion_index[2], :] - coords[..., torsion_index[1], :]
|
| 18 |
+
r_kl = coords[..., torsion_index[2], :] - coords[..., torsion_index[3], :]
|
| 19 |
+
n_ijk = np.cross(r_ij, r_kj, axis=-1)
|
| 20 |
+
n_jkl = np.cross(r_kj, r_kl, axis=-1)
|
| 21 |
+
r_kj_norm = np.linalg.norm(r_kj, axis=-1)
|
| 22 |
+
n_ijk_norm = np.linalg.norm(n_ijk, axis=-1)
|
| 23 |
+
n_jkl_norm = np.linalg.norm(n_jkl, axis=-1)
|
| 24 |
+
sign_phi = np.sign(
|
| 25 |
+
r_kj[..., None, :] @ np.cross(n_ijk, n_jkl, axis=-1)[..., None]
|
| 26 |
+
).squeeze(axis=(-1, -2))
|
| 27 |
+
phi = sign_phi * np.arccos(
|
| 28 |
+
np.clip(
|
| 29 |
+
(n_ijk[..., None, :] @ n_jkl[..., None]).squeeze(axis=(-1, -2))
|
| 30 |
+
/ (n_ijk_norm * n_jkl_norm),
|
| 31 |
+
-1 + 1e-8,
|
| 32 |
+
1 - 1e-8,
|
| 33 |
+
)
|
| 34 |
+
)
|
| 35 |
+
return phi
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def check_ligand_distance_geometry(
|
| 39 |
+
structure, constraints, bond_buffer=0.25, angle_buffer=0.25, clash_buffer=0.2
|
| 40 |
+
):
|
| 41 |
+
coords = structure.coords["coords"]
|
| 42 |
+
rdkit_bounds_constraints = constraints.rdkit_bounds_constraints
|
| 43 |
+
pair_index = rdkit_bounds_constraints["atom_idxs"].copy().astype(np.int64).T
|
| 44 |
+
bond_mask = rdkit_bounds_constraints["is_bond"].copy().astype(bool)
|
| 45 |
+
angle_mask = rdkit_bounds_constraints["is_angle"].copy().astype(bool)
|
| 46 |
+
upper_bounds = rdkit_bounds_constraints["upper_bound"].copy().astype(np.float32)
|
| 47 |
+
lower_bounds = rdkit_bounds_constraints["lower_bound"].copy().astype(np.float32)
|
| 48 |
+
dists = np.linalg.norm(coords[pair_index[0]] - coords[pair_index[1]], axis=-1)
|
| 49 |
+
bond_length_violations = (
|
| 50 |
+
dists[bond_mask] <= lower_bounds[bond_mask] * (1.0 - bond_buffer)
|
| 51 |
+
) + (dists[bond_mask] >= upper_bounds[bond_mask] * (1.0 + bond_buffer))
|
| 52 |
+
bond_angle_violations = (
|
| 53 |
+
dists[angle_mask] <= lower_bounds[angle_mask] * (1.0 - angle_buffer)
|
| 54 |
+
) + (dists[angle_mask] >= upper_bounds[angle_mask] * (1.0 + angle_buffer))
|
| 55 |
+
internal_clash_violations = dists[~bond_mask * ~angle_mask] <= lower_bounds[
|
| 56 |
+
~bond_mask * ~angle_mask
|
| 57 |
+
] * (1.0 - clash_buffer)
|
| 58 |
+
num_ligands = sum(
|
| 59 |
+
[
|
| 60 |
+
int(const.chain_types[chain["mol_type"]] == "NONPOLYMER")
|
| 61 |
+
for chain in structure.chains
|
| 62 |
+
]
|
| 63 |
+
)
|
| 64 |
+
return {
|
| 65 |
+
"num_ligands": num_ligands,
|
| 66 |
+
"num_bond_length_violations": bond_length_violations.sum(),
|
| 67 |
+
"num_bonds": bond_mask.sum(),
|
| 68 |
+
"num_bond_angle_violations": bond_angle_violations.sum(),
|
| 69 |
+
"num_angles": angle_mask.sum(),
|
| 70 |
+
"num_internal_clash_violations": internal_clash_violations.sum(),
|
| 71 |
+
"num_non_neighbors": (~bond_mask * ~angle_mask).sum(),
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def check_ligand_stereochemistry(structure, constraints):
|
| 76 |
+
coords = structure.coords["coords"]
|
| 77 |
+
chiral_atom_constraints = constraints.chiral_atom_constraints
|
| 78 |
+
stereo_bond_constraints = constraints.stereo_bond_constraints
|
| 79 |
+
|
| 80 |
+
chiral_atom_index = chiral_atom_constraints["atom_idxs"].T
|
| 81 |
+
true_chiral_atom_orientations = chiral_atom_constraints["is_r"]
|
| 82 |
+
chiral_atom_ref_mask = chiral_atom_constraints["is_reference"]
|
| 83 |
+
chiral_atom_index = chiral_atom_index[:, chiral_atom_ref_mask]
|
| 84 |
+
true_chiral_atom_orientations = true_chiral_atom_orientations[chiral_atom_ref_mask]
|
| 85 |
+
pred_chiral_atom_orientations = (
|
| 86 |
+
compute_torsion_angles(coords, chiral_atom_index) > 0
|
| 87 |
+
)
|
| 88 |
+
chiral_atom_violations = (
|
| 89 |
+
pred_chiral_atom_orientations != true_chiral_atom_orientations
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
stereo_bond_index = stereo_bond_constraints["atom_idxs"].T
|
| 93 |
+
true_stereo_bond_orientations = stereo_bond_constraints["is_e"]
|
| 94 |
+
stereo_bond_ref_mask = stereo_bond_constraints["is_reference"]
|
| 95 |
+
stereo_bond_index = stereo_bond_index[:, stereo_bond_ref_mask]
|
| 96 |
+
true_stereo_bond_orientations = true_stereo_bond_orientations[stereo_bond_ref_mask]
|
| 97 |
+
pred_stereo_bond_orientations = (
|
| 98 |
+
np.abs(compute_torsion_angles(coords, stereo_bond_index)) > np.pi / 2
|
| 99 |
+
)
|
| 100 |
+
stereo_bond_violations = (
|
| 101 |
+
pred_stereo_bond_orientations != true_stereo_bond_orientations
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
"num_chiral_atom_violations": chiral_atom_violations.sum(),
|
| 106 |
+
"num_chiral_atoms": chiral_atom_index.shape[1],
|
| 107 |
+
"num_stereo_bond_violations": stereo_bond_violations.sum(),
|
| 108 |
+
"num_stereo_bonds": stereo_bond_index.shape[1],
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def check_ligand_flatness(structure, constraints, buffer=0.25):
|
| 113 |
+
coords = structure.coords["coords"]
|
| 114 |
+
|
| 115 |
+
planar_ring_5_index = constraints.planar_ring_5_constraints["atom_idxs"]
|
| 116 |
+
ring_5_coords = coords[planar_ring_5_index, :]
|
| 117 |
+
centered_ring_5_coords = ring_5_coords - ring_5_coords.mean(axis=-2, keepdims=True)
|
| 118 |
+
ring_5_vecs = np.linalg.svd(centered_ring_5_coords)[2][..., -1, :, None]
|
| 119 |
+
ring_5_dists = np.abs((centered_ring_5_coords @ ring_5_vecs).squeeze(axis=-1))
|
| 120 |
+
ring_5_violations = np.all(ring_5_dists <= buffer, axis=-1)
|
| 121 |
+
|
| 122 |
+
planar_ring_6_index = constraints.planar_ring_6_constraints["atom_idxs"]
|
| 123 |
+
ring_6_coords = coords[planar_ring_6_index, :]
|
| 124 |
+
centered_ring_6_coords = ring_6_coords - ring_6_coords.mean(axis=-2, keepdims=True)
|
| 125 |
+
ring_6_vecs = np.linalg.svd(centered_ring_6_coords)[2][..., -1, :, None]
|
| 126 |
+
ring_6_dists = np.abs((centered_ring_6_coords @ ring_6_vecs)).squeeze(axis=-1)
|
| 127 |
+
ring_6_violations = np.any(ring_6_dists >= buffer, axis=-1)
|
| 128 |
+
|
| 129 |
+
planar_bond_index = constraints.planar_bond_constraints["atom_idxs"]
|
| 130 |
+
bond_coords = coords[planar_bond_index, :]
|
| 131 |
+
centered_bond_coords = bond_coords - bond_coords.mean(axis=-2, keepdims=True)
|
| 132 |
+
bond_vecs = np.linalg.svd(centered_bond_coords)[2][..., -1, :, None]
|
| 133 |
+
bond_dists = np.abs((centered_bond_coords @ bond_vecs)).squeeze(axis=-1)
|
| 134 |
+
bond_violations = np.any(bond_dists >= buffer, axis=-1)
|
| 135 |
+
|
| 136 |
+
return {
|
| 137 |
+
"num_planar_5_ring_violations": ring_5_violations.sum(),
|
| 138 |
+
"num_planar_5_rings": ring_5_violations.shape[0],
|
| 139 |
+
"num_planar_6_ring_violations": ring_6_violations.sum(),
|
| 140 |
+
"num_planar_6_rings": ring_6_violations.shape[0],
|
| 141 |
+
"num_planar_double_bond_violations": bond_violations.sum(),
|
| 142 |
+
"num_planar_double_bonds": bond_violations.shape[0],
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def check_steric_clash(structure, molecules, buffer=0.25):
|
| 147 |
+
result = {}
|
| 148 |
+
for type_i in const.chain_types:
|
| 149 |
+
out_type_i = type_i.lower()
|
| 150 |
+
out_type_i = out_type_i if out_type_i != "nonpolymer" else "ligand"
|
| 151 |
+
result[f"num_chain_pairs_sym_{out_type_i}"] = 0
|
| 152 |
+
result[f"num_chain_clashes_sym_{out_type_i}"] = 0
|
| 153 |
+
for type_j in const.chain_types:
|
| 154 |
+
out_type_j = type_j.lower()
|
| 155 |
+
out_type_j = out_type_j if out_type_j != "nonpolymer" else "ligand"
|
| 156 |
+
result[f"num_chain_pairs_asym_{out_type_i}_{out_type_j}"] = 0
|
| 157 |
+
result[f"num_chain_clashes_asym_{out_type_i}_{out_type_j}"] = 0
|
| 158 |
+
|
| 159 |
+
connected_chains = set()
|
| 160 |
+
for bond in structure.bonds:
|
| 161 |
+
if bond["chain_1"] != bond["chain_2"]:
|
| 162 |
+
connected_chains.add(tuple(sorted((bond["chain_1"], bond["chain_2"]))))
|
| 163 |
+
|
| 164 |
+
vdw_radii = []
|
| 165 |
+
for res in structure.residues:
|
| 166 |
+
mol = molecules[res["name"]]
|
| 167 |
+
token_atoms = structure.atoms[
|
| 168 |
+
res["atom_idx"] : res["atom_idx"] + res["atom_num"]
|
| 169 |
+
]
|
| 170 |
+
atom_name_to_ref = {a.GetProp("name"): a for a in mol.GetAtoms()}
|
| 171 |
+
token_atoms_ref = [atom_name_to_ref[a["name"]] for a in token_atoms]
|
| 172 |
+
vdw_radii.extend(
|
| 173 |
+
[const.vdw_radii[a.GetAtomicNum() - 1] for a in token_atoms_ref]
|
| 174 |
+
)
|
| 175 |
+
vdw_radii = np.array(vdw_radii, dtype=np.float32)
|
| 176 |
+
|
| 177 |
+
np.array([a.GetAtomicNum() for a in token_atoms_ref])
|
| 178 |
+
for i, chain_i in enumerate(structure.chains):
|
| 179 |
+
for j, chain_j in enumerate(structure.chains):
|
| 180 |
+
if (
|
| 181 |
+
chain_i["atom_num"] == 1
|
| 182 |
+
or chain_j["atom_num"] == 1
|
| 183 |
+
or j <= i
|
| 184 |
+
or (i, j) in connected_chains
|
| 185 |
+
):
|
| 186 |
+
continue
|
| 187 |
+
coords_i = structure.coords["coords"][
|
| 188 |
+
chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"]
|
| 189 |
+
]
|
| 190 |
+
coords_j = structure.coords["coords"][
|
| 191 |
+
chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"]
|
| 192 |
+
]
|
| 193 |
+
dists = np.linalg.norm(coords_i[:, None, :] - coords_j[None, :, :], axis=-1)
|
| 194 |
+
radii_i = vdw_radii[
|
| 195 |
+
chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"]
|
| 196 |
+
]
|
| 197 |
+
radii_j = vdw_radii[
|
| 198 |
+
chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"]
|
| 199 |
+
]
|
| 200 |
+
radii_sum = radii_i[:, None] + radii_j[None, :]
|
| 201 |
+
is_clashing = np.any(dists < radii_sum * (1.00 - buffer))
|
| 202 |
+
type_i = const.chain_types[chain_i["mol_type"]].lower()
|
| 203 |
+
type_j = const.chain_types[chain_j["mol_type"]].lower()
|
| 204 |
+
type_i = type_i if type_i != "nonpolymer" else "ligand"
|
| 205 |
+
type_j = type_j if type_j != "nonpolymer" else "ligand"
|
| 206 |
+
is_symmetric = (
|
| 207 |
+
chain_i["entity_id"] == chain_j["entity_id"]
|
| 208 |
+
and chain_i["atom_num"] == chain_j["atom_num"]
|
| 209 |
+
)
|
| 210 |
+
if is_symmetric:
|
| 211 |
+
key = "sym_" + type_i
|
| 212 |
+
else:
|
| 213 |
+
key = "asym_" + type_i + "_" + type_j
|
| 214 |
+
result["num_chain_pairs_" + key] += 1
|
| 215 |
+
result["num_chain_clashes_" + key] += int(is_clashing)
|
| 216 |
+
return result
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
cache_dir = Path("/data/rbg/users/jwohlwend/boltz-cache")
|
| 220 |
+
ccd_path = cache_dir / "ccd.pkl"
|
| 221 |
+
moldir = cache_dir / "mols"
|
| 222 |
+
with ccd_path.open("rb") as file:
|
| 223 |
+
ccd = pickle.load(file)
|
| 224 |
+
|
| 225 |
+
boltz1_dir = Path(
|
| 226 |
+
"/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/boltz/predictions"
|
| 227 |
+
)
|
| 228 |
+
boltz1x_dir = Path(
|
| 229 |
+
"/data/scratch/getzn/boltz_private/boltz_1x_test_results_final_new/full_predictions"
|
| 230 |
+
)
|
| 231 |
+
chai_dir = Path(
|
| 232 |
+
"/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/chai"
|
| 233 |
+
)
|
| 234 |
+
af3_dir = Path(
|
| 235 |
+
"/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/af3"
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
boltz1_pdb_ids = set(os.listdir(boltz1_dir))
|
| 239 |
+
boltz1x_pdb_ids = set(os.listdir(boltz1x_dir))
|
| 240 |
+
chai_pdb_ids = set(os.listdir(chai_dir))
|
| 241 |
+
af3_pdb_ids = set([pdb_id for pdb_id in os.listdir(af3_dir)])
|
| 242 |
+
common_pdb_ids = boltz1_pdb_ids & boltz1x_pdb_ids & chai_pdb_ids & af3_pdb_ids
|
| 243 |
+
|
| 244 |
+
tools = ["boltz1", "boltz1x", "chai", "af3"]
|
| 245 |
+
num_samples = 5
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def process_fn(key):
|
| 249 |
+
tool, pdb_id, model_idx = key
|
| 250 |
+
if tool == "boltz1":
|
| 251 |
+
cif_path = boltz1_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif"
|
| 252 |
+
elif tool == "boltz1x":
|
| 253 |
+
cif_path = boltz1x_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif"
|
| 254 |
+
elif tool == "chai":
|
| 255 |
+
cif_path = chai_dir / pdb_id / f"pred.model_idx_{model_idx}.cif"
|
| 256 |
+
elif tool == "af3":
|
| 257 |
+
cif_path = af3_dir / pdb_id.lower() / f"seed-1_sample-{model_idx}" / "model.cif"
|
| 258 |
+
|
| 259 |
+
parsed_structure = parse_mmcif(
|
| 260 |
+
cif_path,
|
| 261 |
+
ccd,
|
| 262 |
+
moldir,
|
| 263 |
+
)
|
| 264 |
+
structure = parsed_structure.data
|
| 265 |
+
constraints = parsed_structure.residue_constraints
|
| 266 |
+
|
| 267 |
+
record = {
|
| 268 |
+
"tool": tool,
|
| 269 |
+
"pdb_id": pdb_id,
|
| 270 |
+
"model_idx": model_idx,
|
| 271 |
+
}
|
| 272 |
+
record.update(check_ligand_distance_geometry(structure, constraints))
|
| 273 |
+
record.update(check_ligand_stereochemistry(structure, constraints))
|
| 274 |
+
record.update(check_ligand_flatness(structure, constraints))
|
| 275 |
+
record.update(check_steric_clash(structure, molecules=ccd))
|
| 276 |
+
return record
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
keys = []
|
| 280 |
+
for tool in tools:
|
| 281 |
+
for pdb_id in common_pdb_ids:
|
| 282 |
+
for model_idx in range(num_samples):
|
| 283 |
+
keys.append((tool, pdb_id, model_idx))
|
| 284 |
+
|
| 285 |
+
process_fn(keys[0])
|
| 286 |
+
records = []
|
| 287 |
+
with Pool(48) as p:
|
| 288 |
+
with tqdm(total=len(keys)) as pbar:
|
| 289 |
+
for record in p.imap_unordered(process_fn, keys):
|
| 290 |
+
records.append(record)
|
| 291 |
+
pbar.update(1)
|
| 292 |
+
df = pd.DataFrame.from_records(records)
|
| 293 |
+
|
| 294 |
+
df["num_chain_clashes_all"] = df[
|
| 295 |
+
[key for key in df.columns if "chain_clash" in key]
|
| 296 |
+
].sum(axis=1)
|
| 297 |
+
df["num_pairs_all"] = df[[key for key in df.columns if "chain_pair" in key]].sum(axis=1)
|
| 298 |
+
df["clash_free"] = df["num_chain_clashes_all"] == 0
|
| 299 |
+
df["valid_ligand"] = (
|
| 300 |
+
df[[key for key in df.columns if "violation" in key]].sum(axis=1) == 0
|
| 301 |
+
)
|
| 302 |
+
df["valid"] = (df["clash_free"]) & (df["valid_ligand"])
|
| 303 |
+
|
| 304 |
+
df.to_csv("physical_checks_test.csv")
|
protify/FastPLMs/boltz/scripts/eval/run_evals.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import concurrent.futures
|
| 3 |
+
import subprocess
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
OST_COMPARE_STRUCTURE = r"""
|
| 9 |
+
#!/bin/bash
|
| 10 |
+
# https://openstructure.org/docs/2.7/actions/#ost-compare-structures
|
| 11 |
+
|
| 12 |
+
IMAGE_NAME=openstructure-0.2.8
|
| 13 |
+
|
| 14 |
+
command="compare-structures \
|
| 15 |
+
-m {model_file} \
|
| 16 |
+
-r {reference_file} \
|
| 17 |
+
--fault-tolerant \
|
| 18 |
+
--min-pep-length 4 \
|
| 19 |
+
--min-nuc-length 4 \
|
| 20 |
+
-o {output_path} \
|
| 21 |
+
--lddt --bb-lddt --qs-score --dockq \
|
| 22 |
+
--ics --ips --rigid-scores --patch-scores --tm-score"
|
| 23 |
+
|
| 24 |
+
sudo docker run -u $(id -u):$(id -g) --rm --volume {mount}:{mount} $IMAGE_NAME $command
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
OST_COMPARE_LIGAND = r"""
|
| 29 |
+
#!/bin/bash
|
| 30 |
+
# https://openstructure.org/docs/2.7/actions/#ost-compare-structures
|
| 31 |
+
|
| 32 |
+
IMAGE_NAME=openstructure-0.2.8
|
| 33 |
+
|
| 34 |
+
command="compare-ligand-structures \
|
| 35 |
+
-m {model_file} \
|
| 36 |
+
-r {reference_file} \
|
| 37 |
+
--fault-tolerant \
|
| 38 |
+
--lddt-pli --rmsd \
|
| 39 |
+
--substructure-match \
|
| 40 |
+
-o {output_path}"
|
| 41 |
+
|
| 42 |
+
sudo docker run -u $(id -u):$(id -g) --rm --volume {mount}:{mount} $IMAGE_NAME $command
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def evaluate_structure(
|
| 47 |
+
name: str,
|
| 48 |
+
pred: Path,
|
| 49 |
+
reference: Path,
|
| 50 |
+
outdir: str,
|
| 51 |
+
mount: str,
|
| 52 |
+
executable: str = "/bin/bash",
|
| 53 |
+
) -> None:
|
| 54 |
+
"""Evaluate the structure."""
|
| 55 |
+
# Evaluate polymer metrics
|
| 56 |
+
out_path = Path(outdir) / f"{name}.json"
|
| 57 |
+
|
| 58 |
+
if out_path.exists():
|
| 59 |
+
print( # noqa: T201
|
| 60 |
+
f"Skipping recomputation of {name} as protein json file already exists"
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
subprocess.run(
|
| 64 |
+
OST_COMPARE_STRUCTURE.format(
|
| 65 |
+
model_file=str(pred),
|
| 66 |
+
reference_file=str(reference),
|
| 67 |
+
output_path=str(out_path),
|
| 68 |
+
mount=mount,
|
| 69 |
+
),
|
| 70 |
+
shell=True, # noqa: S602
|
| 71 |
+
check=False,
|
| 72 |
+
executable=executable,
|
| 73 |
+
capture_output=True,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Evaluate ligand metrics
|
| 77 |
+
out_path = Path(outdir) / f"{name}_ligand.json"
|
| 78 |
+
if out_path.exists():
|
| 79 |
+
print(f"Skipping recomputation of {name} as ligand json file already exists") # noqa: T201
|
| 80 |
+
else:
|
| 81 |
+
subprocess.run(
|
| 82 |
+
OST_COMPARE_LIGAND.format(
|
| 83 |
+
model_file=str(pred),
|
| 84 |
+
reference_file=str(reference),
|
| 85 |
+
output_path=str(out_path),
|
| 86 |
+
mount=mount,
|
| 87 |
+
),
|
| 88 |
+
shell=True, # noqa: S602
|
| 89 |
+
check=False,
|
| 90 |
+
executable=executable,
|
| 91 |
+
capture_output=True,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def main(args):
|
| 96 |
+
# Aggregate the predictions and references
|
| 97 |
+
files = list(args.data.iterdir())
|
| 98 |
+
names = {f.stem.lower(): f for f in files}
|
| 99 |
+
|
| 100 |
+
# Create the output directory
|
| 101 |
+
args.outdir.mkdir(parents=True, exist_ok=True)
|
| 102 |
+
|
| 103 |
+
first_item = True
|
| 104 |
+
with concurrent.futures.ThreadPoolExecutor(args.max_workers) as executor:
|
| 105 |
+
futures = []
|
| 106 |
+
for name, folder in names.items():
|
| 107 |
+
for model_id in range(5):
|
| 108 |
+
# Split the input data
|
| 109 |
+
if args.format == "af3":
|
| 110 |
+
pred_path = folder / f"seed-1_sample-{model_id}" / "model.cif"
|
| 111 |
+
elif args.format == "chai":
|
| 112 |
+
pred_path = folder / f"pred.model_idx_{model_id}.cif"
|
| 113 |
+
elif args.format == "boltz":
|
| 114 |
+
name_file = (
|
| 115 |
+
f"{name[0].upper()}{name[1:]}"
|
| 116 |
+
if args.testset == "casp"
|
| 117 |
+
else name.lower()
|
| 118 |
+
)
|
| 119 |
+
pred_path = folder / f"{name_file}_model_{model_id}.cif"
|
| 120 |
+
|
| 121 |
+
if args.testset == "casp":
|
| 122 |
+
ref_path = args.pdb / f"{name[0].upper()}{name[1:]}.cif"
|
| 123 |
+
elif args.testset == "test":
|
| 124 |
+
ref_path = args.pdb / f"{name.lower()}.cif.gz"
|
| 125 |
+
|
| 126 |
+
if first_item:
|
| 127 |
+
# Evaluate the first item in the first prediction
|
| 128 |
+
# Ensures that the docker image is downloaded
|
| 129 |
+
evaluate_structure(
|
| 130 |
+
name=f"{name}_model_{model_id}",
|
| 131 |
+
pred=str(pred_path),
|
| 132 |
+
reference=str(ref_path),
|
| 133 |
+
outdir=str(args.outdir),
|
| 134 |
+
mount=args.mount,
|
| 135 |
+
executable=args.executable,
|
| 136 |
+
)
|
| 137 |
+
first_item = False
|
| 138 |
+
else:
|
| 139 |
+
future = executor.submit(
|
| 140 |
+
evaluate_structure,
|
| 141 |
+
name=f"{name}_model_{model_id}",
|
| 142 |
+
pred=str(pred_path),
|
| 143 |
+
reference=str(ref_path),
|
| 144 |
+
outdir=str(args.outdir),
|
| 145 |
+
mount=args.mount,
|
| 146 |
+
executable=args.executable,
|
| 147 |
+
)
|
| 148 |
+
futures.append(future)
|
| 149 |
+
|
| 150 |
+
# Wait for all tasks to complete
|
| 151 |
+
with tqdm(total=len(futures)) as pbar:
|
| 152 |
+
for _ in concurrent.futures.as_completed(futures):
|
| 153 |
+
pbar.update(1)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
parser = argparse.ArgumentParser()
|
| 158 |
+
parser.add_argument("data", type=Path)
|
| 159 |
+
parser.add_argument("pdb", type=Path)
|
| 160 |
+
parser.add_argument("outdir", type=Path)
|
| 161 |
+
parser.add_argument("--format", type=str, default="af3")
|
| 162 |
+
parser.add_argument("--testset", type=str, default="casp")
|
| 163 |
+
parser.add_argument("--mount", type=str)
|
| 164 |
+
parser.add_argument("--executable", type=str, default="/bin/bash")
|
| 165 |
+
parser.add_argument("--max-workers", type=int, default=32)
|
| 166 |
+
args = parser.parse_args()
|
| 167 |
+
main(args)
|
protify/FastPLMs/boltz/scripts/process/ccd.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compute conformers and symmetries for all the CCD molecules."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import multiprocessing
|
| 5 |
+
import pickle
|
| 6 |
+
import sys
|
| 7 |
+
from functools import partial
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import rdkit
|
| 12 |
+
from p_tqdm import p_uimap
|
| 13 |
+
from pdbeccdutils.core import ccd_reader
|
| 14 |
+
from pdbeccdutils.core.component import ConformerType
|
| 15 |
+
from rdkit import rdBase
|
| 16 |
+
from rdkit.Chem import AllChem
|
| 17 |
+
from rdkit.Chem.rdchem import Conformer, Mol
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_molecules(components: str) -> list[Mol]:
|
| 22 |
+
"""Load the CCD components file.
|
| 23 |
+
|
| 24 |
+
Parameters
|
| 25 |
+
----------
|
| 26 |
+
components : str
|
| 27 |
+
Path to the CCD components file.
|
| 28 |
+
|
| 29 |
+
Returns
|
| 30 |
+
-------
|
| 31 |
+
list[Mol]
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
components: dict[str, ccd_reader.CCDReaderResult]
|
| 35 |
+
components = ccd_reader.read_pdb_components_file(components)
|
| 36 |
+
|
| 37 |
+
mols = []
|
| 38 |
+
for name, component in components.items():
|
| 39 |
+
mol = component.component.mol
|
| 40 |
+
mol.SetProp("PDB_NAME", name)
|
| 41 |
+
mols.append(mol)
|
| 42 |
+
|
| 43 |
+
return mols
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def compute_3d(mol: Mol, version: str = "v3") -> bool:
|
| 47 |
+
"""Generate 3D coordinates using EKTDG method.
|
| 48 |
+
|
| 49 |
+
Taken from `pdbeccdutils.core.component.Component`.
|
| 50 |
+
|
| 51 |
+
Parameters
|
| 52 |
+
----------
|
| 53 |
+
mol: Mol
|
| 54 |
+
The RDKit molecule to process
|
| 55 |
+
version: str, optional
|
| 56 |
+
The ETKDG version, defaults ot v3
|
| 57 |
+
|
| 58 |
+
Returns
|
| 59 |
+
-------
|
| 60 |
+
bool
|
| 61 |
+
Whether computation was successful.
|
| 62 |
+
|
| 63 |
+
"""
|
| 64 |
+
if version == "v3":
|
| 65 |
+
options = rdkit.Chem.AllChem.ETKDGv3()
|
| 66 |
+
elif version == "v2":
|
| 67 |
+
options = rdkit.Chem.AllChem.ETKDGv2()
|
| 68 |
+
else:
|
| 69 |
+
options = rdkit.Chem.AllChem.ETKDGv2()
|
| 70 |
+
|
| 71 |
+
options.clearConfs = False
|
| 72 |
+
conf_id = -1
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
conf_id = rdkit.Chem.AllChem.EmbedMolecule(mol, options)
|
| 76 |
+
rdkit.Chem.AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000)
|
| 77 |
+
|
| 78 |
+
except RuntimeError:
|
| 79 |
+
pass # Force field issue here
|
| 80 |
+
except ValueError:
|
| 81 |
+
pass # sanitization issue here
|
| 82 |
+
|
| 83 |
+
if conf_id != -1:
|
| 84 |
+
conformer = mol.GetConformer(conf_id)
|
| 85 |
+
conformer.SetProp("name", ConformerType.Computed.name)
|
| 86 |
+
conformer.SetProp("coord_generation", f"ETKDG{version}")
|
| 87 |
+
|
| 88 |
+
return True
|
| 89 |
+
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_conformer(mol: Mol, c_type: ConformerType) -> Conformer:
|
| 94 |
+
"""Retrieve an rdkit object for a deemed conformer.
|
| 95 |
+
|
| 96 |
+
Taken from `pdbeccdutils.core.component.Component`.
|
| 97 |
+
|
| 98 |
+
Parameters
|
| 99 |
+
----------
|
| 100 |
+
mol: Mol
|
| 101 |
+
The molecule to process.
|
| 102 |
+
c_type: ConformerType
|
| 103 |
+
The conformer type to extract.
|
| 104 |
+
|
| 105 |
+
Returns
|
| 106 |
+
-------
|
| 107 |
+
Conformer
|
| 108 |
+
The desired conformer, if any.
|
| 109 |
+
|
| 110 |
+
Raises
|
| 111 |
+
------
|
| 112 |
+
ValueError
|
| 113 |
+
If there are no conformers of the given tyoe.
|
| 114 |
+
|
| 115 |
+
"""
|
| 116 |
+
for c in mol.GetConformers():
|
| 117 |
+
try:
|
| 118 |
+
if c.GetProp("name") == c_type.name:
|
| 119 |
+
return c
|
| 120 |
+
except KeyError: # noqa: PERF203
|
| 121 |
+
pass
|
| 122 |
+
|
| 123 |
+
msg = f"Conformer {c_type.name} does not exist."
|
| 124 |
+
raise ValueError(msg)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def compute_symmetries(mol: Mol) -> list[list[int]]:
|
| 128 |
+
"""Compute the symmetries of a molecule.
|
| 129 |
+
|
| 130 |
+
Parameters
|
| 131 |
+
----------
|
| 132 |
+
mol : Mol
|
| 133 |
+
The molecule to process
|
| 134 |
+
|
| 135 |
+
Returns
|
| 136 |
+
-------
|
| 137 |
+
list[list[int]]
|
| 138 |
+
The symmetries as a list of index permutations
|
| 139 |
+
|
| 140 |
+
"""
|
| 141 |
+
mol = AllChem.RemoveHs(mol)
|
| 142 |
+
idx_map = {}
|
| 143 |
+
atom_idx = 0
|
| 144 |
+
for i, atom in enumerate(mol.GetAtoms()):
|
| 145 |
+
# Skip if leaving atoms
|
| 146 |
+
if int(atom.GetProp("leaving_atom")):
|
| 147 |
+
continue
|
| 148 |
+
idx_map[i] = atom_idx
|
| 149 |
+
atom_idx += 1
|
| 150 |
+
|
| 151 |
+
# Calculate self permutations
|
| 152 |
+
permutations = []
|
| 153 |
+
raw_permutations = mol.GetSubstructMatches(mol, uniquify=False)
|
| 154 |
+
for raw_permutation in raw_permutations:
|
| 155 |
+
# Filter out permutations with leaving atoms
|
| 156 |
+
try:
|
| 157 |
+
if {raw_permutation[idx] for idx in idx_map} == set(idx_map.keys()):
|
| 158 |
+
permutation = [
|
| 159 |
+
idx_map[idx] for idx in raw_permutation if idx in idx_map
|
| 160 |
+
]
|
| 161 |
+
permutations.append(permutation)
|
| 162 |
+
except Exception: # noqa: S110, PERF203, BLE001
|
| 163 |
+
pass
|
| 164 |
+
serialized_permutations = pickle.dumps(permutations)
|
| 165 |
+
mol.SetProp("symmetries", serialized_permutations.hex())
|
| 166 |
+
return permutations
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def process(mol: Mol, output: str) -> tuple[str, str]:
|
| 170 |
+
"""Process a CCD component.
|
| 171 |
+
|
| 172 |
+
Parameters
|
| 173 |
+
----------
|
| 174 |
+
mol : Mol
|
| 175 |
+
The molecule to process
|
| 176 |
+
output : str
|
| 177 |
+
The directory to save the molecules
|
| 178 |
+
|
| 179 |
+
Returns
|
| 180 |
+
-------
|
| 181 |
+
str
|
| 182 |
+
The name of the component
|
| 183 |
+
str
|
| 184 |
+
The result of the conformer generation
|
| 185 |
+
|
| 186 |
+
"""
|
| 187 |
+
# Get name
|
| 188 |
+
name = mol.GetProp("PDB_NAME")
|
| 189 |
+
|
| 190 |
+
# Check if single atom
|
| 191 |
+
if mol.GetNumAtoms() == 1:
|
| 192 |
+
result = "single"
|
| 193 |
+
else:
|
| 194 |
+
# Get the 3D conformer
|
| 195 |
+
try:
|
| 196 |
+
# Try to generate a 3D conformer with RDKit
|
| 197 |
+
success = compute_3d(mol, version="v3")
|
| 198 |
+
if success:
|
| 199 |
+
_ = get_conformer(mol, ConformerType.Computed)
|
| 200 |
+
result = "computed"
|
| 201 |
+
|
| 202 |
+
# Otherwise, default to the ideal coordinates
|
| 203 |
+
else:
|
| 204 |
+
_ = get_conformer(mol, ConformerType.Ideal)
|
| 205 |
+
result = "ideal"
|
| 206 |
+
except ValueError:
|
| 207 |
+
result = "failed"
|
| 208 |
+
|
| 209 |
+
# Dump the molecule
|
| 210 |
+
path = Path(output) / f"{name}.pkl"
|
| 211 |
+
with path.open("wb") as f:
|
| 212 |
+
pickle.dump(mol, f)
|
| 213 |
+
|
| 214 |
+
# Output the results
|
| 215 |
+
return name, result
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def main(args: argparse.Namespace) -> None:
|
| 219 |
+
"""Process conformers."""
|
| 220 |
+
# Set property saving
|
| 221 |
+
rdkit.Chem.SetDefaultPickleProperties(rdkit.Chem.PropertyPickleOptions.AllProps)
|
| 222 |
+
|
| 223 |
+
# Load components
|
| 224 |
+
print("Loading components") # noqa: T201
|
| 225 |
+
molecules = load_molecules(args.components)
|
| 226 |
+
|
| 227 |
+
# Reset stdout and stderr, as pdbccdutils messes with them
|
| 228 |
+
sys.stdout = sys.__stdout__
|
| 229 |
+
sys.stderr = sys.__stderr__
|
| 230 |
+
|
| 231 |
+
# Disable rdkit warnings
|
| 232 |
+
blocker = rdBase.BlockLogs() # noqa: F841
|
| 233 |
+
|
| 234 |
+
# Setup processing function
|
| 235 |
+
outdir = Path(args.outdir)
|
| 236 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
| 237 |
+
mol_output = outdir / "mols"
|
| 238 |
+
mol_output.mkdir(parents=True, exist_ok=True)
|
| 239 |
+
process_fn = partial(process, output=str(mol_output))
|
| 240 |
+
|
| 241 |
+
# Process the files in parallel
|
| 242 |
+
print("Processing components") # noqa: T201
|
| 243 |
+
metadata = []
|
| 244 |
+
|
| 245 |
+
# Check if we can run in parallel
|
| 246 |
+
max_processes = multiprocessing.cpu_count()
|
| 247 |
+
num_processes = max(1, min(args.num_processes, max_processes, len(molecules)))
|
| 248 |
+
parallel = num_processes > 1
|
| 249 |
+
|
| 250 |
+
if parallel:
|
| 251 |
+
for name, result in p_uimap(
|
| 252 |
+
process_fn,
|
| 253 |
+
molecules,
|
| 254 |
+
num_cpus=num_processes,
|
| 255 |
+
):
|
| 256 |
+
metadata.append({"name": name, "result": result})
|
| 257 |
+
else:
|
| 258 |
+
for mol in tqdm(molecules):
|
| 259 |
+
name, result = process_fn(mol)
|
| 260 |
+
metadata.append({"name": name, "result": result})
|
| 261 |
+
|
| 262 |
+
# Load and group outputs
|
| 263 |
+
molecules = {}
|
| 264 |
+
for item in metadata:
|
| 265 |
+
if item["result"] == "failed":
|
| 266 |
+
continue
|
| 267 |
+
|
| 268 |
+
# Load the mol file
|
| 269 |
+
path = mol_output / f"{item['name']}.pkl"
|
| 270 |
+
with path.open("rb") as f:
|
| 271 |
+
mol = pickle.load(f) # noqa: S301
|
| 272 |
+
molecules[item["name"]] = mol
|
| 273 |
+
|
| 274 |
+
# Dump metadata
|
| 275 |
+
path = outdir / "results.csv"
|
| 276 |
+
metadata = pd.DataFrame(metadata)
|
| 277 |
+
metadata.to_csv(path)
|
| 278 |
+
|
| 279 |
+
# Dump the components
|
| 280 |
+
path = outdir / "ccd.pkl"
|
| 281 |
+
with path.open("wb") as f:
|
| 282 |
+
pickle.dump(molecules, f)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
if __name__ == "__main__":
|
| 286 |
+
parser = argparse.ArgumentParser()
|
| 287 |
+
parser.add_argument("--components", type=str)
|
| 288 |
+
parser.add_argument("--outdir", type=str)
|
| 289 |
+
parser.add_argument(
|
| 290 |
+
"--num_processes",
|
| 291 |
+
type=int,
|
| 292 |
+
default=multiprocessing.cpu_count(),
|
| 293 |
+
)
|
| 294 |
+
args = parser.parse_args()
|
| 295 |
+
main(args)
|
protify/FastPLMs/boltz/scripts/process/cluster.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Create a mapping from structure and chain ID to MSA indices."""
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import hashlib
|
| 5 |
+
import json
|
| 6 |
+
import pickle
|
| 7 |
+
import subprocess
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from Bio import SeqIO
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def hash_sequence(seq: str) -> str:
|
| 15 |
+
"""Hash a sequence."""
|
| 16 |
+
return hashlib.sha256(seq.encode()).hexdigest()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main(args: argparse.Namespace) -> None:
|
| 20 |
+
"""Create clustering."""
|
| 21 |
+
# Set output directory
|
| 22 |
+
outdir = Path(args.outdir)
|
| 23 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
# Split the sequences into proteins and nucleotides
|
| 26 |
+
with Path(args.sequences).open("r") as f:
|
| 27 |
+
data = list(SeqIO.parse(f, "fasta"))
|
| 28 |
+
|
| 29 |
+
proteins = set()
|
| 30 |
+
shorts = set()
|
| 31 |
+
nucleotides = set()
|
| 32 |
+
|
| 33 |
+
# Separate the sequences into proteins, nucleotides and short sequences
|
| 34 |
+
# Short sequences cause a bug in the clustering, so they are separated
|
| 35 |
+
for seq in data:
|
| 36 |
+
if set(str(seq.seq)).issubset({"A", "C", "G", "T", "U", "N"}):
|
| 37 |
+
nucleotides.add(str(seq.seq).strip())
|
| 38 |
+
elif len(str(seq.seq).strip()) < 10: # noqa: PLR2004
|
| 39 |
+
shorts.add(str(seq.seq).strip())
|
| 40 |
+
else:
|
| 41 |
+
proteins.add(str(seq.seq).strip())
|
| 42 |
+
|
| 43 |
+
# Run mmseqs on the protein data
|
| 44 |
+
proteins = [f">{hash_sequence(seq)}\n{seq}" for seq in proteins]
|
| 45 |
+
with (outdir / "proteins.fasta").open("w") as f:
|
| 46 |
+
f.write("\n".join(proteins))
|
| 47 |
+
|
| 48 |
+
subprocess.run(
|
| 49 |
+
f"{args.mmseqs} easy-cluster {outdir / 'proteins.fasta'} {outdir / 'clust_prot'} {outdir / 'tmp'} --min-seq-id 0.4", # noqa: E501
|
| 50 |
+
shell=True, # noqa: S602
|
| 51 |
+
check=True,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Load protein clusters
|
| 55 |
+
clustering_path = outdir / "clust_prot_cluster.tsv"
|
| 56 |
+
protein_data = pd.read_csv(clustering_path, sep="\t", header=None)
|
| 57 |
+
clusters = protein_data[0]
|
| 58 |
+
items = protein_data[1]
|
| 59 |
+
clustering = dict(zip(list(items), list(clusters)))
|
| 60 |
+
|
| 61 |
+
# Each shqrt sequence is given an id
|
| 62 |
+
for short in shorts:
|
| 63 |
+
short_id = hash_sequence(short)
|
| 64 |
+
clustering[short_id] = short_id
|
| 65 |
+
|
| 66 |
+
# Each unique rna sequence is given an id
|
| 67 |
+
for nucl in nucleotides:
|
| 68 |
+
nucl_id = hash_sequence(nucl)
|
| 69 |
+
clustering[nucl_id] = nucl_id
|
| 70 |
+
|
| 71 |
+
# Load ligand data
|
| 72 |
+
with Path(args.ccd).open("rb") as handle:
|
| 73 |
+
ligand_data = pickle.load(handle) # noqa: S301
|
| 74 |
+
|
| 75 |
+
# Each unique ligand CCD is given an id
|
| 76 |
+
for ccd_code in ligand_data:
|
| 77 |
+
clustering[ccd_code] = ccd_code
|
| 78 |
+
|
| 79 |
+
# Save clustering
|
| 80 |
+
with (outdir / "clustering.json").open("w") as handle:
|
| 81 |
+
json.dump(clustering, handle)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
parser = argparse.ArgumentParser()
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--sequences",
|
| 88 |
+
type=str,
|
| 89 |
+
help="Input to protein fasta.",
|
| 90 |
+
required=True,
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--ccd",
|
| 94 |
+
type=str,
|
| 95 |
+
help="Input to rna fasta.",
|
| 96 |
+
required=True,
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--outdir",
|
| 100 |
+
type=str,
|
| 101 |
+
help="Output directory.",
|
| 102 |
+
required=True,
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--mmseqs",
|
| 106 |
+
type=str,
|
| 107 |
+
help="Path to mmseqs program.",
|
| 108 |
+
default="mmseqs",
|
| 109 |
+
)
|
| 110 |
+
args = parser.parse_args()
|
| 111 |
+
main(args)
|
protify/FastPLMs/boltz/scripts/process/mmcif.py
ADDED
|
@@ -0,0 +1,1123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
from dataclasses import dataclass, replace
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import gemmi
|
| 6 |
+
import numpy as np
|
| 7 |
+
from rdkit import rdBase
|
| 8 |
+
from rdkit.Chem import AllChem
|
| 9 |
+
from rdkit.Chem.rdchem import Conformer, Mol
|
| 10 |
+
from sklearn.neighbors import KDTree
|
| 11 |
+
|
| 12 |
+
from boltz.data import const
|
| 13 |
+
from boltz.data.types import (
|
| 14 |
+
Atom,
|
| 15 |
+
Bond,
|
| 16 |
+
Chain,
|
| 17 |
+
Connection,
|
| 18 |
+
Interface,
|
| 19 |
+
Residue,
|
| 20 |
+
Structure,
|
| 21 |
+
StructureInfo,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
####################################################################################################
|
| 25 |
+
# DATACLASSES
|
| 26 |
+
####################################################################################################
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True, slots=True)
|
| 30 |
+
class ParsedAtom:
|
| 31 |
+
"""A parsed atom object."""
|
| 32 |
+
|
| 33 |
+
name: str
|
| 34 |
+
element: int
|
| 35 |
+
charge: int
|
| 36 |
+
coords: tuple[float, float, float]
|
| 37 |
+
conformer: tuple[float, float, float]
|
| 38 |
+
is_present: bool
|
| 39 |
+
chirality: int
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass(frozen=True, slots=True)
|
| 43 |
+
class ParsedBond:
|
| 44 |
+
"""A parsed bond object."""
|
| 45 |
+
|
| 46 |
+
atom_1: int
|
| 47 |
+
atom_2: int
|
| 48 |
+
type: int
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass(frozen=True, slots=True)
|
| 52 |
+
class ParsedResidue:
|
| 53 |
+
"""A parsed residue object."""
|
| 54 |
+
|
| 55 |
+
name: str
|
| 56 |
+
type: int
|
| 57 |
+
idx: int
|
| 58 |
+
atoms: list[ParsedAtom]
|
| 59 |
+
bonds: list[ParsedBond]
|
| 60 |
+
orig_idx: Optional[int]
|
| 61 |
+
atom_center: int
|
| 62 |
+
atom_disto: int
|
| 63 |
+
is_standard: bool
|
| 64 |
+
is_present: bool
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass(frozen=True, slots=True)
|
| 68 |
+
class ParsedChain:
|
| 69 |
+
"""A parsed chain object."""
|
| 70 |
+
|
| 71 |
+
name: str
|
| 72 |
+
entity: str
|
| 73 |
+
type: str
|
| 74 |
+
residues: list[ParsedResidue]
|
| 75 |
+
sequence: list[str]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass(frozen=True, slots=True)
|
| 79 |
+
class ParsedConnection:
|
| 80 |
+
"""A parsed connection object."""
|
| 81 |
+
|
| 82 |
+
chain_1: str
|
| 83 |
+
chain_2: str
|
| 84 |
+
residue_index_1: int
|
| 85 |
+
residue_index_2: int
|
| 86 |
+
atom_index_1: str
|
| 87 |
+
atom_index_2: str
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@dataclass(frozen=True, slots=True)
|
| 91 |
+
class ParsedStructure:
|
| 92 |
+
"""A parsed structure object."""
|
| 93 |
+
|
| 94 |
+
data: Structure
|
| 95 |
+
info: StructureInfo
|
| 96 |
+
covalents: list[int]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
####################################################################################################
|
| 100 |
+
# HELPERS
|
| 101 |
+
####################################################################################################
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_dates(block: gemmi.cif.Block) -> tuple[str, str, str]:
|
| 105 |
+
"""Get the deposited, released, and last revision dates.
|
| 106 |
+
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
block : gemmi.cif.Block
|
| 110 |
+
The block to process.
|
| 111 |
+
|
| 112 |
+
Returns
|
| 113 |
+
-------
|
| 114 |
+
str
|
| 115 |
+
The deposited date.
|
| 116 |
+
str
|
| 117 |
+
The released date.
|
| 118 |
+
str
|
| 119 |
+
The last revision date.
|
| 120 |
+
|
| 121 |
+
"""
|
| 122 |
+
deposited = "_pdbx_database_status.recvd_initial_deposition_date"
|
| 123 |
+
revision = "_pdbx_audit_revision_history.revision_date"
|
| 124 |
+
deposit_date = revision_date = release_date = ""
|
| 125 |
+
with contextlib.suppress(Exception):
|
| 126 |
+
deposit_date = block.find([deposited])[0][0]
|
| 127 |
+
release_date = block.find([revision])[0][0]
|
| 128 |
+
revision_date = block.find([revision])[-1][0]
|
| 129 |
+
|
| 130 |
+
return deposit_date, release_date, revision_date
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_resolution(block: gemmi.cif.Block) -> float:
|
| 134 |
+
"""Get the resolution from a gemmi structure.
|
| 135 |
+
|
| 136 |
+
Parameters
|
| 137 |
+
----------
|
| 138 |
+
block : gemmi.cif.Block
|
| 139 |
+
The block to process.
|
| 140 |
+
|
| 141 |
+
Returns
|
| 142 |
+
-------
|
| 143 |
+
float
|
| 144 |
+
The resolution.
|
| 145 |
+
|
| 146 |
+
"""
|
| 147 |
+
resolution = 0.0
|
| 148 |
+
for res_key in (
|
| 149 |
+
"_refine.ls_d_res_high",
|
| 150 |
+
"_em_3d_reconstruction.resolution",
|
| 151 |
+
"_reflns.d_resolution_high",
|
| 152 |
+
):
|
| 153 |
+
with contextlib.suppress(Exception):
|
| 154 |
+
resolution = float(block.find([res_key])[0].str(0))
|
| 155 |
+
break
|
| 156 |
+
return resolution
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def get_method(block: gemmi.cif.Block) -> str:
|
| 160 |
+
"""Get the method from a gemmi structure.
|
| 161 |
+
|
| 162 |
+
Parameters
|
| 163 |
+
----------
|
| 164 |
+
block : gemmi.cif.Block
|
| 165 |
+
The block to process.
|
| 166 |
+
|
| 167 |
+
Returns
|
| 168 |
+
-------
|
| 169 |
+
str
|
| 170 |
+
The method.
|
| 171 |
+
|
| 172 |
+
"""
|
| 173 |
+
method = ""
|
| 174 |
+
method_key = "_exptl.method"
|
| 175 |
+
with contextlib.suppress(Exception):
|
| 176 |
+
methods = block.find([method_key])
|
| 177 |
+
method = ",".join([m.str(0).lower() for m in methods])
|
| 178 |
+
|
| 179 |
+
return method
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def convert_atom_name(name: str) -> tuple[int, int, int, int]:
|
| 183 |
+
"""Convert an atom name to a standard format.
|
| 184 |
+
|
| 185 |
+
Parameters
|
| 186 |
+
----------
|
| 187 |
+
name : str
|
| 188 |
+
The atom name.
|
| 189 |
+
|
| 190 |
+
Returns
|
| 191 |
+
-------
|
| 192 |
+
tuple[int, int, int, int]
|
| 193 |
+
The converted atom name.
|
| 194 |
+
|
| 195 |
+
"""
|
| 196 |
+
name = name.strip()
|
| 197 |
+
name = [ord(c) - 32 for c in name]
|
| 198 |
+
name = name + [0] * (4 - len(name))
|
| 199 |
+
return tuple(name)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def get_unk_token(dtype: gemmi.PolymerType) -> str:
|
| 203 |
+
"""Get the unknown token for a given entity type.
|
| 204 |
+
|
| 205 |
+
Parameters
|
| 206 |
+
----------
|
| 207 |
+
dtype : gemmi.EntityType
|
| 208 |
+
The entity type.
|
| 209 |
+
|
| 210 |
+
Returns
|
| 211 |
+
-------
|
| 212 |
+
str
|
| 213 |
+
The unknown token.
|
| 214 |
+
|
| 215 |
+
"""
|
| 216 |
+
if dtype == gemmi.PolymerType.PeptideL:
|
| 217 |
+
unk = const.unk_token["PROTEIN"]
|
| 218 |
+
elif dtype == gemmi.PolymerType.Dna:
|
| 219 |
+
unk = const.unk_token["DNA"]
|
| 220 |
+
elif dtype == gemmi.PolymerType.Rna:
|
| 221 |
+
unk = const.unk_token["RNA"]
|
| 222 |
+
else:
|
| 223 |
+
msg = f"Unknown polymer type: {dtype}"
|
| 224 |
+
raise ValueError(msg)
|
| 225 |
+
|
| 226 |
+
return unk
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def get_conformer(mol: Mol) -> Conformer:
|
| 230 |
+
"""Retrieve an rdkit object for a deemed conformer.
|
| 231 |
+
|
| 232 |
+
Inspired by `pdbeccdutils.core.component.Component`.
|
| 233 |
+
|
| 234 |
+
Parameters
|
| 235 |
+
----------
|
| 236 |
+
mol: Mol
|
| 237 |
+
The molecule to process.
|
| 238 |
+
|
| 239 |
+
Returns
|
| 240 |
+
-------
|
| 241 |
+
Conformer
|
| 242 |
+
The desired conformer, if any.
|
| 243 |
+
|
| 244 |
+
Raises
|
| 245 |
+
------
|
| 246 |
+
ValueError
|
| 247 |
+
If there are no conformers of the given tyoe.
|
| 248 |
+
|
| 249 |
+
"""
|
| 250 |
+
for c in mol.GetConformers():
|
| 251 |
+
try:
|
| 252 |
+
if c.GetProp("name") == "Computed":
|
| 253 |
+
return c
|
| 254 |
+
except KeyError: # noqa: PERF203
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
for c in mol.GetConformers():
|
| 258 |
+
try:
|
| 259 |
+
if c.GetProp("name") == "Ideal":
|
| 260 |
+
return c
|
| 261 |
+
except KeyError: # noqa: PERF203
|
| 262 |
+
pass
|
| 263 |
+
|
| 264 |
+
msg = "Conformer does not exist."
|
| 265 |
+
raise ValueError(msg)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def compute_covalent_ligands(
|
| 269 |
+
connections: list[gemmi.Connection],
|
| 270 |
+
subchain_map: dict[tuple[str, int], str],
|
| 271 |
+
entities: dict[str, gemmi.Entity],
|
| 272 |
+
) -> set[str]:
|
| 273 |
+
"""Compute the covalent ligands from a list of connections.
|
| 274 |
+
|
| 275 |
+
Parameters
|
| 276 |
+
----------
|
| 277 |
+
connections: List[gemmi.Connection]
|
| 278 |
+
The connections to process.
|
| 279 |
+
subchain_map: dict[tuple[str, int], str]
|
| 280 |
+
The mapping from chain, residue index to subchain name.
|
| 281 |
+
entities: dict[str, gemmi.Entity]
|
| 282 |
+
The entities in the structure.
|
| 283 |
+
|
| 284 |
+
Returns
|
| 285 |
+
-------
|
| 286 |
+
set
|
| 287 |
+
The covalent ligand subchains.
|
| 288 |
+
|
| 289 |
+
"""
|
| 290 |
+
# Get covalent chain ids
|
| 291 |
+
covalent_chain_ids = set()
|
| 292 |
+
for connection in connections:
|
| 293 |
+
if connection.type.name != "Covale":
|
| 294 |
+
continue
|
| 295 |
+
|
| 296 |
+
# Map to correct subchain
|
| 297 |
+
chain_1_name = connection.partner1.chain_name
|
| 298 |
+
chain_2_name = connection.partner2.chain_name
|
| 299 |
+
|
| 300 |
+
res_1_id = connection.partner1.res_id.seqid
|
| 301 |
+
res_1_id = str(res_1_id.num) + str(res_1_id.icode).strip()
|
| 302 |
+
|
| 303 |
+
res_2_id = connection.partner2.res_id.seqid
|
| 304 |
+
res_2_id = str(res_2_id.num) + str(res_2_id.icode).strip()
|
| 305 |
+
|
| 306 |
+
subchain_1 = subchain_map[(chain_1_name, res_1_id)]
|
| 307 |
+
subchain_2 = subchain_map[(chain_2_name, res_2_id)]
|
| 308 |
+
|
| 309 |
+
# If non-polymer or branched, add to set
|
| 310 |
+
entity_1 = entities[subchain_1].entity_type.name
|
| 311 |
+
entity_2 = entities[subchain_2].entity_type.name
|
| 312 |
+
|
| 313 |
+
if entity_1 in {"NonPolymer", "Branched"}:
|
| 314 |
+
covalent_chain_ids.add(subchain_1)
|
| 315 |
+
if entity_2 in {"NonPolymer", "Branched"}:
|
| 316 |
+
covalent_chain_ids.add(subchain_2)
|
| 317 |
+
|
| 318 |
+
return covalent_chain_ids
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def compute_interfaces(atom_data: np.ndarray, chain_data: np.ndarray) -> np.ndarray:
|
| 322 |
+
"""Compute the chain-chain interfaces from a gemmi structure.
|
| 323 |
+
|
| 324 |
+
Parameters
|
| 325 |
+
----------
|
| 326 |
+
atom_data : List[tuple]
|
| 327 |
+
The atom data.
|
| 328 |
+
chain_data : List[tuple]
|
| 329 |
+
The chain data.
|
| 330 |
+
|
| 331 |
+
Returns
|
| 332 |
+
-------
|
| 333 |
+
List[tuple[int, int]]
|
| 334 |
+
The interfaces.
|
| 335 |
+
|
| 336 |
+
"""
|
| 337 |
+
# Compute chain_id per atom
|
| 338 |
+
chain_ids = []
|
| 339 |
+
for idx, chain in enumerate(chain_data):
|
| 340 |
+
chain_ids.extend([idx] * chain["atom_num"])
|
| 341 |
+
chain_ids = np.array(chain_ids)
|
| 342 |
+
|
| 343 |
+
# Filte to present atoms
|
| 344 |
+
coords = atom_data["coords"]
|
| 345 |
+
mask = atom_data["is_present"]
|
| 346 |
+
|
| 347 |
+
coords = coords[mask]
|
| 348 |
+
chain_ids = chain_ids[mask]
|
| 349 |
+
|
| 350 |
+
# Compute the distance matrix
|
| 351 |
+
tree = KDTree(coords, metric="euclidean")
|
| 352 |
+
query = tree.query_radius(coords, const.atom_interface_cutoff)
|
| 353 |
+
|
| 354 |
+
# Get unique chain pairs
|
| 355 |
+
interfaces = set()
|
| 356 |
+
for c1, pairs in zip(chain_ids, query):
|
| 357 |
+
chains = np.unique(chain_ids[pairs])
|
| 358 |
+
chains = chains[chains != c1]
|
| 359 |
+
interfaces.update((c1, c2) for c2 in chains)
|
| 360 |
+
|
| 361 |
+
# Get unique chain pairs
|
| 362 |
+
interfaces = [(min(i, j), max(i, j)) for i, j in interfaces]
|
| 363 |
+
interfaces = list({(int(i), int(j)) for i, j in interfaces})
|
| 364 |
+
interfaces = np.array(interfaces, dtype=Interface)
|
| 365 |
+
return interfaces
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
####################################################################################################
|
| 369 |
+
# PARSING
|
| 370 |
+
####################################################################################################
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def parse_ccd_residue( # noqa: PLR0915, C901
|
| 374 |
+
name: str,
|
| 375 |
+
components: dict[str, Mol],
|
| 376 |
+
res_idx: int,
|
| 377 |
+
gemmi_mol: Optional[gemmi.Residue] = None,
|
| 378 |
+
is_covalent: bool = False,
|
| 379 |
+
) -> Optional[ParsedResidue]:
|
| 380 |
+
"""Parse an MMCIF ligand.
|
| 381 |
+
|
| 382 |
+
First tries to get the SMILES string from the RCSB.
|
| 383 |
+
Then, tries to infer atom ordering using RDKit.
|
| 384 |
+
|
| 385 |
+
Parameters
|
| 386 |
+
----------
|
| 387 |
+
name: str
|
| 388 |
+
The name of the molecule to parse.
|
| 389 |
+
components : dict
|
| 390 |
+
The preprocessed PDB components dictionary.
|
| 391 |
+
res_idx : int
|
| 392 |
+
The residue index.
|
| 393 |
+
gemmi_mol : Optional[gemmi.Residue]
|
| 394 |
+
The PDB molecule, as a gemmi Residue object, if any.
|
| 395 |
+
|
| 396 |
+
Returns
|
| 397 |
+
-------
|
| 398 |
+
ParsedResidue, optional
|
| 399 |
+
The output ParsedResidue, if successful.
|
| 400 |
+
|
| 401 |
+
"""
|
| 402 |
+
unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
|
| 403 |
+
# Check if we have a PDB structure for this residue,
|
| 404 |
+
# it could be a missing residue from the sequence
|
| 405 |
+
is_present = gemmi_mol is not None
|
| 406 |
+
|
| 407 |
+
# Save original index (required for parsing connections)
|
| 408 |
+
if is_present:
|
| 409 |
+
orig_idx = gemmi_mol.seqid
|
| 410 |
+
orig_idx = str(orig_idx.num) + str(orig_idx.icode).strip()
|
| 411 |
+
else:
|
| 412 |
+
orig_idx = None
|
| 413 |
+
|
| 414 |
+
# Get reference component
|
| 415 |
+
ref_mol = components[name]
|
| 416 |
+
|
| 417 |
+
# Remove hydrogens
|
| 418 |
+
ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
|
| 419 |
+
|
| 420 |
+
# Check if this is a single atom CCD residue
|
| 421 |
+
if ref_mol.GetNumAtoms() == 1:
|
| 422 |
+
pos = (0, 0, 0)
|
| 423 |
+
if is_present:
|
| 424 |
+
pos = (
|
| 425 |
+
gemmi_mol[0].pos.x,
|
| 426 |
+
gemmi_mol[0].pos.y,
|
| 427 |
+
gemmi_mol[0].pos.z,
|
| 428 |
+
)
|
| 429 |
+
ref_atom = ref_mol.GetAtoms()[0]
|
| 430 |
+
chirality_type = const.chirality_type_ids.get(
|
| 431 |
+
str(ref_atom.GetChiralTag()), unk_chirality
|
| 432 |
+
)
|
| 433 |
+
atom = ParsedAtom(
|
| 434 |
+
name=ref_atom.GetProp("name"),
|
| 435 |
+
element=ref_atom.GetAtomicNum(),
|
| 436 |
+
charge=ref_atom.GetFormalCharge(),
|
| 437 |
+
coords=pos,
|
| 438 |
+
conformer=(0, 0, 0),
|
| 439 |
+
is_present=is_present,
|
| 440 |
+
chirality=chirality_type,
|
| 441 |
+
)
|
| 442 |
+
unk_prot_id = const.unk_token_ids["PROTEIN"]
|
| 443 |
+
residue = ParsedResidue(
|
| 444 |
+
name=name,
|
| 445 |
+
type=unk_prot_id,
|
| 446 |
+
atoms=[atom],
|
| 447 |
+
bonds=[],
|
| 448 |
+
idx=res_idx,
|
| 449 |
+
orig_idx=orig_idx,
|
| 450 |
+
atom_center=0, # Placeholder, no center
|
| 451 |
+
atom_disto=0, # Placeholder, no center
|
| 452 |
+
is_standard=False,
|
| 453 |
+
is_present=is_present,
|
| 454 |
+
)
|
| 455 |
+
return residue
|
| 456 |
+
|
| 457 |
+
# If multi-atom, start by getting the PDB coordinates
|
| 458 |
+
pdb_pos = {}
|
| 459 |
+
if is_present:
|
| 460 |
+
# Match atoms based on names
|
| 461 |
+
for atom in gemmi_mol:
|
| 462 |
+
atom: gemmi.Atom
|
| 463 |
+
pos = (atom.pos.x, atom.pos.y, atom.pos.z)
|
| 464 |
+
pdb_pos[atom.name] = pos
|
| 465 |
+
|
| 466 |
+
# Get reference conformer coordinates
|
| 467 |
+
conformer = get_conformer(ref_mol)
|
| 468 |
+
|
| 469 |
+
# Parse each atom in order of the reference mol
|
| 470 |
+
atoms = []
|
| 471 |
+
atom_idx = 0
|
| 472 |
+
idx_map = {} # Used for bonds later
|
| 473 |
+
|
| 474 |
+
for i, atom in enumerate(ref_mol.GetAtoms()):
|
| 475 |
+
# Get atom name, charge, element and reference coordinates
|
| 476 |
+
atom_name = atom.GetProp("name")
|
| 477 |
+
charge = atom.GetFormalCharge()
|
| 478 |
+
element = atom.GetAtomicNum()
|
| 479 |
+
ref_coords = conformer.GetAtomPosition(atom.GetIdx())
|
| 480 |
+
ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)
|
| 481 |
+
chirality_type = const.chirality_type_ids.get(
|
| 482 |
+
str(atom.GetChiralTag()), unk_chirality
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# If the atom is a leaving atom, skip if not in the PDB and is_covalent
|
| 486 |
+
if (
|
| 487 |
+
int(atom.GetProp("leaving_atom")) == 1
|
| 488 |
+
and is_covalent
|
| 489 |
+
and (atom_name not in pdb_pos)
|
| 490 |
+
):
|
| 491 |
+
continue
|
| 492 |
+
|
| 493 |
+
# Get PDB coordinates, if any
|
| 494 |
+
coords = pdb_pos.get(atom_name)
|
| 495 |
+
if coords is None:
|
| 496 |
+
atom_is_present = False
|
| 497 |
+
coords = (0, 0, 0)
|
| 498 |
+
else:
|
| 499 |
+
atom_is_present = True
|
| 500 |
+
|
| 501 |
+
# Add atom to list
|
| 502 |
+
atoms.append(
|
| 503 |
+
ParsedAtom(
|
| 504 |
+
name=atom_name,
|
| 505 |
+
element=element,
|
| 506 |
+
charge=charge,
|
| 507 |
+
coords=coords,
|
| 508 |
+
conformer=ref_coords,
|
| 509 |
+
is_present=atom_is_present,
|
| 510 |
+
chirality=chirality_type,
|
| 511 |
+
)
|
| 512 |
+
)
|
| 513 |
+
idx_map[i] = atom_idx
|
| 514 |
+
atom_idx += 1
|
| 515 |
+
|
| 516 |
+
# Load bonds
|
| 517 |
+
bonds = []
|
| 518 |
+
unk_bond = const.bond_type_ids[const.unk_bond_type]
|
| 519 |
+
for bond in ref_mol.GetBonds():
|
| 520 |
+
idx_1 = bond.GetBeginAtomIdx()
|
| 521 |
+
idx_2 = bond.GetEndAtomIdx()
|
| 522 |
+
|
| 523 |
+
# Skip bonds with atoms ignored
|
| 524 |
+
if (idx_1 not in idx_map) or (idx_2 not in idx_map):
|
| 525 |
+
continue
|
| 526 |
+
|
| 527 |
+
idx_1 = idx_map[idx_1]
|
| 528 |
+
idx_2 = idx_map[idx_2]
|
| 529 |
+
start = min(idx_1, idx_2)
|
| 530 |
+
end = max(idx_1, idx_2)
|
| 531 |
+
bond_type = bond.GetBondType().name
|
| 532 |
+
bond_type = const.bond_type_ids.get(bond_type, unk_bond)
|
| 533 |
+
bonds.append(ParsedBond(start, end, bond_type))
|
| 534 |
+
|
| 535 |
+
unk_prot_id = const.unk_token_ids["PROTEIN"]
|
| 536 |
+
return ParsedResidue(
|
| 537 |
+
name=name,
|
| 538 |
+
type=unk_prot_id,
|
| 539 |
+
atoms=atoms,
|
| 540 |
+
bonds=bonds,
|
| 541 |
+
idx=res_idx,
|
| 542 |
+
atom_center=0,
|
| 543 |
+
atom_disto=0,
|
| 544 |
+
orig_idx=orig_idx,
|
| 545 |
+
is_standard=False,
|
| 546 |
+
is_present=is_present,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def parse_polymer( # noqa: C901, PLR0915, PLR0912
|
| 551 |
+
polymer: gemmi.ResidueSpan,
|
| 552 |
+
polymer_type: gemmi.PolymerType,
|
| 553 |
+
sequence: list[str],
|
| 554 |
+
chain_id: str,
|
| 555 |
+
entity: str,
|
| 556 |
+
components: dict[str, Mol],
|
| 557 |
+
) -> Optional[ParsedChain]:
|
| 558 |
+
"""Process a gemmi Polymer into a chain object.
|
| 559 |
+
|
| 560 |
+
Performs alignment of the full sequence to the polymer
|
| 561 |
+
residues. Loads coordinates and masks for the atoms in
|
| 562 |
+
the polymer, following the ordering in const.atom_order.
|
| 563 |
+
|
| 564 |
+
Parameters
|
| 565 |
+
----------
|
| 566 |
+
polymer : gemmi.ResidueSpan
|
| 567 |
+
The polymer to process.
|
| 568 |
+
polymer_type : gemmi.PolymerType
|
| 569 |
+
The polymer type.
|
| 570 |
+
sequence : str
|
| 571 |
+
The full sequence of the polymer.
|
| 572 |
+
chain_id : str
|
| 573 |
+
The chain identifier.
|
| 574 |
+
entity : str
|
| 575 |
+
The entity name.
|
| 576 |
+
components : dict[str, Mol]
|
| 577 |
+
The preprocessed PDB components dictionary.
|
| 578 |
+
|
| 579 |
+
Returns
|
| 580 |
+
-------
|
| 581 |
+
ParsedChain, optional
|
| 582 |
+
The output chain, if successful.
|
| 583 |
+
|
| 584 |
+
Raises
|
| 585 |
+
------
|
| 586 |
+
ValueError
|
| 587 |
+
If the alignment fails.
|
| 588 |
+
|
| 589 |
+
"""
|
| 590 |
+
# Get unknown chirality token
|
| 591 |
+
unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
|
| 592 |
+
|
| 593 |
+
# Ignore microheterogenities (pick first)
|
| 594 |
+
sequence = [gemmi.Entity.first_mon(item) for item in sequence]
|
| 595 |
+
|
| 596 |
+
# Align full sequence to polymer residues
|
| 597 |
+
# This is a simple way to handle all the different numbering schemes
|
| 598 |
+
result = gemmi.align_sequence_to_polymer(
|
| 599 |
+
sequence,
|
| 600 |
+
polymer,
|
| 601 |
+
polymer_type,
|
| 602 |
+
gemmi.AlignmentScoring(),
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# Get coordinates and masks
|
| 606 |
+
i = 0
|
| 607 |
+
ref_res = set(const.tokens)
|
| 608 |
+
parsed = []
|
| 609 |
+
for j, match in enumerate(result.match_string):
|
| 610 |
+
# Get residue name from sequence
|
| 611 |
+
res_name = sequence[j]
|
| 612 |
+
|
| 613 |
+
# Check if we have a match in the structure
|
| 614 |
+
res = None
|
| 615 |
+
name_to_atom = {}
|
| 616 |
+
|
| 617 |
+
if match == "|":
|
| 618 |
+
# Get pdb residue
|
| 619 |
+
res = polymer[i]
|
| 620 |
+
name_to_atom = {a.name.upper(): a for a in res}
|
| 621 |
+
|
| 622 |
+
# Double check the match
|
| 623 |
+
if res.name != res_name:
|
| 624 |
+
msg = "Alignment mismatch!"
|
| 625 |
+
raise ValueError(msg)
|
| 626 |
+
|
| 627 |
+
# Increment polymer index
|
| 628 |
+
i += 1
|
| 629 |
+
|
| 630 |
+
# Map MSE to MET, put the selenium atom in the sulphur column
|
| 631 |
+
if res_name == "MSE":
|
| 632 |
+
res_name = "MET"
|
| 633 |
+
if "SE" in name_to_atom:
|
| 634 |
+
name_to_atom["SD"] = name_to_atom["SE"]
|
| 635 |
+
|
| 636 |
+
# Handle non-standard residues
|
| 637 |
+
elif res_name not in ref_res:
|
| 638 |
+
residue = parse_ccd_residue(
|
| 639 |
+
name=res_name,
|
| 640 |
+
components=components,
|
| 641 |
+
res_idx=j,
|
| 642 |
+
gemmi_mol=res,
|
| 643 |
+
is_covalent=True,
|
| 644 |
+
)
|
| 645 |
+
parsed.append(residue)
|
| 646 |
+
continue
|
| 647 |
+
|
| 648 |
+
# Load regular residues
|
| 649 |
+
ref_mol = components[res_name]
|
| 650 |
+
ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
|
| 651 |
+
ref_conformer = get_conformer(ref_mol)
|
| 652 |
+
|
| 653 |
+
# Only use reference atoms set in constants
|
| 654 |
+
ref_name_to_atom = {a.GetProp("name"): a for a in ref_mol.GetAtoms()}
|
| 655 |
+
ref_atoms = [ref_name_to_atom[a] for a in const.ref_atoms[res_name]]
|
| 656 |
+
|
| 657 |
+
# Iterate, always in the same order
|
| 658 |
+
atoms: list[ParsedAtom] = []
|
| 659 |
+
|
| 660 |
+
for ref_atom in ref_atoms:
|
| 661 |
+
# Get atom name
|
| 662 |
+
atom_name = ref_atom.GetProp("name")
|
| 663 |
+
idx = ref_atom.GetIdx()
|
| 664 |
+
|
| 665 |
+
# Get conformer coordinates
|
| 666 |
+
ref_coords = ref_conformer.GetAtomPosition(idx)
|
| 667 |
+
ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)
|
| 668 |
+
|
| 669 |
+
# Get coordinated from PDB
|
| 670 |
+
if atom_name in name_to_atom:
|
| 671 |
+
atom = name_to_atom[atom_name]
|
| 672 |
+
atom_is_present = True
|
| 673 |
+
coords = (atom.pos.x, atom.pos.y, atom.pos.z)
|
| 674 |
+
else:
|
| 675 |
+
atom_is_present = False
|
| 676 |
+
coords = (0, 0, 0)
|
| 677 |
+
|
| 678 |
+
# Add atom to list
|
| 679 |
+
atoms.append(
|
| 680 |
+
ParsedAtom(
|
| 681 |
+
name=atom_name,
|
| 682 |
+
element=ref_atom.GetAtomicNum(),
|
| 683 |
+
charge=ref_atom.GetFormalCharge(),
|
| 684 |
+
coords=coords,
|
| 685 |
+
conformer=ref_coords,
|
| 686 |
+
is_present=atom_is_present,
|
| 687 |
+
chirality=const.chirality_type_ids.get(
|
| 688 |
+
str(ref_atom.GetChiralTag()), unk_chirality
|
| 689 |
+
),
|
| 690 |
+
)
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
# Fix naming errors in arginine residues where NH2 is
|
| 694 |
+
# incorrectly assigned to be closer to CD than NH1
|
| 695 |
+
if (res is not None) and (res_name == "ARG"):
|
| 696 |
+
ref_atoms: list[str] = const.ref_atoms["ARG"]
|
| 697 |
+
cd = atoms[ref_atoms.index("CD")]
|
| 698 |
+
nh1 = atoms[ref_atoms.index("NH1")]
|
| 699 |
+
nh2 = atoms[ref_atoms.index("NH2")]
|
| 700 |
+
|
| 701 |
+
cd_coords = np.array(cd.coords)
|
| 702 |
+
nh1_coords = np.array(nh1.coords)
|
| 703 |
+
nh2_coords = np.array(nh2.coords)
|
| 704 |
+
|
| 705 |
+
if all(atom.is_present for atom in (cd, nh1, nh2)) and (
|
| 706 |
+
np.linalg.norm(nh1_coords - cd_coords)
|
| 707 |
+
> np.linalg.norm(nh2_coords - cd_coords)
|
| 708 |
+
):
|
| 709 |
+
atoms[ref_atoms.index("NH1")] = replace(nh1, coords=nh2.coords)
|
| 710 |
+
atoms[ref_atoms.index("NH2")] = replace(nh2, coords=nh1.coords)
|
| 711 |
+
|
| 712 |
+
# Add residue to parsed list
|
| 713 |
+
if res is not None:
|
| 714 |
+
orig_idx = res.seqid
|
| 715 |
+
orig_idx = str(orig_idx.num) + str(orig_idx.icode).strip()
|
| 716 |
+
else:
|
| 717 |
+
orig_idx = None
|
| 718 |
+
|
| 719 |
+
atom_center = const.res_to_center_atom_id[res_name]
|
| 720 |
+
atom_disto = const.res_to_disto_atom_id[res_name]
|
| 721 |
+
parsed.append(
|
| 722 |
+
ParsedResidue(
|
| 723 |
+
name=res_name,
|
| 724 |
+
type=const.token_ids[res_name],
|
| 725 |
+
atoms=atoms,
|
| 726 |
+
bonds=[],
|
| 727 |
+
idx=j,
|
| 728 |
+
atom_center=atom_center,
|
| 729 |
+
atom_disto=atom_disto,
|
| 730 |
+
is_standard=True,
|
| 731 |
+
is_present=res is not None,
|
| 732 |
+
orig_idx=orig_idx,
|
| 733 |
+
)
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# Get polymer class
|
| 737 |
+
if polymer_type == gemmi.PolymerType.PeptideL:
|
| 738 |
+
chain_type = const.chain_type_ids["PROTEIN"]
|
| 739 |
+
elif polymer_type == gemmi.PolymerType.Dna:
|
| 740 |
+
chain_type = const.chain_type_ids["DNA"]
|
| 741 |
+
elif polymer_type == gemmi.PolymerType.Rna:
|
| 742 |
+
chain_type = const.chain_type_ids["RNA"]
|
| 743 |
+
|
| 744 |
+
# Return polymer object
|
| 745 |
+
return ParsedChain(
|
| 746 |
+
name=chain_id,
|
| 747 |
+
entity=entity,
|
| 748 |
+
residues=parsed,
|
| 749 |
+
type=chain_type,
|
| 750 |
+
sequence=gemmi.one_letter_code(sequence),
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def parse_connection(
|
| 755 |
+
connection: gemmi.Connection,
|
| 756 |
+
chains: list[ParsedChain],
|
| 757 |
+
subchain_map: dict[tuple[str, int], str],
|
| 758 |
+
) -> ParsedConnection:
|
| 759 |
+
"""Parse (covalent) connection from a gemmi Connection.
|
| 760 |
+
|
| 761 |
+
Parameters
|
| 762 |
+
----------
|
| 763 |
+
connections : gemmi.ConnectionList
|
| 764 |
+
The connection list to parse.
|
| 765 |
+
chains : List[Chain]
|
| 766 |
+
The parsed chains.
|
| 767 |
+
subchain_map : dict[tuple[str, int], str]
|
| 768 |
+
The mapping from chain, residue index to subchain name.
|
| 769 |
+
|
| 770 |
+
Returns
|
| 771 |
+
-------
|
| 772 |
+
List[Connection]
|
| 773 |
+
The parsed connections.
|
| 774 |
+
|
| 775 |
+
"""
|
| 776 |
+
# Map to correct subchains
|
| 777 |
+
chain_1_name = connection.partner1.chain_name
|
| 778 |
+
chain_2_name = connection.partner2.chain_name
|
| 779 |
+
|
| 780 |
+
res_1_id = connection.partner1.res_id.seqid
|
| 781 |
+
res_1_id = str(res_1_id.num) + str(res_1_id.icode).strip()
|
| 782 |
+
|
| 783 |
+
res_2_id = connection.partner2.res_id.seqid
|
| 784 |
+
res_2_id = str(res_2_id.num) + str(res_2_id.icode).strip()
|
| 785 |
+
|
| 786 |
+
subchain_1 = subchain_map[(chain_1_name, res_1_id)]
|
| 787 |
+
subchain_2 = subchain_map[(chain_2_name, res_2_id)]
|
| 788 |
+
|
| 789 |
+
# Get chain indices
|
| 790 |
+
chain_1 = next(chain for chain in chains if (chain.name == subchain_1))
|
| 791 |
+
chain_2 = next(chain for chain in chains if (chain.name == subchain_2))
|
| 792 |
+
|
| 793 |
+
# Get residue indices
|
| 794 |
+
res_1_idx, res_1 = next(
|
| 795 |
+
(idx, res)
|
| 796 |
+
for idx, res in enumerate(chain_1.residues)
|
| 797 |
+
if (res.orig_idx == res_1_id)
|
| 798 |
+
)
|
| 799 |
+
res_2_idx, res_2 = next(
|
| 800 |
+
(idx, res)
|
| 801 |
+
for idx, res in enumerate(chain_2.residues)
|
| 802 |
+
if (res.orig_idx == res_2_id)
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
# Get atom indices
|
| 806 |
+
atom_index_1 = next(
|
| 807 |
+
idx
|
| 808 |
+
for idx, atom in enumerate(res_1.atoms)
|
| 809 |
+
if atom.name == connection.partner1.atom_name
|
| 810 |
+
)
|
| 811 |
+
atom_index_2 = next(
|
| 812 |
+
idx
|
| 813 |
+
for idx, atom in enumerate(res_2.atoms)
|
| 814 |
+
if atom.name == connection.partner2.atom_name
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
conn = ParsedConnection(
|
| 818 |
+
chain_1=subchain_1,
|
| 819 |
+
chain_2=subchain_2,
|
| 820 |
+
residue_index_1=res_1_idx,
|
| 821 |
+
residue_index_2=res_2_idx,
|
| 822 |
+
atom_index_1=atom_index_1,
|
| 823 |
+
atom_index_2=atom_index_2,
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
return conn
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
def parse_mmcif( # noqa: C901, PLR0915, PLR0912
|
| 830 |
+
path: str,
|
| 831 |
+
components: dict[str, Mol],
|
| 832 |
+
use_assembly: bool = True,
|
| 833 |
+
) -> ParsedStructure:
|
| 834 |
+
"""Parse a structure in MMCIF format.
|
| 835 |
+
|
| 836 |
+
Parameters
|
| 837 |
+
----------
|
| 838 |
+
mmcif_file : PathLike
|
| 839 |
+
Path to the MMCIF file.
|
| 840 |
+
components: dict[str, Mol]
|
| 841 |
+
The preprocessed PDB components dictionary.
|
| 842 |
+
use_assembly: bool
|
| 843 |
+
Whether to use the first assembly.
|
| 844 |
+
|
| 845 |
+
Returns
|
| 846 |
+
-------
|
| 847 |
+
ParsedStructure
|
| 848 |
+
The parsed structure.
|
| 849 |
+
|
| 850 |
+
"""
|
| 851 |
+
# Disable rdkit warnings
|
| 852 |
+
blocker = rdBase.BlockLogs() # noqa: F841
|
| 853 |
+
|
| 854 |
+
# Parse MMCIF input file
|
| 855 |
+
block = gemmi.cif.read(str(path))[0]
|
| 856 |
+
|
| 857 |
+
# Extract medatadata
|
| 858 |
+
deposit_date, release_date, revision_date = get_dates(block)
|
| 859 |
+
resolution = get_resolution(block)
|
| 860 |
+
method = get_method(block)
|
| 861 |
+
|
| 862 |
+
# Load structure object
|
| 863 |
+
structure = gemmi.make_structure_from_block(block)
|
| 864 |
+
|
| 865 |
+
# Clean up the structure
|
| 866 |
+
structure.merge_chain_parts()
|
| 867 |
+
structure.remove_waters()
|
| 868 |
+
structure.remove_hydrogens()
|
| 869 |
+
structure.remove_alternative_conformations()
|
| 870 |
+
structure.remove_empty_chains()
|
| 871 |
+
|
| 872 |
+
# Expand assembly 1
|
| 873 |
+
if use_assembly and structure.assemblies:
|
| 874 |
+
how = gemmi.HowToNameCopiedChain.AddNumber
|
| 875 |
+
assembly_name = structure.assemblies[0].name
|
| 876 |
+
structure.transform_to_assembly(assembly_name, how=how)
|
| 877 |
+
|
| 878 |
+
# Parse entities
|
| 879 |
+
# Create mapping from subchain id to entity
|
| 880 |
+
entities: dict[str, gemmi.Entity] = {}
|
| 881 |
+
entity_ids: dict[str, int] = {}
|
| 882 |
+
for entity_id, entity in enumerate(structure.entities):
|
| 883 |
+
entity: gemmi.Entity
|
| 884 |
+
if entity.entity_type.name == "Water":
|
| 885 |
+
continue
|
| 886 |
+
for subchain_id in entity.subchains:
|
| 887 |
+
entities[subchain_id] = entity
|
| 888 |
+
entity_ids[subchain_id] = entity_id
|
| 889 |
+
|
| 890 |
+
# Create mapping from chain, residue to subchains
|
| 891 |
+
# since a Connection uses the chains and not subchins
|
| 892 |
+
subchain_map = {}
|
| 893 |
+
for chain in structure[0]:
|
| 894 |
+
for residue in chain:
|
| 895 |
+
seq_id = residue.seqid
|
| 896 |
+
seq_id = str(seq_id.num) + str(seq_id.icode).strip()
|
| 897 |
+
subchain_map[(chain.name, seq_id)] = residue.subchain
|
| 898 |
+
|
| 899 |
+
# Find covalent ligands
|
| 900 |
+
covalent_chain_ids = compute_covalent_ligands(
|
| 901 |
+
connections=structure.connections,
|
| 902 |
+
subchain_map=subchain_map,
|
| 903 |
+
entities=entities,
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
# Parse chains
|
| 907 |
+
chains: list[ParsedChain] = []
|
| 908 |
+
chain_seqs = []
|
| 909 |
+
for raw_chain in structure[0].subchains():
|
| 910 |
+
# Check chain type
|
| 911 |
+
subchain_id = raw_chain.subchain_id()
|
| 912 |
+
entity: gemmi.Entity = entities[subchain_id]
|
| 913 |
+
entity_type = entity.entity_type.name
|
| 914 |
+
|
| 915 |
+
# Parse a polymer
|
| 916 |
+
if entity_type == "Polymer":
|
| 917 |
+
# Skip PeptideD, DnaRnaHybrid, Pna, Other
|
| 918 |
+
if entity.polymer_type.name not in {
|
| 919 |
+
"PeptideL",
|
| 920 |
+
"Dna",
|
| 921 |
+
"Rna",
|
| 922 |
+
}:
|
| 923 |
+
continue
|
| 924 |
+
|
| 925 |
+
# Add polymer if successful
|
| 926 |
+
parsed_polymer = parse_polymer(
|
| 927 |
+
polymer=raw_chain,
|
| 928 |
+
polymer_type=entity.polymer_type,
|
| 929 |
+
sequence=entity.full_sequence,
|
| 930 |
+
chain_id=subchain_id,
|
| 931 |
+
entity=entity.name,
|
| 932 |
+
components=components,
|
| 933 |
+
)
|
| 934 |
+
if parsed_polymer is not None:
|
| 935 |
+
chains.append(parsed_polymer)
|
| 936 |
+
chain_seqs.append(parsed_polymer.sequence)
|
| 937 |
+
|
| 938 |
+
# Parse a non-polymer
|
| 939 |
+
elif entity_type in {"NonPolymer", "Branched"}:
|
| 940 |
+
# Skip UNL or other missing ligands
|
| 941 |
+
if any(components.get(lig.name) is None for lig in raw_chain):
|
| 942 |
+
continue
|
| 943 |
+
|
| 944 |
+
residues = []
|
| 945 |
+
for lig_idx, ligand in enumerate(raw_chain):
|
| 946 |
+
# Check if ligand is covalent
|
| 947 |
+
if entity_type == "Branched":
|
| 948 |
+
is_covalent = True
|
| 949 |
+
else:
|
| 950 |
+
is_covalent = subchain_id in covalent_chain_ids
|
| 951 |
+
|
| 952 |
+
ligand: gemmi.Residue
|
| 953 |
+
residue = parse_ccd_residue(
|
| 954 |
+
name=ligand.name,
|
| 955 |
+
components=components,
|
| 956 |
+
res_idx=lig_idx,
|
| 957 |
+
gemmi_mol=ligand,
|
| 958 |
+
is_covalent=is_covalent,
|
| 959 |
+
)
|
| 960 |
+
residues.append(residue)
|
| 961 |
+
|
| 962 |
+
if residues:
|
| 963 |
+
chains.append(
|
| 964 |
+
ParsedChain(
|
| 965 |
+
name=subchain_id,
|
| 966 |
+
entity=entity.name,
|
| 967 |
+
residues=residues,
|
| 968 |
+
type=const.chain_type_ids["NONPOLYMER"],
|
| 969 |
+
sequence=None,
|
| 970 |
+
)
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
# If no chains parsed fail
|
| 974 |
+
if not chains:
|
| 975 |
+
msg = "No chains parsed!"
|
| 976 |
+
raise ValueError(msg)
|
| 977 |
+
|
| 978 |
+
# Parse covalent connections
|
| 979 |
+
connections: list[ParsedConnection] = []
|
| 980 |
+
for connection in structure.connections:
|
| 981 |
+
# Skip non-covalent connections
|
| 982 |
+
connection: gemmi.Connection
|
| 983 |
+
if connection.type.name != "Covale":
|
| 984 |
+
continue
|
| 985 |
+
|
| 986 |
+
parsed_connection = parse_connection(
|
| 987 |
+
connection=connection,
|
| 988 |
+
chains=chains,
|
| 989 |
+
subchain_map=subchain_map,
|
| 990 |
+
)
|
| 991 |
+
connections.append(parsed_connection)
|
| 992 |
+
|
| 993 |
+
# Create tables
|
| 994 |
+
atom_data = []
|
| 995 |
+
bond_data = []
|
| 996 |
+
res_data = []
|
| 997 |
+
chain_data = []
|
| 998 |
+
connection_data = []
|
| 999 |
+
|
| 1000 |
+
# Convert parsed chains to tables
|
| 1001 |
+
atom_idx = 0
|
| 1002 |
+
res_idx = 0
|
| 1003 |
+
asym_id = 0
|
| 1004 |
+
sym_count = {}
|
| 1005 |
+
chain_to_idx = {}
|
| 1006 |
+
res_to_idx = {}
|
| 1007 |
+
|
| 1008 |
+
for asym_id, chain in enumerate(chains):
|
| 1009 |
+
# Compute number of atoms and residues
|
| 1010 |
+
res_num = len(chain.residues)
|
| 1011 |
+
atom_num = sum(len(res.atoms) for res in chain.residues)
|
| 1012 |
+
|
| 1013 |
+
# Find all copies of this chain in the assembly
|
| 1014 |
+
entity_id = entity_ids[chain.name]
|
| 1015 |
+
sym_id = sym_count.get(entity_id, 0)
|
| 1016 |
+
chain_data.append(
|
| 1017 |
+
(
|
| 1018 |
+
chain.name,
|
| 1019 |
+
chain.type,
|
| 1020 |
+
entity_id,
|
| 1021 |
+
sym_id,
|
| 1022 |
+
asym_id,
|
| 1023 |
+
atom_idx,
|
| 1024 |
+
atom_num,
|
| 1025 |
+
res_idx,
|
| 1026 |
+
res_num,
|
| 1027 |
+
)
|
| 1028 |
+
)
|
| 1029 |
+
chain_to_idx[chain.name] = asym_id
|
| 1030 |
+
sym_count[entity_id] = sym_id + 1
|
| 1031 |
+
|
| 1032 |
+
# Add residue, atom, bond, data
|
| 1033 |
+
for i, res in enumerate(chain.residues):
|
| 1034 |
+
atom_center = atom_idx + res.atom_center
|
| 1035 |
+
atom_disto = atom_idx + res.atom_disto
|
| 1036 |
+
res_data.append(
|
| 1037 |
+
(
|
| 1038 |
+
res.name,
|
| 1039 |
+
res.type,
|
| 1040 |
+
res.idx,
|
| 1041 |
+
atom_idx,
|
| 1042 |
+
len(res.atoms),
|
| 1043 |
+
atom_center,
|
| 1044 |
+
atom_disto,
|
| 1045 |
+
res.is_standard,
|
| 1046 |
+
res.is_present,
|
| 1047 |
+
)
|
| 1048 |
+
)
|
| 1049 |
+
res_to_idx[(chain.name, i)] = (res_idx, atom_idx)
|
| 1050 |
+
|
| 1051 |
+
for bond in res.bonds:
|
| 1052 |
+
atom_1 = atom_idx + bond.atom_1
|
| 1053 |
+
atom_2 = atom_idx + bond.atom_2
|
| 1054 |
+
bond_data.append((atom_1, atom_2, bond.type))
|
| 1055 |
+
|
| 1056 |
+
for atom in res.atoms:
|
| 1057 |
+
atom_data.append(
|
| 1058 |
+
(
|
| 1059 |
+
convert_atom_name(atom.name),
|
| 1060 |
+
atom.element,
|
| 1061 |
+
atom.charge,
|
| 1062 |
+
atom.coords,
|
| 1063 |
+
atom.conformer,
|
| 1064 |
+
atom.is_present,
|
| 1065 |
+
atom.chirality,
|
| 1066 |
+
)
|
| 1067 |
+
)
|
| 1068 |
+
atom_idx += 1
|
| 1069 |
+
|
| 1070 |
+
res_idx += 1
|
| 1071 |
+
|
| 1072 |
+
# Convert connections to tables
|
| 1073 |
+
for conn in connections:
|
| 1074 |
+
chain_1_idx = chain_to_idx[conn.chain_1]
|
| 1075 |
+
chain_2_idx = chain_to_idx[conn.chain_2]
|
| 1076 |
+
res_1_idx, atom_1_offset = res_to_idx[(conn.chain_1, conn.residue_index_1)]
|
| 1077 |
+
res_2_idx, atom_2_offset = res_to_idx[(conn.chain_2, conn.residue_index_2)]
|
| 1078 |
+
atom_1_idx = atom_1_offset + conn.atom_index_1
|
| 1079 |
+
atom_2_idx = atom_2_offset + conn.atom_index_2
|
| 1080 |
+
connection_data.append(
|
| 1081 |
+
(
|
| 1082 |
+
chain_1_idx,
|
| 1083 |
+
chain_2_idx,
|
| 1084 |
+
res_1_idx,
|
| 1085 |
+
res_2_idx,
|
| 1086 |
+
atom_1_idx,
|
| 1087 |
+
atom_2_idx,
|
| 1088 |
+
)
|
| 1089 |
+
)
|
| 1090 |
+
|
| 1091 |
+
# Convert into datatypes
|
| 1092 |
+
atoms = np.array(atom_data, dtype=Atom)
|
| 1093 |
+
bonds = np.array(bond_data, dtype=Bond)
|
| 1094 |
+
residues = np.array(res_data, dtype=Residue)
|
| 1095 |
+
chains = np.array(chain_data, dtype=Chain)
|
| 1096 |
+
connections = np.array(connection_data, dtype=Connection)
|
| 1097 |
+
mask = np.ones(len(chain_data), dtype=bool)
|
| 1098 |
+
|
| 1099 |
+
# Compute interface chains (find chains with a heavy atom within 5A)
|
| 1100 |
+
interfaces = compute_interfaces(atoms, chains)
|
| 1101 |
+
|
| 1102 |
+
# Return parsed structure
|
| 1103 |
+
info = StructureInfo(
|
| 1104 |
+
deposited=deposit_date,
|
| 1105 |
+
revised=revision_date,
|
| 1106 |
+
released=release_date,
|
| 1107 |
+
resolution=resolution,
|
| 1108 |
+
method=method,
|
| 1109 |
+
num_chains=len(chains),
|
| 1110 |
+
num_interfaces=len(interfaces),
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
data = Structure(
|
| 1114 |
+
atoms=atoms,
|
| 1115 |
+
bonds=bonds,
|
| 1116 |
+
residues=residues,
|
| 1117 |
+
chains=chains,
|
| 1118 |
+
connections=connections,
|
| 1119 |
+
interfaces=interfaces,
|
| 1120 |
+
mask=mask,
|
| 1121 |
+
)
|
| 1122 |
+
|
| 1123 |
+
return ParsedStructure(data=data, info=info, covalents=[])
|
protify/FastPLMs/boltz/scripts/process/msa.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import multiprocessing
|
| 3 |
+
from dataclasses import asdict
|
| 4 |
+
from functools import partial
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from p_tqdm import p_umap
|
| 10 |
+
from redis import Redis
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from boltz.data.parse.a3m import parse_a3m
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Resource:
|
| 17 |
+
"""A shared resource for processing."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, host: str, port: int) -> None:
|
| 20 |
+
"""Initialize the redis database."""
|
| 21 |
+
self._redis = Redis(host=host, port=port)
|
| 22 |
+
|
| 23 |
+
def get(self, key: str) -> Any: # noqa: ANN401
|
| 24 |
+
"""Get an item from the Redis database."""
|
| 25 |
+
return self._redis.get(key)
|
| 26 |
+
|
| 27 |
+
def __getitem__(self, key: str) -> Any: # noqa: ANN401
|
| 28 |
+
"""Get an item from the resource."""
|
| 29 |
+
out = self.get(key)
|
| 30 |
+
if out is None:
|
| 31 |
+
raise KeyError(key)
|
| 32 |
+
return out
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def process_msa(
|
| 36 |
+
path: Path,
|
| 37 |
+
outdir: str,
|
| 38 |
+
max_seqs: int,
|
| 39 |
+
resource: Resource,
|
| 40 |
+
) -> None:
|
| 41 |
+
"""Run processing in a worker thread."""
|
| 42 |
+
outdir = Path(outdir)
|
| 43 |
+
out_path = outdir / f"{path.stem}.npz"
|
| 44 |
+
if not out_path.exists():
|
| 45 |
+
msa = parse_a3m(path, resource, max_seqs)
|
| 46 |
+
np.savez_compressed(out_path, **asdict(msa))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def process(args) -> None:
|
| 50 |
+
"""Run the data processing task."""
|
| 51 |
+
# Create output directory
|
| 52 |
+
args.outdir.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
# Load the resource
|
| 55 |
+
resource = Resource(host=args.redis_host, port=args.redis_port)
|
| 56 |
+
|
| 57 |
+
# Get data points
|
| 58 |
+
print("Fetching data...")
|
| 59 |
+
data = list(args.msadir.rglob("*.a3m*"))
|
| 60 |
+
print(f"Found {len(data)} MSA's.")
|
| 61 |
+
|
| 62 |
+
# Check if we can run in parallel
|
| 63 |
+
max_processes = multiprocessing.cpu_count()
|
| 64 |
+
num_processes = max(1, min(args.num_processes, max_processes, len(data)))
|
| 65 |
+
parallel = num_processes > 1
|
| 66 |
+
|
| 67 |
+
# Run processing
|
| 68 |
+
if parallel:
|
| 69 |
+
# Create processing function
|
| 70 |
+
fn = partial(
|
| 71 |
+
process_msa,
|
| 72 |
+
outdir=args.outdir,
|
| 73 |
+
max_seqs=args.max_seqs,
|
| 74 |
+
resource=resource,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Run in parallel
|
| 78 |
+
p_umap(fn, data, num_cpus=num_processes)
|
| 79 |
+
|
| 80 |
+
else:
|
| 81 |
+
# Run in serial
|
| 82 |
+
for path in tqdm(data):
|
| 83 |
+
process_msa(
|
| 84 |
+
path,
|
| 85 |
+
outdir=args.outdir,
|
| 86 |
+
max_seqs=args.max_seqs,
|
| 87 |
+
resource=resource,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
parser = argparse.ArgumentParser(description="Process MSA data.")
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--msadir",
|
| 95 |
+
type=Path,
|
| 96 |
+
required=True,
|
| 97 |
+
help="The MSA data directory.",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--outdir",
|
| 101 |
+
type=Path,
|
| 102 |
+
default="data",
|
| 103 |
+
help="The output directory.",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--num-processes",
|
| 107 |
+
type=int,
|
| 108 |
+
default=multiprocessing.cpu_count(),
|
| 109 |
+
help="The number of processes.",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--redis-host",
|
| 113 |
+
type=str,
|
| 114 |
+
default="localhost",
|
| 115 |
+
help="The Redis host.",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--redis-port",
|
| 119 |
+
type=int,
|
| 120 |
+
default=7777,
|
| 121 |
+
help="The Redis port.",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--max-seqs",
|
| 125 |
+
type=int,
|
| 126 |
+
default=16384,
|
| 127 |
+
help="The maximum number of sequences.",
|
| 128 |
+
)
|
| 129 |
+
args = parser.parse_args()
|
| 130 |
+
process(args)
|
protify/FastPLMs/boltz/scripts/process/rcsb.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import multiprocessing
|
| 4 |
+
import pickle
|
| 5 |
+
import traceback
|
| 6 |
+
from dataclasses import asdict, dataclass, replace
|
| 7 |
+
from functools import partial
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Optional
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import rdkit
|
| 13 |
+
from mmcif import parse_mmcif
|
| 14 |
+
from p_tqdm import p_umap
|
| 15 |
+
from redis import Redis
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
|
| 18 |
+
from boltz.data.filter.static.filter import StaticFilter
|
| 19 |
+
from boltz.data.filter.static.ligand import ExcludedLigands
|
| 20 |
+
from boltz.data.filter.static.polymer import (
|
| 21 |
+
ClashingChainsFilter,
|
| 22 |
+
ConsecutiveCA,
|
| 23 |
+
MinimumLengthFilter,
|
| 24 |
+
UnknownFilter,
|
| 25 |
+
)
|
| 26 |
+
from boltz.data.types import ChainInfo, InterfaceInfo, Record, Target
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True, slots=True)
|
| 30 |
+
class PDB:
|
| 31 |
+
"""A raw MMCIF PDB file."""
|
| 32 |
+
|
| 33 |
+
id: str
|
| 34 |
+
path: str
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Resource:
|
| 38 |
+
"""A shared resource for processing."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, host: str, port: int) -> None:
|
| 41 |
+
"""Initialize the redis database."""
|
| 42 |
+
self._redis = Redis(host=host, port=port)
|
| 43 |
+
|
| 44 |
+
def get(self, key: str) -> Any: # noqa: ANN401
|
| 45 |
+
"""Get an item from the Redis database."""
|
| 46 |
+
value = self._redis.get(key)
|
| 47 |
+
if value is not None:
|
| 48 |
+
value = pickle.loads(value) # noqa: S301
|
| 49 |
+
return value
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, key: str) -> Any: # noqa: ANN401
|
| 52 |
+
"""Get an item from the resource."""
|
| 53 |
+
out = self.get(key)
|
| 54 |
+
if out is None:
|
| 55 |
+
raise KeyError(key)
|
| 56 |
+
return out
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def fetch(datadir: Path, max_file_size: Optional[int] = None) -> list[PDB]:
|
| 60 |
+
"""Fetch the PDB files."""
|
| 61 |
+
data = []
|
| 62 |
+
excluded = 0
|
| 63 |
+
for file in datadir.rglob("*.cif*"):
|
| 64 |
+
# The clustering file is annotated by pdb_entity id
|
| 65 |
+
pdb_id = str(file.stem).lower()
|
| 66 |
+
|
| 67 |
+
# Check file size and skip if too large
|
| 68 |
+
if max_file_size is not None and (file.stat().st_size > max_file_size):
|
| 69 |
+
excluded += 1
|
| 70 |
+
continue
|
| 71 |
+
|
| 72 |
+
# Create the target
|
| 73 |
+
target = PDB(id=pdb_id, path=str(file))
|
| 74 |
+
data.append(target)
|
| 75 |
+
|
| 76 |
+
print(f"Excluded {excluded} files due to size.") # noqa: T201
|
| 77 |
+
return data
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def finalize(outdir: Path) -> None:
|
| 81 |
+
"""Run post-processing in main thread.
|
| 82 |
+
|
| 83 |
+
Parameters
|
| 84 |
+
----------
|
| 85 |
+
outdir : Path
|
| 86 |
+
The output directory.
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
# Group records into a manifest
|
| 90 |
+
records_dir = outdir / "records"
|
| 91 |
+
|
| 92 |
+
failed_count = 0
|
| 93 |
+
records = []
|
| 94 |
+
for record in records_dir.iterdir():
|
| 95 |
+
path = record
|
| 96 |
+
try:
|
| 97 |
+
with path.open("r") as f:
|
| 98 |
+
records.append(json.load(f))
|
| 99 |
+
except: # noqa: E722
|
| 100 |
+
failed_count += 1
|
| 101 |
+
print(f"Failed to parse {record}") # noqa: T201
|
| 102 |
+
if failed_count > 0:
|
| 103 |
+
print(f"Failed to parse {failed_count} entries.") # noqa: T201
|
| 104 |
+
else:
|
| 105 |
+
print("All entries parsed successfully.")
|
| 106 |
+
|
| 107 |
+
# Save manifest
|
| 108 |
+
outpath = outdir / "manifest.json"
|
| 109 |
+
with outpath.open("w") as f:
|
| 110 |
+
json.dump(records, f)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def parse(data: PDB, resource: Resource, clusters: dict) -> Target:
|
| 114 |
+
"""Process a structure.
|
| 115 |
+
|
| 116 |
+
Parameters
|
| 117 |
+
----------
|
| 118 |
+
data : PDB
|
| 119 |
+
The raw input data.
|
| 120 |
+
resource: Resource
|
| 121 |
+
The shared resource.
|
| 122 |
+
|
| 123 |
+
Returns
|
| 124 |
+
-------
|
| 125 |
+
Target
|
| 126 |
+
The processed data.
|
| 127 |
+
|
| 128 |
+
"""
|
| 129 |
+
# Get the PDB id
|
| 130 |
+
pdb_id = data.id.lower()
|
| 131 |
+
|
| 132 |
+
# Parse structure
|
| 133 |
+
parsed = parse_mmcif(data.path, resource)
|
| 134 |
+
structure = parsed.data
|
| 135 |
+
structure_info = parsed.info
|
| 136 |
+
|
| 137 |
+
# Create chain metadata
|
| 138 |
+
chain_info = []
|
| 139 |
+
for i, chain in enumerate(structure.chains):
|
| 140 |
+
key = f"{pdb_id}_{chain['entity_id']}"
|
| 141 |
+
chain_info.append(
|
| 142 |
+
ChainInfo(
|
| 143 |
+
chain_id=i,
|
| 144 |
+
chain_name=chain["name"],
|
| 145 |
+
msa_id="", # FIX
|
| 146 |
+
mol_type=int(chain["mol_type"]),
|
| 147 |
+
cluster_id=clusters.get(key, -1),
|
| 148 |
+
num_residues=int(chain["res_num"]),
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Get interface metadata
|
| 153 |
+
interface_info = []
|
| 154 |
+
for interface in structure.interfaces:
|
| 155 |
+
chain_1 = int(interface["chain_1"])
|
| 156 |
+
chain_2 = int(interface["chain_2"])
|
| 157 |
+
interface_info.append(
|
| 158 |
+
InterfaceInfo(
|
| 159 |
+
chain_1=chain_1,
|
| 160 |
+
chain_2=chain_2,
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Create record
|
| 165 |
+
record = Record(
|
| 166 |
+
id=data.id,
|
| 167 |
+
structure=structure_info,
|
| 168 |
+
chains=chain_info,
|
| 169 |
+
interfaces=interface_info,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return Target(structure=structure, record=record)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def process_structure(
|
| 176 |
+
data: PDB,
|
| 177 |
+
resource: Resource,
|
| 178 |
+
outdir: Path,
|
| 179 |
+
filters: list[StaticFilter],
|
| 180 |
+
clusters: dict,
|
| 181 |
+
) -> None:
|
| 182 |
+
"""Process a target.
|
| 183 |
+
|
| 184 |
+
Parameters
|
| 185 |
+
----------
|
| 186 |
+
item : PDB
|
| 187 |
+
The raw input data.
|
| 188 |
+
resource: Resource
|
| 189 |
+
The shared resource.
|
| 190 |
+
outdir : Path
|
| 191 |
+
The output directory.
|
| 192 |
+
|
| 193 |
+
"""
|
| 194 |
+
# Check if we need to process
|
| 195 |
+
struct_path = outdir / "structures" / f"{data.id}.npz"
|
| 196 |
+
record_path = outdir / "records" / f"{data.id}.json"
|
| 197 |
+
|
| 198 |
+
if struct_path.exists() and record_path.exists():
|
| 199 |
+
return
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
# Parse the target
|
| 203 |
+
target: Target = parse(data, resource, clusters)
|
| 204 |
+
structure = target.structure
|
| 205 |
+
|
| 206 |
+
# Apply the filters
|
| 207 |
+
mask = structure.mask
|
| 208 |
+
if filters is not None:
|
| 209 |
+
for f in filters:
|
| 210 |
+
filter_mask = f.filter(structure)
|
| 211 |
+
mask = mask & filter_mask
|
| 212 |
+
except Exception: # noqa: BLE001
|
| 213 |
+
traceback.print_exc()
|
| 214 |
+
print(f"Failed to parse {data.id}")
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
# Replace chains and interfaces
|
| 218 |
+
chains = []
|
| 219 |
+
for i, chain in enumerate(target.record.chains):
|
| 220 |
+
chains.append(replace(chain, valid=bool(mask[i])))
|
| 221 |
+
|
| 222 |
+
interfaces = []
|
| 223 |
+
for interface in target.record.interfaces:
|
| 224 |
+
chain_1 = bool(mask[interface.chain_1])
|
| 225 |
+
chain_2 = bool(mask[interface.chain_2])
|
| 226 |
+
interfaces.append(replace(interface, valid=(chain_1 and chain_2)))
|
| 227 |
+
|
| 228 |
+
# Replace structure and record
|
| 229 |
+
structure = replace(structure, mask=mask)
|
| 230 |
+
record = replace(target.record, chains=chains, interfaces=interfaces)
|
| 231 |
+
target = replace(target, structure=structure, record=record)
|
| 232 |
+
|
| 233 |
+
# Dump structure
|
| 234 |
+
np.savez_compressed(struct_path, **asdict(structure))
|
| 235 |
+
|
| 236 |
+
# Dump record
|
| 237 |
+
with record_path.open("w") as f:
|
| 238 |
+
json.dump(asdict(record), f)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def process(args) -> None:
|
| 242 |
+
"""Run the data processing task."""
|
| 243 |
+
# Create output directory
|
| 244 |
+
args.outdir.mkdir(parents=True, exist_ok=True)
|
| 245 |
+
|
| 246 |
+
# Create output directories
|
| 247 |
+
records_dir = args.outdir / "records"
|
| 248 |
+
records_dir.mkdir(parents=True, exist_ok=True)
|
| 249 |
+
|
| 250 |
+
structure_dir = args.outdir / "structures"
|
| 251 |
+
structure_dir.mkdir(parents=True, exist_ok=True)
|
| 252 |
+
|
| 253 |
+
# Load clusters
|
| 254 |
+
with Path(args.clusters).open("r") as f:
|
| 255 |
+
clusters: dict[str, str] = json.load(f)
|
| 256 |
+
clusters = {k.lower(): v.lower() for k, v in clusters.items()}
|
| 257 |
+
|
| 258 |
+
# Load filters
|
| 259 |
+
filters = [
|
| 260 |
+
ExcludedLigands(),
|
| 261 |
+
MinimumLengthFilter(min_len=4, max_len=5000),
|
| 262 |
+
UnknownFilter(),
|
| 263 |
+
ConsecutiveCA(max_dist=10.0),
|
| 264 |
+
ClashingChainsFilter(freq=0.3, dist=1.7),
|
| 265 |
+
]
|
| 266 |
+
|
| 267 |
+
# Set default pickle properties
|
| 268 |
+
pickle_option = rdkit.Chem.PropertyPickleOptions.AllProps
|
| 269 |
+
rdkit.Chem.SetDefaultPickleProperties(pickle_option)
|
| 270 |
+
|
| 271 |
+
# Load shared data from redis
|
| 272 |
+
resource = Resource(host=args.redis_host, port=args.redis_port)
|
| 273 |
+
|
| 274 |
+
# Get data points
|
| 275 |
+
print("Fetching data...")
|
| 276 |
+
data = fetch(args.datadir)
|
| 277 |
+
|
| 278 |
+
# Check if we can run in parallel
|
| 279 |
+
max_processes = multiprocessing.cpu_count()
|
| 280 |
+
num_processes = max(1, min(args.num_processes, max_processes, len(data)))
|
| 281 |
+
parallel = num_processes > 1
|
| 282 |
+
|
| 283 |
+
# Run processing
|
| 284 |
+
print("Processing data...")
|
| 285 |
+
if parallel:
|
| 286 |
+
# Create processing function
|
| 287 |
+
fn = partial(
|
| 288 |
+
process_structure,
|
| 289 |
+
resource=resource,
|
| 290 |
+
outdir=args.outdir,
|
| 291 |
+
clusters=clusters,
|
| 292 |
+
filters=filters,
|
| 293 |
+
)
|
| 294 |
+
# Run processing in parallel
|
| 295 |
+
p_umap(fn, data, num_cpus=num_processes)
|
| 296 |
+
else:
|
| 297 |
+
for item in tqdm(data):
|
| 298 |
+
process_structure(
|
| 299 |
+
item,
|
| 300 |
+
resource=resource,
|
| 301 |
+
outdir=args.outdir,
|
| 302 |
+
clusters=clusters,
|
| 303 |
+
filters=filters,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
# Finalize
|
| 307 |
+
finalize(args.outdir)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
if __name__ == "__main__":
|
| 311 |
+
parser = argparse.ArgumentParser(description="Process MSA data.")
|
| 312 |
+
parser.add_argument(
|
| 313 |
+
"--datadir",
|
| 314 |
+
type=Path,
|
| 315 |
+
required=True,
|
| 316 |
+
help="The data containing the MMCIF files.",
|
| 317 |
+
)
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
"--clusters",
|
| 320 |
+
type=Path,
|
| 321 |
+
required=True,
|
| 322 |
+
help="Path to the cluster file.",
|
| 323 |
+
)
|
| 324 |
+
parser.add_argument(
|
| 325 |
+
"--outdir",
|
| 326 |
+
type=Path,
|
| 327 |
+
default="data",
|
| 328 |
+
help="The output directory.",
|
| 329 |
+
)
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--num-processes",
|
| 332 |
+
type=int,
|
| 333 |
+
default=multiprocessing.cpu_count(),
|
| 334 |
+
help="The number of processes.",
|
| 335 |
+
)
|
| 336 |
+
parser.add_argument(
|
| 337 |
+
"--redis-host",
|
| 338 |
+
type=str,
|
| 339 |
+
default="localhost",
|
| 340 |
+
help="The Redis host.",
|
| 341 |
+
)
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--redis-port",
|
| 344 |
+
type=int,
|
| 345 |
+
default=7777,
|
| 346 |
+
help="The Redis port.",
|
| 347 |
+
)
|
| 348 |
+
parser.add_argument(
|
| 349 |
+
"--use-assembly",
|
| 350 |
+
action="store_true",
|
| 351 |
+
help="Whether to use assembly 1.",
|
| 352 |
+
)
|
| 353 |
+
parser.add_argument(
|
| 354 |
+
"--max-file-size",
|
| 355 |
+
type=int,
|
| 356 |
+
default=None,
|
| 357 |
+
)
|
| 358 |
+
args = parser.parse_args()
|
| 359 |
+
process(args)
|
protify/FastPLMs/boltz/scripts/train/train.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import string
|
| 4 |
+
import sys
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import hydra
|
| 10 |
+
import omegaconf
|
| 11 |
+
import pytorch_lightning as pl
|
| 12 |
+
import torch
|
| 13 |
+
import torch.multiprocessing
|
| 14 |
+
from omegaconf import OmegaConf, listconfig
|
| 15 |
+
from pytorch_lightning import LightningModule
|
| 16 |
+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
| 17 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 18 |
+
from pytorch_lightning.strategies import DDPStrategy
|
| 19 |
+
from pytorch_lightning.utilities import rank_zero_only
|
| 20 |
+
|
| 21 |
+
from boltz.data.module.training import BoltzTrainingDataModule, DataConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class TrainConfig:
|
| 26 |
+
"""Train configuration.
|
| 27 |
+
|
| 28 |
+
Attributes
|
| 29 |
+
----------
|
| 30 |
+
data : DataConfig
|
| 31 |
+
The data configuration.
|
| 32 |
+
model : ModelConfig
|
| 33 |
+
The model configuration.
|
| 34 |
+
output : str
|
| 35 |
+
The output directory.
|
| 36 |
+
trainer : Optional[dict]
|
| 37 |
+
The trainer configuration.
|
| 38 |
+
resume : Optional[str]
|
| 39 |
+
The resume checkpoint.
|
| 40 |
+
pretrained : Optional[str]
|
| 41 |
+
The pretrained model.
|
| 42 |
+
wandb : Optional[dict]
|
| 43 |
+
The wandb configuration.
|
| 44 |
+
disable_checkpoint : bool
|
| 45 |
+
Disable checkpoint.
|
| 46 |
+
matmul_precision : Optional[str]
|
| 47 |
+
The matmul precision.
|
| 48 |
+
find_unused_parameters : Optional[bool]
|
| 49 |
+
Find unused parameters.
|
| 50 |
+
save_top_k : Optional[int]
|
| 51 |
+
Save top k checkpoints.
|
| 52 |
+
validation_only : bool
|
| 53 |
+
Run validation only.
|
| 54 |
+
debug : bool
|
| 55 |
+
Debug mode.
|
| 56 |
+
strict_loading : bool
|
| 57 |
+
Fail on mismatched checkpoint weights.
|
| 58 |
+
load_confidence_from_trunk: Optional[bool]
|
| 59 |
+
Load pre-trained confidence weights from trunk.
|
| 60 |
+
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
data: DataConfig
|
| 64 |
+
model: LightningModule
|
| 65 |
+
output: str
|
| 66 |
+
trainer: Optional[dict] = None
|
| 67 |
+
resume: Optional[str] = None
|
| 68 |
+
pretrained: Optional[str] = None
|
| 69 |
+
wandb: Optional[dict] = None
|
| 70 |
+
disable_checkpoint: bool = False
|
| 71 |
+
matmul_precision: Optional[str] = None
|
| 72 |
+
find_unused_parameters: Optional[bool] = False
|
| 73 |
+
save_top_k: Optional[int] = 1
|
| 74 |
+
validation_only: bool = False
|
| 75 |
+
debug: bool = False
|
| 76 |
+
strict_loading: bool = True
|
| 77 |
+
load_confidence_from_trunk: Optional[bool] = False
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR0915
|
| 81 |
+
"""Run training.
|
| 82 |
+
|
| 83 |
+
Parameters
|
| 84 |
+
----------
|
| 85 |
+
raw_config : str
|
| 86 |
+
The input yaml configuration.
|
| 87 |
+
args : list[str]
|
| 88 |
+
Any command line overrides.
|
| 89 |
+
|
| 90 |
+
"""
|
| 91 |
+
# Load the configuration
|
| 92 |
+
raw_config = omegaconf.OmegaConf.load(raw_config)
|
| 93 |
+
|
| 94 |
+
# Apply input arguments
|
| 95 |
+
args = omegaconf.OmegaConf.from_dotlist(args)
|
| 96 |
+
raw_config = omegaconf.OmegaConf.merge(raw_config, args)
|
| 97 |
+
|
| 98 |
+
# Instantiate the task
|
| 99 |
+
cfg = hydra.utils.instantiate(raw_config)
|
| 100 |
+
cfg = TrainConfig(**cfg)
|
| 101 |
+
|
| 102 |
+
# Set matmul precision
|
| 103 |
+
if cfg.matmul_precision is not None:
|
| 104 |
+
torch.set_float32_matmul_precision(cfg.matmul_precision)
|
| 105 |
+
|
| 106 |
+
# Create trainer dict
|
| 107 |
+
trainer = cfg.trainer
|
| 108 |
+
if trainer is None:
|
| 109 |
+
trainer = {}
|
| 110 |
+
|
| 111 |
+
# Flip some arguments in debug mode
|
| 112 |
+
devices = trainer.get("devices", 1)
|
| 113 |
+
|
| 114 |
+
wandb = cfg.wandb
|
| 115 |
+
if cfg.debug:
|
| 116 |
+
if isinstance(devices, int):
|
| 117 |
+
devices = 1
|
| 118 |
+
elif isinstance(devices, (list, listconfig.ListConfig)):
|
| 119 |
+
devices = [devices[0]]
|
| 120 |
+
trainer["devices"] = devices
|
| 121 |
+
cfg.data.num_workers = 0
|
| 122 |
+
if wandb:
|
| 123 |
+
wandb = None
|
| 124 |
+
|
| 125 |
+
# Create objects
|
| 126 |
+
data_config = DataConfig(**cfg.data)
|
| 127 |
+
data_module = BoltzTrainingDataModule(data_config)
|
| 128 |
+
model_module = cfg.model
|
| 129 |
+
|
| 130 |
+
if cfg.pretrained and not cfg.resume:
|
| 131 |
+
# Load the pretrained weights into the confidence module
|
| 132 |
+
if cfg.load_confidence_from_trunk:
|
| 133 |
+
checkpoint = torch.load(cfg.pretrained, map_location="cpu")
|
| 134 |
+
|
| 135 |
+
# Modify parameter names in the state_dict
|
| 136 |
+
new_state_dict = {}
|
| 137 |
+
for key, value in checkpoint["state_dict"].items():
|
| 138 |
+
if not key.startswith("structure_module") and not key.startswith(
|
| 139 |
+
"distogram_module"
|
| 140 |
+
):
|
| 141 |
+
new_key = "confidence_module." + key
|
| 142 |
+
new_state_dict[new_key] = value
|
| 143 |
+
new_state_dict.update(checkpoint["state_dict"])
|
| 144 |
+
|
| 145 |
+
# Update the checkpoint with the new state_dict
|
| 146 |
+
checkpoint["state_dict"] = new_state_dict
|
| 147 |
+
|
| 148 |
+
# Save the modified checkpoint
|
| 149 |
+
random_string = "".join(
|
| 150 |
+
random.choices(string.ascii_lowercase + string.digits, k=10)
|
| 151 |
+
)
|
| 152 |
+
file_path = os.path.dirname(cfg.pretrained) + "/" + random_string + ".ckpt"
|
| 153 |
+
print(
|
| 154 |
+
f"Saving modified checkpoint to {file_path} created by broadcasting trunk of {cfg.pretrained} to confidence module."
|
| 155 |
+
)
|
| 156 |
+
torch.save(checkpoint, file_path)
|
| 157 |
+
else:
|
| 158 |
+
file_path = cfg.pretrained
|
| 159 |
+
|
| 160 |
+
print(f"Loading model from {file_path}")
|
| 161 |
+
model_module = type(model_module).load_from_checkpoint(
|
| 162 |
+
file_path, map_location="cpu", strict=False, **(model_module.hparams)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
if cfg.load_confidence_from_trunk:
|
| 166 |
+
os.remove(file_path)
|
| 167 |
+
|
| 168 |
+
# Create checkpoint callback
|
| 169 |
+
callbacks = []
|
| 170 |
+
dirpath = cfg.output
|
| 171 |
+
if not cfg.disable_checkpoint:
|
| 172 |
+
mc = ModelCheckpoint(
|
| 173 |
+
monitor="val/lddt",
|
| 174 |
+
save_top_k=cfg.save_top_k,
|
| 175 |
+
save_last=True,
|
| 176 |
+
mode="max",
|
| 177 |
+
every_n_epochs=1,
|
| 178 |
+
)
|
| 179 |
+
callbacks = [mc]
|
| 180 |
+
|
| 181 |
+
# Create wandb logger
|
| 182 |
+
loggers = []
|
| 183 |
+
if wandb:
|
| 184 |
+
wdb_logger = WandbLogger(
|
| 185 |
+
name=wandb["name"],
|
| 186 |
+
group=wandb["name"],
|
| 187 |
+
save_dir=cfg.output,
|
| 188 |
+
project=wandb["project"],
|
| 189 |
+
entity=wandb["entity"],
|
| 190 |
+
log_model=False,
|
| 191 |
+
)
|
| 192 |
+
loggers.append(wdb_logger)
|
| 193 |
+
# Save the config to wandb
|
| 194 |
+
|
| 195 |
+
@rank_zero_only
|
| 196 |
+
def save_config_to_wandb() -> None:
|
| 197 |
+
config_out = Path(wdb_logger.experiment.dir) / "run.yaml"
|
| 198 |
+
with Path.open(config_out, "w") as f:
|
| 199 |
+
OmegaConf.save(raw_config, f)
|
| 200 |
+
wdb_logger.experiment.save(str(config_out))
|
| 201 |
+
|
| 202 |
+
save_config_to_wandb()
|
| 203 |
+
|
| 204 |
+
# Set up trainer
|
| 205 |
+
strategy = "auto"
|
| 206 |
+
if (isinstance(devices, int) and devices > 1) or (
|
| 207 |
+
isinstance(devices, (list, listconfig.ListConfig)) and len(devices) > 1
|
| 208 |
+
):
|
| 209 |
+
strategy = DDPStrategy(find_unused_parameters=cfg.find_unused_parameters)
|
| 210 |
+
|
| 211 |
+
trainer = pl.Trainer(
|
| 212 |
+
default_root_dir=str(dirpath),
|
| 213 |
+
strategy=strategy,
|
| 214 |
+
callbacks=callbacks,
|
| 215 |
+
logger=loggers,
|
| 216 |
+
enable_checkpointing=not cfg.disable_checkpoint,
|
| 217 |
+
reload_dataloaders_every_n_epochs=1,
|
| 218 |
+
**trainer,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if not cfg.strict_loading:
|
| 222 |
+
model_module.strict_loading = False
|
| 223 |
+
|
| 224 |
+
if cfg.validation_only:
|
| 225 |
+
trainer.validate(
|
| 226 |
+
model_module,
|
| 227 |
+
datamodule=data_module,
|
| 228 |
+
ckpt_path=cfg.resume,
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
trainer.fit(
|
| 232 |
+
model_module,
|
| 233 |
+
datamodule=data_module,
|
| 234 |
+
ckpt_path=cfg.resume,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
arg1 = sys.argv[1]
|
| 240 |
+
arg2 = sys.argv[2:]
|
| 241 |
+
train(arg1, arg2)
|
protify/FastPLMs/boltz/src/boltz/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from importlib.metadata import PackageNotFoundError, version
|
| 2 |
+
|
| 3 |
+
try: # noqa: SIM105
|
| 4 |
+
__version__ = version("boltz")
|
| 5 |
+
except PackageNotFoundError:
|
| 6 |
+
# package is not installed
|
| 7 |
+
pass
|
protify/FastPLMs/boltz/src/boltz/data/__init__.py
ADDED
|
File without changes
|
protify/FastPLMs/boltz/src/boltz/data/const.py
ADDED
|
@@ -0,0 +1,1184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
####################################################################################################
|
| 2 |
+
# CHAINS
|
| 3 |
+
####################################################################################################
|
| 4 |
+
|
| 5 |
+
chain_types = [
|
| 6 |
+
"PROTEIN",
|
| 7 |
+
"DNA",
|
| 8 |
+
"RNA",
|
| 9 |
+
"NONPOLYMER",
|
| 10 |
+
]
|
| 11 |
+
chain_type_ids = {chain: i for i, chain in enumerate(chain_types)}
|
| 12 |
+
|
| 13 |
+
out_types = [
|
| 14 |
+
"dna_protein",
|
| 15 |
+
"rna_protein",
|
| 16 |
+
"ligand_protein",
|
| 17 |
+
"dna_ligand",
|
| 18 |
+
"rna_ligand",
|
| 19 |
+
"intra_ligand",
|
| 20 |
+
"intra_dna",
|
| 21 |
+
"intra_rna",
|
| 22 |
+
"intra_protein",
|
| 23 |
+
"protein_protein",
|
| 24 |
+
"modified",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
out_types_weights_af3 = {
|
| 28 |
+
"dna_protein": 10.0,
|
| 29 |
+
"rna_protein": 10.0,
|
| 30 |
+
"ligand_protein": 10.0,
|
| 31 |
+
"dna_ligand": 5.0,
|
| 32 |
+
"rna_ligand": 5.0,
|
| 33 |
+
"intra_ligand": 20.0,
|
| 34 |
+
"intra_dna": 4.0,
|
| 35 |
+
"intra_rna": 16.0,
|
| 36 |
+
"intra_protein": 20.0,
|
| 37 |
+
"protein_protein": 20.0,
|
| 38 |
+
"modified": 0.0,
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
out_types_weights = {
|
| 42 |
+
"dna_protein": 5.0,
|
| 43 |
+
"rna_protein": 5.0,
|
| 44 |
+
"ligand_protein": 20.0,
|
| 45 |
+
"dna_ligand": 2.0,
|
| 46 |
+
"rna_ligand": 2.0,
|
| 47 |
+
"intra_ligand": 20.0,
|
| 48 |
+
"intra_dna": 2.0,
|
| 49 |
+
"intra_rna": 8.0,
|
| 50 |
+
"intra_protein": 20.0,
|
| 51 |
+
"protein_protein": 20.0,
|
| 52 |
+
"modified": 0.0,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
out_single_types = ["protein", "ligand", "dna", "rna"]
|
| 57 |
+
|
| 58 |
+
clash_types = [
|
| 59 |
+
"dna_protein",
|
| 60 |
+
"rna_protein",
|
| 61 |
+
"ligand_protein",
|
| 62 |
+
"protein_protein",
|
| 63 |
+
"dna_ligand",
|
| 64 |
+
"rna_ligand",
|
| 65 |
+
"ligand_ligand",
|
| 66 |
+
"rna_dna",
|
| 67 |
+
"dna_dna",
|
| 68 |
+
"rna_rna",
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
chain_types_to_clash_type = {
|
| 72 |
+
frozenset(("PROTEIN", "DNA")): "dna_protein",
|
| 73 |
+
frozenset(("PROTEIN", "RNA")): "rna_protein",
|
| 74 |
+
frozenset(("PROTEIN", "NONPOLYMER")): "ligand_protein",
|
| 75 |
+
frozenset(("PROTEIN",)): "protein_protein",
|
| 76 |
+
frozenset(("NONPOLYMER", "DNA")): "dna_ligand",
|
| 77 |
+
frozenset(("NONPOLYMER", "RNA")): "rna_ligand",
|
| 78 |
+
frozenset(("NONPOLYMER",)): "ligand_ligand",
|
| 79 |
+
frozenset(("DNA", "RNA")): "rna_dna",
|
| 80 |
+
frozenset(("DNA",)): "dna_dna",
|
| 81 |
+
frozenset(("RNA",)): "rna_rna",
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
chain_type_to_out_single_type = {
|
| 85 |
+
"PROTEIN": "protein",
|
| 86 |
+
"DNA": "dna",
|
| 87 |
+
"RNA": "rna",
|
| 88 |
+
"NONPOLYMER": "ligand",
|
| 89 |
+
}
|
| 90 |
+
####################################################################################################
|
| 91 |
+
# RESIDUES & TOKENS
|
| 92 |
+
####################################################################################################
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
canonical_tokens = [
|
| 96 |
+
"ALA",
|
| 97 |
+
"ARG",
|
| 98 |
+
"ASN",
|
| 99 |
+
"ASP",
|
| 100 |
+
"CYS",
|
| 101 |
+
"GLN",
|
| 102 |
+
"GLU",
|
| 103 |
+
"GLY",
|
| 104 |
+
"HIS",
|
| 105 |
+
"ILE",
|
| 106 |
+
"LEU",
|
| 107 |
+
"LYS",
|
| 108 |
+
"MET",
|
| 109 |
+
"PHE",
|
| 110 |
+
"PRO",
|
| 111 |
+
"SER",
|
| 112 |
+
"THR",
|
| 113 |
+
"TRP",
|
| 114 |
+
"TYR",
|
| 115 |
+
"VAL",
|
| 116 |
+
"UNK", # unknown protein token
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
tokens = [
|
| 120 |
+
"<pad>",
|
| 121 |
+
"-",
|
| 122 |
+
*canonical_tokens,
|
| 123 |
+
"A",
|
| 124 |
+
"G",
|
| 125 |
+
"C",
|
| 126 |
+
"U",
|
| 127 |
+
"N", # unknown rna token
|
| 128 |
+
"DA",
|
| 129 |
+
"DG",
|
| 130 |
+
"DC",
|
| 131 |
+
"DT",
|
| 132 |
+
"DN", # unknown dna token
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
token_ids = {token: i for i, token in enumerate(tokens)}
|
| 136 |
+
num_tokens = len(tokens)
|
| 137 |
+
unk_token = {"PROTEIN": "UNK", "DNA": "DN", "RNA": "N"}
|
| 138 |
+
unk_token_ids = {m: token_ids[t] for m, t in unk_token.items()}
|
| 139 |
+
|
| 140 |
+
prot_letter_to_token = {
|
| 141 |
+
"A": "ALA",
|
| 142 |
+
"R": "ARG",
|
| 143 |
+
"N": "ASN",
|
| 144 |
+
"D": "ASP",
|
| 145 |
+
"C": "CYS",
|
| 146 |
+
"E": "GLU",
|
| 147 |
+
"Q": "GLN",
|
| 148 |
+
"G": "GLY",
|
| 149 |
+
"H": "HIS",
|
| 150 |
+
"I": "ILE",
|
| 151 |
+
"L": "LEU",
|
| 152 |
+
"K": "LYS",
|
| 153 |
+
"M": "MET",
|
| 154 |
+
"F": "PHE",
|
| 155 |
+
"P": "PRO",
|
| 156 |
+
"S": "SER",
|
| 157 |
+
"T": "THR",
|
| 158 |
+
"W": "TRP",
|
| 159 |
+
"Y": "TYR",
|
| 160 |
+
"V": "VAL",
|
| 161 |
+
"X": "UNK",
|
| 162 |
+
"J": "UNK",
|
| 163 |
+
"B": "UNK",
|
| 164 |
+
"Z": "UNK",
|
| 165 |
+
"O": "UNK",
|
| 166 |
+
"U": "UNK",
|
| 167 |
+
"-": "-",
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
prot_token_to_letter = {v: k for k, v in prot_letter_to_token.items()}
|
| 171 |
+
prot_token_to_letter["UNK"] = "X"
|
| 172 |
+
|
| 173 |
+
rna_letter_to_token = {
|
| 174 |
+
"A": "A",
|
| 175 |
+
"G": "G",
|
| 176 |
+
"C": "C",
|
| 177 |
+
"U": "U",
|
| 178 |
+
"N": "N",
|
| 179 |
+
}
|
| 180 |
+
rna_token_to_letter = {v: k for k, v in rna_letter_to_token.items()}
|
| 181 |
+
|
| 182 |
+
dna_letter_to_token = {
|
| 183 |
+
"A": "DA",
|
| 184 |
+
"G": "DG",
|
| 185 |
+
"C": "DC",
|
| 186 |
+
"T": "DT",
|
| 187 |
+
"N": "DN",
|
| 188 |
+
}
|
| 189 |
+
dna_token_to_letter = {v: k for k, v in dna_letter_to_token.items()}
|
| 190 |
+
|
| 191 |
+
####################################################################################################
|
| 192 |
+
# ATOMS
|
| 193 |
+
####################################################################################################
|
| 194 |
+
|
| 195 |
+
num_elements = 128
|
| 196 |
+
|
| 197 |
+
chirality_types = [
|
| 198 |
+
"CHI_UNSPECIFIED",
|
| 199 |
+
"CHI_TETRAHEDRAL_CW",
|
| 200 |
+
"CHI_TETRAHEDRAL_CCW",
|
| 201 |
+
"CHI_SQUAREPLANAR",
|
| 202 |
+
"CHI_OCTAHEDRAL",
|
| 203 |
+
"CHI_TRIGONALBIPYRAMIDAL",
|
| 204 |
+
"CHI_OTHER",
|
| 205 |
+
]
|
| 206 |
+
chirality_type_ids = {chirality: i for i, chirality in enumerate(chirality_types)}
|
| 207 |
+
unk_chirality_type = "CHI_OTHER"
|
| 208 |
+
|
| 209 |
+
hybridization_map = [
|
| 210 |
+
"S",
|
| 211 |
+
"SP",
|
| 212 |
+
"SP2",
|
| 213 |
+
"SP2D",
|
| 214 |
+
"SP3",
|
| 215 |
+
"SP3D",
|
| 216 |
+
"SP3D2",
|
| 217 |
+
"OTHER",
|
| 218 |
+
"UNSPECIFIED",
|
| 219 |
+
]
|
| 220 |
+
hybridization_type_ids = {hybrid: i for i, hybrid in enumerate(hybridization_map)}
|
| 221 |
+
unk_hybridization_type = "UNSPECIFIED"
|
| 222 |
+
|
| 223 |
+
# fmt: off
|
| 224 |
+
ref_atoms = {
|
| 225 |
+
"PAD": [],
|
| 226 |
+
"UNK": ["N", "CA", "C", "O", "CB"],
|
| 227 |
+
"-": [],
|
| 228 |
+
"ALA": ["N", "CA", "C", "O", "CB"],
|
| 229 |
+
"ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
|
| 230 |
+
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"],
|
| 231 |
+
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"],
|
| 232 |
+
"CYS": ["N", "CA", "C", "O", "CB", "SG"],
|
| 233 |
+
"GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"],
|
| 234 |
+
"GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"],
|
| 235 |
+
"GLY": ["N", "CA", "C", "O"],
|
| 236 |
+
"HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"],
|
| 237 |
+
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"],
|
| 238 |
+
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"],
|
| 239 |
+
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"],
|
| 240 |
+
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
|
| 241 |
+
"PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
|
| 242 |
+
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD"],
|
| 243 |
+
"SER": ["N", "CA", "C", "O", "CB", "OG"],
|
| 244 |
+
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2"],
|
| 245 |
+
"TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"], # noqa: E501
|
| 246 |
+
"TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"],
|
| 247 |
+
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2"],
|
| 248 |
+
"A": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"], # noqa: E501
|
| 249 |
+
"G": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"], # noqa: E501
|
| 250 |
+
"C": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"], # noqa: E501
|
| 251 |
+
"U": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C6"], # noqa: E501
|
| 252 |
+
"N": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'"], # noqa: E501
|
| 253 |
+
"DA": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"], # noqa: E501
|
| 254 |
+
"DG": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"], # noqa: E501
|
| 255 |
+
"DC": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"], # noqa: E501
|
| 256 |
+
"DT": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C7", "C6"], # noqa: E501
|
| 257 |
+
"DN": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'"]
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
protein_backbone_atom_names = ["N", "CA", "C", "O"]
|
| 261 |
+
nucleic_backbone_atom_names = ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'"]
|
| 262 |
+
|
| 263 |
+
protein_backbone_atom_index = {name: i for i, name in enumerate(protein_backbone_atom_names)}
|
| 264 |
+
nucleic_backbone_atom_index = {name: i for i, name in enumerate(nucleic_backbone_atom_names)}
|
| 265 |
+
|
| 266 |
+
ref_symmetries = {
|
| 267 |
+
"PAD": [],
|
| 268 |
+
"ALA": [],
|
| 269 |
+
"ARG": [],
|
| 270 |
+
"ASN": [],
|
| 271 |
+
"ASP": [[(6, 7), (7, 6)]],
|
| 272 |
+
"CYS": [],
|
| 273 |
+
"GLN": [],
|
| 274 |
+
"GLU": [[(7, 8), (8, 7)]],
|
| 275 |
+
"GLY": [],
|
| 276 |
+
"HIS": [],
|
| 277 |
+
"ILE": [],
|
| 278 |
+
"LEU": [],
|
| 279 |
+
"LYS": [],
|
| 280 |
+
"MET": [],
|
| 281 |
+
"PHE": [[(6, 7), (7, 6), (8, 9), (9, 8)]],
|
| 282 |
+
"PRO": [],
|
| 283 |
+
"SER": [],
|
| 284 |
+
"THR": [],
|
| 285 |
+
"TRP": [],
|
| 286 |
+
"TYR": [[(6, 7), (7, 6), (8, 9), (9, 8)]],
|
| 287 |
+
"VAL": [],
|
| 288 |
+
"A": [[(1, 2), (2, 1)]],
|
| 289 |
+
"G": [[(1, 2), (2, 1)]],
|
| 290 |
+
"C": [[(1, 2), (2, 1)]],
|
| 291 |
+
"U": [[(1, 2), (2, 1)]],
|
| 292 |
+
#"N": [[(1, 2), (2, 1)]],
|
| 293 |
+
"DA": [[(1, 2), (2, 1)]],
|
| 294 |
+
"DG": [[(1, 2), (2, 1)]],
|
| 295 |
+
"DC": [[(1, 2), (2, 1)]],
|
| 296 |
+
"DT": [[(1, 2), (2, 1)]],
|
| 297 |
+
#"DN": [[(1, 2), (2, 1)]]
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
res_to_center_atom = {
|
| 302 |
+
"UNK": "CA",
|
| 303 |
+
"ALA": "CA",
|
| 304 |
+
"ARG": "CA",
|
| 305 |
+
"ASN": "CA",
|
| 306 |
+
"ASP": "CA",
|
| 307 |
+
"CYS": "CA",
|
| 308 |
+
"GLN": "CA",
|
| 309 |
+
"GLU": "CA",
|
| 310 |
+
"GLY": "CA",
|
| 311 |
+
"HIS": "CA",
|
| 312 |
+
"ILE": "CA",
|
| 313 |
+
"LEU": "CA",
|
| 314 |
+
"LYS": "CA",
|
| 315 |
+
"MET": "CA",
|
| 316 |
+
"PHE": "CA",
|
| 317 |
+
"PRO": "CA",
|
| 318 |
+
"SER": "CA",
|
| 319 |
+
"THR": "CA",
|
| 320 |
+
"TRP": "CA",
|
| 321 |
+
"TYR": "CA",
|
| 322 |
+
"VAL": "CA",
|
| 323 |
+
"A": "C1'",
|
| 324 |
+
"G": "C1'",
|
| 325 |
+
"C": "C1'",
|
| 326 |
+
"U": "C1'",
|
| 327 |
+
"N": "C1'",
|
| 328 |
+
"DA": "C1'",
|
| 329 |
+
"DG": "C1'",
|
| 330 |
+
"DC": "C1'",
|
| 331 |
+
"DT": "C1'",
|
| 332 |
+
"DN": "C1'"
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
res_to_disto_atom = {
|
| 336 |
+
"UNK": "CB",
|
| 337 |
+
"ALA": "CB",
|
| 338 |
+
"ARG": "CB",
|
| 339 |
+
"ASN": "CB",
|
| 340 |
+
"ASP": "CB",
|
| 341 |
+
"CYS": "CB",
|
| 342 |
+
"GLN": "CB",
|
| 343 |
+
"GLU": "CB",
|
| 344 |
+
"GLY": "CA",
|
| 345 |
+
"HIS": "CB",
|
| 346 |
+
"ILE": "CB",
|
| 347 |
+
"LEU": "CB",
|
| 348 |
+
"LYS": "CB",
|
| 349 |
+
"MET": "CB",
|
| 350 |
+
"PHE": "CB",
|
| 351 |
+
"PRO": "CB",
|
| 352 |
+
"SER": "CB",
|
| 353 |
+
"THR": "CB",
|
| 354 |
+
"TRP": "CB",
|
| 355 |
+
"TYR": "CB",
|
| 356 |
+
"VAL": "CB",
|
| 357 |
+
"A": "C4",
|
| 358 |
+
"G": "C4",
|
| 359 |
+
"C": "C2",
|
| 360 |
+
"U": "C2",
|
| 361 |
+
"N": "C1'",
|
| 362 |
+
"DA": "C4",
|
| 363 |
+
"DG": "C4",
|
| 364 |
+
"DC": "C2",
|
| 365 |
+
"DT": "C2",
|
| 366 |
+
"DN": "C1'"
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
res_to_center_atom_id = {
|
| 370 |
+
res: ref_atoms[res].index(atom)
|
| 371 |
+
for res, atom in res_to_center_atom.items()
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
res_to_disto_atom_id = {
|
| 375 |
+
res: ref_atoms[res].index(atom)
|
| 376 |
+
for res, atom in res_to_disto_atom.items()
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
# fmt: on
|
| 380 |
+
|
| 381 |
+
####################################################################################################
|
| 382 |
+
# BONDS
|
| 383 |
+
####################################################################################################
|
| 384 |
+
|
| 385 |
+
atom_interface_cutoff = 5.0
|
| 386 |
+
interface_cutoff = 15.0
|
| 387 |
+
|
| 388 |
+
bond_types = [
|
| 389 |
+
"OTHER",
|
| 390 |
+
"SINGLE",
|
| 391 |
+
"DOUBLE",
|
| 392 |
+
"TRIPLE",
|
| 393 |
+
"AROMATIC",
|
| 394 |
+
"COVALENT",
|
| 395 |
+
]
|
| 396 |
+
bond_type_ids = {bond: i for i, bond in enumerate(bond_types)}
|
| 397 |
+
unk_bond_type = "OTHER"
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
####################################################################################################
|
| 401 |
+
# Contacts
|
| 402 |
+
####################################################################################################
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
pocket_contact_info = {
|
| 406 |
+
"UNSPECIFIED": 0,
|
| 407 |
+
"UNSELECTED": 1,
|
| 408 |
+
"POCKET": 2,
|
| 409 |
+
"BINDER": 3,
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
contact_conditioning_info = {
|
| 413 |
+
"UNSPECIFIED": 0,
|
| 414 |
+
"UNSELECTED": 1,
|
| 415 |
+
"POCKET>BINDER": 2,
|
| 416 |
+
"BINDER>POCKET": 3,
|
| 417 |
+
"CONTACT": 4,
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
####################################################################################################
|
| 422 |
+
# MSA
|
| 423 |
+
####################################################################################################
|
| 424 |
+
|
| 425 |
+
max_msa_seqs = 16384
|
| 426 |
+
max_paired_seqs = 8192
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
####################################################################################################
|
| 430 |
+
# CHUNKING
|
| 431 |
+
####################################################################################################
|
| 432 |
+
|
| 433 |
+
chunk_size_threshold = 384
|
| 434 |
+
|
| 435 |
+
####################################################################################################
|
| 436 |
+
# Method conditioning
|
| 437 |
+
####################################################################################################
|
| 438 |
+
|
| 439 |
+
# Methods
|
| 440 |
+
method_types_ids = {
|
| 441 |
+
"MD": 0,
|
| 442 |
+
"X-RAY DIFFRACTION": 1,
|
| 443 |
+
"ELECTRON MICROSCOPY": 2,
|
| 444 |
+
"SOLUTION NMR": 3,
|
| 445 |
+
"SOLID-STATE NMR": 4,
|
| 446 |
+
"NEUTRON DIFFRACTION": 4,
|
| 447 |
+
"ELECTRON CRYSTALLOGRAPHY": 4,
|
| 448 |
+
"FIBER DIFFRACTION": 4,
|
| 449 |
+
"POWDER DIFFRACTION": 4,
|
| 450 |
+
"INFRARED SPECTROSCOPY": 4,
|
| 451 |
+
"FLUORESCENCE TRANSFER": 4,
|
| 452 |
+
"EPR": 4,
|
| 453 |
+
"THEORETICAL MODEL": 4,
|
| 454 |
+
"SOLUTION SCATTERING": 4,
|
| 455 |
+
"OTHER": 4,
|
| 456 |
+
"AFDB": 5,
|
| 457 |
+
"BOLTZ-1": 6,
|
| 458 |
+
"FUTURE1": 7, # Placeholder for future supervision sources
|
| 459 |
+
"FUTURE2": 8,
|
| 460 |
+
"FUTURE3": 9,
|
| 461 |
+
"FUTURE4": 10,
|
| 462 |
+
"FUTURE5": 11,
|
| 463 |
+
}
|
| 464 |
+
method_types_ids = {k.lower(): v for k, v in method_types_ids.items()}
|
| 465 |
+
num_method_types = len(set(method_types_ids.values()))
|
| 466 |
+
|
| 467 |
+
# Temperature
|
| 468 |
+
temperature_bins = [(265, 280), (280, 295), (295, 310)]
|
| 469 |
+
temperature_bins_ids = {temp: i for i, temp in enumerate(temperature_bins)}
|
| 470 |
+
temperature_bins_ids["other"] = len(temperature_bins)
|
| 471 |
+
num_temp_bins = len(temperature_bins_ids)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
# pH
|
| 475 |
+
ph_bins = [(0, 6), (6, 8), (8, 14)]
|
| 476 |
+
ph_bins_ids = {ph: i for i, ph in enumerate(ph_bins)}
|
| 477 |
+
ph_bins_ids["other"] = len(ph_bins)
|
| 478 |
+
num_ph_bins = len(ph_bins_ids)
|
| 479 |
+
|
| 480 |
+
####################################################################################################
|
| 481 |
+
# VDW_RADII
|
| 482 |
+
####################################################################################################
|
| 483 |
+
|
| 484 |
+
# fmt: off
|
| 485 |
+
vdw_radii = [
|
| 486 |
+
1.2, 1.4, 2.2, 1.9, 1.8, 1.7, 1.6, 1.55, 1.5, 1.54,
|
| 487 |
+
2.4, 2.2, 2.1, 2.1, 1.95, 1.8, 1.8, 1.88, 2.8, 2.4,
|
| 488 |
+
2.3, 2.15, 2.05, 2.05, 2.05, 2.05, 2.0, 2.0, 2.0, 2.1,
|
| 489 |
+
2.1, 2.1, 2.05, 1.9, 1.9, 2.02, 2.9, 2.55, 2.4, 2.3,
|
| 490 |
+
2.15, 2.1, 2.05, 2.05, 2.0, 2.05, 2.1, 2.2, 2.2, 2.25,
|
| 491 |
+
2.2, 2.1, 2.1, 2.16, 3.0, 2.7, 2.5, 2.48, 2.47, 2.45,
|
| 492 |
+
2.43, 2.42, 2.4, 2.38, 2.37, 2.35, 2.33, 2.32, 2.3, 2.28,
|
| 493 |
+
2.27, 2.25, 2.2, 2.1, 2.05, 2.0, 2.0, 2.05, 2.1, 2.05,
|
| 494 |
+
2.2, 2.3, 2.3, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.4,
|
| 495 |
+
2.0, 2.3, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
|
| 496 |
+
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
|
| 497 |
+
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0
|
| 498 |
+
]
|
| 499 |
+
# fmt: on
|
| 500 |
+
|
| 501 |
+
####################################################################################################
|
| 502 |
+
# Excluded ligands
|
| 503 |
+
####################################################################################################
|
| 504 |
+
|
| 505 |
+
ligand_exclusion = {
|
| 506 |
+
"144",
|
| 507 |
+
"15P",
|
| 508 |
+
"1PE",
|
| 509 |
+
"2F2",
|
| 510 |
+
"2JC",
|
| 511 |
+
"3HR",
|
| 512 |
+
"3SY",
|
| 513 |
+
"7N5",
|
| 514 |
+
"7PE",
|
| 515 |
+
"9JE",
|
| 516 |
+
"AAE",
|
| 517 |
+
"ABA",
|
| 518 |
+
"ACE",
|
| 519 |
+
"ACN",
|
| 520 |
+
"ACT",
|
| 521 |
+
"ACY",
|
| 522 |
+
"AZI",
|
| 523 |
+
"BAM",
|
| 524 |
+
"BCN",
|
| 525 |
+
"BCT",
|
| 526 |
+
"BDN",
|
| 527 |
+
"BEN",
|
| 528 |
+
"BME",
|
| 529 |
+
"BO3",
|
| 530 |
+
"BTB",
|
| 531 |
+
"BTC",
|
| 532 |
+
"BU1",
|
| 533 |
+
"C8E",
|
| 534 |
+
"CAD",
|
| 535 |
+
"CAQ",
|
| 536 |
+
"CBM",
|
| 537 |
+
"CCN",
|
| 538 |
+
"CIT",
|
| 539 |
+
"CL",
|
| 540 |
+
"CLR",
|
| 541 |
+
"CM",
|
| 542 |
+
"CMO",
|
| 543 |
+
"CO3",
|
| 544 |
+
"CPT",
|
| 545 |
+
"CXS",
|
| 546 |
+
"D10",
|
| 547 |
+
"DEP",
|
| 548 |
+
"DIO",
|
| 549 |
+
"DMS",
|
| 550 |
+
"DN",
|
| 551 |
+
"DOD",
|
| 552 |
+
"DOX",
|
| 553 |
+
"EDO",
|
| 554 |
+
"EEE",
|
| 555 |
+
"EGL",
|
| 556 |
+
"EOH",
|
| 557 |
+
"EOX",
|
| 558 |
+
"EPE",
|
| 559 |
+
"ETF",
|
| 560 |
+
"FCY",
|
| 561 |
+
"FJO",
|
| 562 |
+
"FLC",
|
| 563 |
+
"FMT",
|
| 564 |
+
"FW5",
|
| 565 |
+
"GOL",
|
| 566 |
+
"GSH",
|
| 567 |
+
"GTT",
|
| 568 |
+
"GYF",
|
| 569 |
+
"HED",
|
| 570 |
+
"IHP",
|
| 571 |
+
"IHS",
|
| 572 |
+
"IMD",
|
| 573 |
+
"IOD",
|
| 574 |
+
"IPA",
|
| 575 |
+
"IPH",
|
| 576 |
+
"LDA",
|
| 577 |
+
"MB3",
|
| 578 |
+
"MEG",
|
| 579 |
+
"MES",
|
| 580 |
+
"MLA",
|
| 581 |
+
"MLI",
|
| 582 |
+
"MOH",
|
| 583 |
+
"MPD",
|
| 584 |
+
"MRD",
|
| 585 |
+
"MSE",
|
| 586 |
+
"MYR",
|
| 587 |
+
"N",
|
| 588 |
+
"NA",
|
| 589 |
+
"NH2",
|
| 590 |
+
"NH4",
|
| 591 |
+
"NHE",
|
| 592 |
+
"NO3",
|
| 593 |
+
"O4B",
|
| 594 |
+
"OHE",
|
| 595 |
+
"OLA",
|
| 596 |
+
"OLC",
|
| 597 |
+
"OMB",
|
| 598 |
+
"OME",
|
| 599 |
+
"OXA",
|
| 600 |
+
"P6G",
|
| 601 |
+
"PE3",
|
| 602 |
+
"PE4",
|
| 603 |
+
"PEG",
|
| 604 |
+
"PEO",
|
| 605 |
+
"PEP",
|
| 606 |
+
"PG0",
|
| 607 |
+
"PG4",
|
| 608 |
+
"PGE",
|
| 609 |
+
"PGR",
|
| 610 |
+
"PLM",
|
| 611 |
+
"PO4",
|
| 612 |
+
"POL",
|
| 613 |
+
"POP",
|
| 614 |
+
"PVO",
|
| 615 |
+
"SAR",
|
| 616 |
+
"SCN",
|
| 617 |
+
"SEO",
|
| 618 |
+
"SEP",
|
| 619 |
+
"SIN",
|
| 620 |
+
"SO4",
|
| 621 |
+
"SPD",
|
| 622 |
+
"SPM",
|
| 623 |
+
"SR",
|
| 624 |
+
"STE",
|
| 625 |
+
"STO",
|
| 626 |
+
"STU",
|
| 627 |
+
"TAR",
|
| 628 |
+
"TBU",
|
| 629 |
+
"TME",
|
| 630 |
+
"TPO",
|
| 631 |
+
"TRS",
|
| 632 |
+
"UNK",
|
| 633 |
+
"UNL",
|
| 634 |
+
"UNX",
|
| 635 |
+
"UPL",
|
| 636 |
+
"URE",
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
####################################################################################################
|
| 641 |
+
# TEMPLATES
|
| 642 |
+
####################################################################################################
|
| 643 |
+
|
| 644 |
+
min_coverage_residues = 10
|
| 645 |
+
min_coverage_fraction = 0.1
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
####################################################################################################
|
| 649 |
+
# Ambiguous atoms
|
| 650 |
+
####################################################################################################
|
| 651 |
+
|
| 652 |
+
ambiguous_atoms = {
|
| 653 |
+
"CA": {
|
| 654 |
+
"*": "C",
|
| 655 |
+
"OEX": "CA",
|
| 656 |
+
"OEC": "CA",
|
| 657 |
+
"543": "CA",
|
| 658 |
+
"OC6": "CA",
|
| 659 |
+
"OC1": "CA",
|
| 660 |
+
"OC7": "CA",
|
| 661 |
+
"OEY": "CA",
|
| 662 |
+
"OC4": "CA",
|
| 663 |
+
"OC3": "CA",
|
| 664 |
+
"ICA": "CA",
|
| 665 |
+
"CA": "CA",
|
| 666 |
+
"OC2": "CA",
|
| 667 |
+
"OC5": "CA",
|
| 668 |
+
},
|
| 669 |
+
"CD": {"*": "C", "CD": "CD", "CD3": "CD", "CD5": "CD", "CD1": "CD"},
|
| 670 |
+
"BR": "BR",
|
| 671 |
+
"CL": {
|
| 672 |
+
"*": "CL",
|
| 673 |
+
"C8P": "C",
|
| 674 |
+
"L3T": "C",
|
| 675 |
+
"TLC": "C",
|
| 676 |
+
"TZ0": "C",
|
| 677 |
+
"471": "C",
|
| 678 |
+
"NLK": "C",
|
| 679 |
+
"PGM": "C",
|
| 680 |
+
"PNE": "C",
|
| 681 |
+
"RCY": "C",
|
| 682 |
+
"11F": "C",
|
| 683 |
+
"PII": "C",
|
| 684 |
+
"C1Q": "C",
|
| 685 |
+
"4MD": "C",
|
| 686 |
+
"R5A": "C",
|
| 687 |
+
"KW2": "C",
|
| 688 |
+
"I7M": "C",
|
| 689 |
+
"R48": "C",
|
| 690 |
+
"FC3": "C",
|
| 691 |
+
"55V": "C",
|
| 692 |
+
"KPF": "C",
|
| 693 |
+
"SPZ": "C",
|
| 694 |
+
"0TT": "C",
|
| 695 |
+
"R9A": "C",
|
| 696 |
+
"5NA": "C",
|
| 697 |
+
"C55": "C",
|
| 698 |
+
"NIX": "C",
|
| 699 |
+
"5PM": "C",
|
| 700 |
+
"PP8": "C",
|
| 701 |
+
"544": "C",
|
| 702 |
+
"812": "C",
|
| 703 |
+
"NPM": "C",
|
| 704 |
+
"KU8": "C",
|
| 705 |
+
"A1AMM": "C",
|
| 706 |
+
"4S0": "C",
|
| 707 |
+
"AQC": "C",
|
| 708 |
+
"2JK": "C",
|
| 709 |
+
"WJR": "C",
|
| 710 |
+
"A1AAW": "C",
|
| 711 |
+
"85E": "C",
|
| 712 |
+
"MB0": "C",
|
| 713 |
+
"ZAB": "C",
|
| 714 |
+
"85K": "C",
|
| 715 |
+
"GBP": "C",
|
| 716 |
+
"A1H80": "C",
|
| 717 |
+
"A1AFR": "C",
|
| 718 |
+
"L9M": "C",
|
| 719 |
+
"MYK": "C",
|
| 720 |
+
"MB9": "C",
|
| 721 |
+
"38R": "C",
|
| 722 |
+
"EKB": "C",
|
| 723 |
+
"NKF": "C",
|
| 724 |
+
"UMQ": "C",
|
| 725 |
+
"T4K": "C",
|
| 726 |
+
"3PT": "C",
|
| 727 |
+
"A1A7S": "C",
|
| 728 |
+
"1Q9": "C",
|
| 729 |
+
"11R": "C",
|
| 730 |
+
"D2V": "C",
|
| 731 |
+
"SM8": "C",
|
| 732 |
+
"IFC": "C",
|
| 733 |
+
"DB5": "C",
|
| 734 |
+
"L2T": "C",
|
| 735 |
+
"GNB": "C",
|
| 736 |
+
"PP7": "C",
|
| 737 |
+
"072": "C",
|
| 738 |
+
"P88": "C",
|
| 739 |
+
"DRL": "C",
|
| 740 |
+
"C9W": "C",
|
| 741 |
+
"NTP": "C",
|
| 742 |
+
"4HJ": "C",
|
| 743 |
+
"7NA": "C",
|
| 744 |
+
"LPC": "C",
|
| 745 |
+
"T8W": "C",
|
| 746 |
+
"63R": "C",
|
| 747 |
+
"570": "C",
|
| 748 |
+
"R4A": "C",
|
| 749 |
+
"3BG": "C",
|
| 750 |
+
"4RB": "C",
|
| 751 |
+
"GSO": "C",
|
| 752 |
+
"BQ6": "C",
|
| 753 |
+
"R4P": "C",
|
| 754 |
+
"5CP": "C",
|
| 755 |
+
"TTR": "C",
|
| 756 |
+
"6UZ": "C",
|
| 757 |
+
"SPJ": "C",
|
| 758 |
+
"0SA": "C",
|
| 759 |
+
"ZL1": "C",
|
| 760 |
+
"BYG": "C",
|
| 761 |
+
"F0E": "C",
|
| 762 |
+
"PC0": "C",
|
| 763 |
+
"B2Q": "C",
|
| 764 |
+
"KV6": "C",
|
| 765 |
+
"NTO": "C",
|
| 766 |
+
"CLG": "C",
|
| 767 |
+
"R7U": "C",
|
| 768 |
+
"SMQ": "C",
|
| 769 |
+
"GM2": "C",
|
| 770 |
+
"Z7P": "C",
|
| 771 |
+
"NXF": "C",
|
| 772 |
+
"C6Q": "C",
|
| 773 |
+
"A1G": "C",
|
| 774 |
+
"433": "C",
|
| 775 |
+
"L9N": "C",
|
| 776 |
+
"7OX": "C",
|
| 777 |
+
"A1H84": "C",
|
| 778 |
+
"97L": "C",
|
| 779 |
+
"HDV": "C",
|
| 780 |
+
"LUO": "C",
|
| 781 |
+
"R6A": "C",
|
| 782 |
+
"1PC": "C",
|
| 783 |
+
"4PT": "C",
|
| 784 |
+
"SBZ": "C",
|
| 785 |
+
"EAB": "C",
|
| 786 |
+
"FL4": "C",
|
| 787 |
+
"OPS": "C",
|
| 788 |
+
"C2X": "C",
|
| 789 |
+
"SLL": "C",
|
| 790 |
+
"BFC": "C",
|
| 791 |
+
"GIP": "C",
|
| 792 |
+
"7CP": "C",
|
| 793 |
+
"CLH": "C",
|
| 794 |
+
"34E": "C",
|
| 795 |
+
"5NE": "C",
|
| 796 |
+
"PBF": "C",
|
| 797 |
+
"ABD": "C",
|
| 798 |
+
"ABC": "C",
|
| 799 |
+
"LPF": "C",
|
| 800 |
+
"TIZ": "C",
|
| 801 |
+
"4HH": "C",
|
| 802 |
+
"AFC": "C",
|
| 803 |
+
"WQH": "C",
|
| 804 |
+
"9JL": "C",
|
| 805 |
+
"CS3": "C",
|
| 806 |
+
"NL0": "C",
|
| 807 |
+
"KPY": "C",
|
| 808 |
+
"DNA": "C",
|
| 809 |
+
"B3C": "C",
|
| 810 |
+
"TKL": "C",
|
| 811 |
+
"KVS": "C",
|
| 812 |
+
"HO6": "C",
|
| 813 |
+
"NLH": "C",
|
| 814 |
+
"1PB": "C",
|
| 815 |
+
"CYF": "C",
|
| 816 |
+
"G4M": "C",
|
| 817 |
+
"R5B": "C",
|
| 818 |
+
"N4S": "C",
|
| 819 |
+
"N11": "C",
|
| 820 |
+
"C8F": "C",
|
| 821 |
+
"PIJ": "C",
|
| 822 |
+
"WIN": "C",
|
| 823 |
+
"NT1": "C",
|
| 824 |
+
"WJW": "C",
|
| 825 |
+
"HF7": "C",
|
| 826 |
+
"TY1": "C",
|
| 827 |
+
"VM1": "C",
|
| 828 |
+
},
|
| 829 |
+
"OS": {"*": "O", "DWC": "OS", "OHX": "OS", "OS": "OS", "8WV": "OS", "OS4": "OS"},
|
| 830 |
+
"PB": {"*": "P", "ZN9": "PB", "ZN7": "PB", "PBM": "PB", "PB": "PB", "CSB": "PB"},
|
| 831 |
+
"CE": {"*": "C", "CE": "CE"},
|
| 832 |
+
"FE": {"*": "FE", "TFR": "F", "PF5": "F", "IFC": "F", "F5C": "F"},
|
| 833 |
+
"NA": {"*": "N", "CGO": "NA", "R2K": "NA", "LVQ": "NA", "NA": "NA"},
|
| 834 |
+
"ND": {"*": "N", "ND": "ND"},
|
| 835 |
+
"CF": {"*": "C", "CF": "CF"},
|
| 836 |
+
"RU": "RU",
|
| 837 |
+
"BRAF": "BR",
|
| 838 |
+
"EU": "EU",
|
| 839 |
+
"CLAA": "CL",
|
| 840 |
+
"CLBQ": "CL",
|
| 841 |
+
"CM": {"*": "C", "ZCM": "CM"},
|
| 842 |
+
"SN": {"*": "SN", "TAP": "S", "SND": "S", "TAD": "S", "XPT": "S"},
|
| 843 |
+
"AG": "AG",
|
| 844 |
+
"CLN": "CL",
|
| 845 |
+
"CLM": "CL",
|
| 846 |
+
"CLA": {"*": "CL", "PII": "C", "TDL": "C", "D0J": "C", "GM2": "C", "PIJ": "C"},
|
| 847 |
+
"CLB": {
|
| 848 |
+
"*": "CL",
|
| 849 |
+
"TD5": "C",
|
| 850 |
+
"PII": "C",
|
| 851 |
+
"TDL": "C",
|
| 852 |
+
"GM2": "C",
|
| 853 |
+
"TD7": "C",
|
| 854 |
+
"TD6": "C",
|
| 855 |
+
"PIJ": "C",
|
| 856 |
+
},
|
| 857 |
+
"CR": {
|
| 858 |
+
"*": "C",
|
| 859 |
+
"BW9": "CR",
|
| 860 |
+
"CQ4": "CR",
|
| 861 |
+
"AC9": "CR",
|
| 862 |
+
"TIL": "CR",
|
| 863 |
+
"J7U": "CR",
|
| 864 |
+
"CR": "CR",
|
| 865 |
+
},
|
| 866 |
+
"CLAY": "CL",
|
| 867 |
+
"CLBC": "CL",
|
| 868 |
+
"PD": {
|
| 869 |
+
"*": "P",
|
| 870 |
+
"F6Q": "PD",
|
| 871 |
+
"SVP": "PD",
|
| 872 |
+
"SXC": "PD",
|
| 873 |
+
"U5U": "PD",
|
| 874 |
+
"PD": "PD",
|
| 875 |
+
"PLL": "PD",
|
| 876 |
+
},
|
| 877 |
+
"CO": {
|
| 878 |
+
"*": "C",
|
| 879 |
+
"J1S": "CO",
|
| 880 |
+
"OCN": "CO",
|
| 881 |
+
"OL3": "CO",
|
| 882 |
+
"OL4": "CO",
|
| 883 |
+
"B12": "CO",
|
| 884 |
+
"XCO": "CO",
|
| 885 |
+
"UFU": "CO",
|
| 886 |
+
"CON": "CO",
|
| 887 |
+
"OL5": "CO",
|
| 888 |
+
"B13": "CO",
|
| 889 |
+
"7KI": "CO",
|
| 890 |
+
"PL1": "CO",
|
| 891 |
+
"OCO": "CO",
|
| 892 |
+
"J1R": "CO",
|
| 893 |
+
"COH": "CO",
|
| 894 |
+
"SIR": "CO",
|
| 895 |
+
"6KI": "CO",
|
| 896 |
+
"NCO": "CO",
|
| 897 |
+
"9CO": "CO",
|
| 898 |
+
"PC3": "CO",
|
| 899 |
+
"BWU": "CO",
|
| 900 |
+
"B1Z": "CO",
|
| 901 |
+
"J83": "CO",
|
| 902 |
+
"CO": "CO",
|
| 903 |
+
"COY": "CO",
|
| 904 |
+
"CNC": "CO",
|
| 905 |
+
"3CO": "CO",
|
| 906 |
+
"OCL": "CO",
|
| 907 |
+
"R5Q": "CO",
|
| 908 |
+
"X5Z": "CO",
|
| 909 |
+
"CBY": "CO",
|
| 910 |
+
"OLS": "CO",
|
| 911 |
+
"F0X": "CO",
|
| 912 |
+
"I2A": "CO",
|
| 913 |
+
"OCM": "CO",
|
| 914 |
+
},
|
| 915 |
+
"CU": {
|
| 916 |
+
"*": "C",
|
| 917 |
+
"8ZR": "CU",
|
| 918 |
+
"K7E": "CU",
|
| 919 |
+
"CU3": "CU",
|
| 920 |
+
"SI9": "CU",
|
| 921 |
+
"35N": "CU",
|
| 922 |
+
"C2O": "CU",
|
| 923 |
+
"SI7": "CU",
|
| 924 |
+
"B15": "CU",
|
| 925 |
+
"SI0": "CU",
|
| 926 |
+
"CUP": "CU",
|
| 927 |
+
"SQ1": "CU",
|
| 928 |
+
"CUK": "CU",
|
| 929 |
+
"CUL": "CU",
|
| 930 |
+
"SI8": "CU",
|
| 931 |
+
"IC4": "CU",
|
| 932 |
+
"CUM": "CU",
|
| 933 |
+
"MM2": "CU",
|
| 934 |
+
"B30": "CU",
|
| 935 |
+
"S32": "CU",
|
| 936 |
+
"V79": "CU",
|
| 937 |
+
"IMF": "CU",
|
| 938 |
+
"CUN": "CU",
|
| 939 |
+
"MM1": "CU",
|
| 940 |
+
"MP1": "CU",
|
| 941 |
+
"IME": "CU",
|
| 942 |
+
"B17": "CU",
|
| 943 |
+
"C2C": "CU",
|
| 944 |
+
"1CU": "CU",
|
| 945 |
+
"CU6": "CU",
|
| 946 |
+
"C1O": "CU",
|
| 947 |
+
"CU1": "CU",
|
| 948 |
+
"B22": "CU",
|
| 949 |
+
"CUS": "CU",
|
| 950 |
+
"RUQ": "CU",
|
| 951 |
+
"CUF": "CU",
|
| 952 |
+
"CUA": "CU",
|
| 953 |
+
"CU": "CU",
|
| 954 |
+
"CUO": "CU",
|
| 955 |
+
"0TE": "CU",
|
| 956 |
+
"SI4": "CU",
|
| 957 |
+
},
|
| 958 |
+
"CS": {"*": "C", "CS": "CS"},
|
| 959 |
+
"CLQ": "CL",
|
| 960 |
+
"CLR": "CL",
|
| 961 |
+
"CLU": "CL",
|
| 962 |
+
"TE": "TE",
|
| 963 |
+
"NI": {
|
| 964 |
+
"*": "N",
|
| 965 |
+
"USN": "NI",
|
| 966 |
+
"NFO": "NI",
|
| 967 |
+
"NI2": "NI",
|
| 968 |
+
"NFS": "NI",
|
| 969 |
+
"NFR": "NI",
|
| 970 |
+
"82N": "NI",
|
| 971 |
+
"R5N": "NI",
|
| 972 |
+
"NFU": "NI",
|
| 973 |
+
"A1ICD": "NI",
|
| 974 |
+
"NI3": "NI",
|
| 975 |
+
"M43": "NI",
|
| 976 |
+
"MM5": "NI",
|
| 977 |
+
"BF8": "NI",
|
| 978 |
+
"TCN": "NI",
|
| 979 |
+
"NIK": "NI",
|
| 980 |
+
"CUV": "NI",
|
| 981 |
+
"MM6": "NI",
|
| 982 |
+
"J52": "NI",
|
| 983 |
+
"NI": "NI",
|
| 984 |
+
"SNF": "NI",
|
| 985 |
+
"XCC": "NI",
|
| 986 |
+
"F0L": "NI",
|
| 987 |
+
"UWE": "NI",
|
| 988 |
+
"NFC": "NI",
|
| 989 |
+
"3NI": "NI",
|
| 990 |
+
"HNI": "NI",
|
| 991 |
+
"F43": "NI",
|
| 992 |
+
"RQM": "NI",
|
| 993 |
+
"NFE": "NI",
|
| 994 |
+
"NFB": "NI",
|
| 995 |
+
"B51": "NI",
|
| 996 |
+
"NI1": "NI",
|
| 997 |
+
"WCC": "NI",
|
| 998 |
+
"NUF": "NI",
|
| 999 |
+
},
|
| 1000 |
+
"SB": {"*": "S", "UJI": "SB", "SB": "SB", "118": "SB", "SBO": "SB", "3CG": "SB"},
|
| 1001 |
+
"MO": "MO",
|
| 1002 |
+
"SEG": "SE",
|
| 1003 |
+
"CLL": "CL",
|
| 1004 |
+
"CLAH": "CL",
|
| 1005 |
+
"CLC": {
|
| 1006 |
+
"*": "CL",
|
| 1007 |
+
"TD5": "C",
|
| 1008 |
+
"PII": "C",
|
| 1009 |
+
"TDL": "C",
|
| 1010 |
+
"GM2": "C",
|
| 1011 |
+
"TD7": "C",
|
| 1012 |
+
"TD6": "C",
|
| 1013 |
+
"PIJ": "C",
|
| 1014 |
+
},
|
| 1015 |
+
"CLD": {"*": "CL", "PII": "C", "GM2": "C", "PIJ": "C"},
|
| 1016 |
+
"CLAD": "CL",
|
| 1017 |
+
"CLAE": "CL",
|
| 1018 |
+
"LA": "LA",
|
| 1019 |
+
"RH": "RH",
|
| 1020 |
+
"BRAC": "BR",
|
| 1021 |
+
"BRAD": "BR",
|
| 1022 |
+
"CLBN": "CL",
|
| 1023 |
+
"CLAC": "CL",
|
| 1024 |
+
"BRAB": "BR",
|
| 1025 |
+
"BRAE": "BR",
|
| 1026 |
+
"MG": "MG",
|
| 1027 |
+
"IR": "IR",
|
| 1028 |
+
"SE": {
|
| 1029 |
+
"*": "SE",
|
| 1030 |
+
"HII": "S",
|
| 1031 |
+
"NT2": "S",
|
| 1032 |
+
"R2P": "S",
|
| 1033 |
+
"S2P": "S",
|
| 1034 |
+
"0IU": "S",
|
| 1035 |
+
"QMB": "S",
|
| 1036 |
+
"81S": "S",
|
| 1037 |
+
"0QB": "S",
|
| 1038 |
+
"UB4": "S",
|
| 1039 |
+
"OHS": "S",
|
| 1040 |
+
"Q78": "S",
|
| 1041 |
+
"0Y2": "S",
|
| 1042 |
+
"B3M": "S",
|
| 1043 |
+
"NT1": "S",
|
| 1044 |
+
"81R": "S",
|
| 1045 |
+
},
|
| 1046 |
+
"BRAG": "BR",
|
| 1047 |
+
"CLF": {"*": "CL", "PII": "C", "GM2": "C", "PIJ": "C"},
|
| 1048 |
+
"CLE": {"*": "CL", "PII": "C", "GM2": "C", "PIJ": "C"},
|
| 1049 |
+
"BRAX": "BR",
|
| 1050 |
+
"CLK": "CL",
|
| 1051 |
+
"ZN": "ZN",
|
| 1052 |
+
"AS": "AS",
|
| 1053 |
+
"AU": "AU",
|
| 1054 |
+
"PT": "PT",
|
| 1055 |
+
"CLAS": "CL",
|
| 1056 |
+
"MN": "MN",
|
| 1057 |
+
"CLBE": "CL",
|
| 1058 |
+
"CLBF": "CL",
|
| 1059 |
+
"CLAF": "CL",
|
| 1060 |
+
"NA'": {"*": "N", "CGO": "NA"},
|
| 1061 |
+
"BRAH": "BR",
|
| 1062 |
+
"BRAI": "BR",
|
| 1063 |
+
"BRA": "BR",
|
| 1064 |
+
"BRB": "BR",
|
| 1065 |
+
"BRAV": "BR",
|
| 1066 |
+
"HG": {
|
| 1067 |
+
"*": "HG",
|
| 1068 |
+
"BBA": "H",
|
| 1069 |
+
"MID": "H",
|
| 1070 |
+
"APM": "H",
|
| 1071 |
+
"4QQ": "H",
|
| 1072 |
+
"0ZG": "H",
|
| 1073 |
+
"APH": "H",
|
| 1074 |
+
},
|
| 1075 |
+
"AR": "AR",
|
| 1076 |
+
"D": "H",
|
| 1077 |
+
"CLAN": "CL",
|
| 1078 |
+
"SI": "SI",
|
| 1079 |
+
"CLS": "CL",
|
| 1080 |
+
"ZR": "ZR",
|
| 1081 |
+
"CLAR": {"*": "CL", "ZM4": "C"},
|
| 1082 |
+
"HO": "HO",
|
| 1083 |
+
"CLI": {"*": "CL", "GM2": "C"},
|
| 1084 |
+
"CLH": {"*": "CL", "GM2": "C"},
|
| 1085 |
+
"CLAP": "CL",
|
| 1086 |
+
"CLBL": "CL",
|
| 1087 |
+
"CLBM": "CL",
|
| 1088 |
+
"PR": {"*": "PR", "UF0": "P", "252": "P"},
|
| 1089 |
+
"IN": "IN",
|
| 1090 |
+
"CLJ": "CL",
|
| 1091 |
+
"BRU": "BR",
|
| 1092 |
+
"SC": {"*": "S", "SFL": "SC"},
|
| 1093 |
+
"CLG": {"*": "CL", "GM2": "C"},
|
| 1094 |
+
"BRAT": "BR",
|
| 1095 |
+
"BRAR": "BR",
|
| 1096 |
+
"CLAG": "CL",
|
| 1097 |
+
"CLAB": "CL",
|
| 1098 |
+
"CLV": "CL",
|
| 1099 |
+
"TI": "TI",
|
| 1100 |
+
"CLAX": "CL",
|
| 1101 |
+
"CLAJ": "CL",
|
| 1102 |
+
"CL'": {"*": "CL", "BNR": "C", "25A": "C", "BDA": "C"},
|
| 1103 |
+
"CLAW": "CL",
|
| 1104 |
+
"BRF": "BR",
|
| 1105 |
+
"BRE": "BR",
|
| 1106 |
+
"RE": "RE",
|
| 1107 |
+
"GD": "GD",
|
| 1108 |
+
"SM": {"*": "S", "SM": "SM"},
|
| 1109 |
+
"CLBH": "CL",
|
| 1110 |
+
"CLBI": "CL",
|
| 1111 |
+
"CLAI": "CL",
|
| 1112 |
+
"CLY": "CL",
|
| 1113 |
+
"CLZ": "CL",
|
| 1114 |
+
"AC": "AC",
|
| 1115 |
+
"BR'": "BR",
|
| 1116 |
+
"CLT": "CL",
|
| 1117 |
+
"CLO": "CL",
|
| 1118 |
+
"CLP": "CL",
|
| 1119 |
+
"LU": "LU",
|
| 1120 |
+
"BA": {"*": "B", "BA": "BA"},
|
| 1121 |
+
"CLAU": "CL",
|
| 1122 |
+
"RB": "RB",
|
| 1123 |
+
"LI": "LI",
|
| 1124 |
+
"MOM": "MO",
|
| 1125 |
+
"BRAQ": "BR",
|
| 1126 |
+
"SR": {"*": "S", "SR": "SR", "OER": "SR"},
|
| 1127 |
+
"CLAT": "CL",
|
| 1128 |
+
"BRAL": "BR",
|
| 1129 |
+
"SEB": "SE",
|
| 1130 |
+
"CLW": "CL",
|
| 1131 |
+
"CLX": "CL",
|
| 1132 |
+
"BE": "BE",
|
| 1133 |
+
"BRG": "BR",
|
| 1134 |
+
"SEA": "SE",
|
| 1135 |
+
"BRAW": "BR",
|
| 1136 |
+
"BRBB": "BR",
|
| 1137 |
+
"ER": "ER",
|
| 1138 |
+
"TH": "TH",
|
| 1139 |
+
"BRR": "BR",
|
| 1140 |
+
"CLBV": "CL",
|
| 1141 |
+
"AL": "AL",
|
| 1142 |
+
"CLAV": "CL",
|
| 1143 |
+
"BRH": "BR",
|
| 1144 |
+
"CLAQ": "CL",
|
| 1145 |
+
"GA": "GA",
|
| 1146 |
+
"X": "*",
|
| 1147 |
+
"TL": "TL",
|
| 1148 |
+
"CLBB": "CL",
|
| 1149 |
+
"TB": "TB",
|
| 1150 |
+
"CLAK": "CL",
|
| 1151 |
+
"XE": {"*": "*", "XE": "XE"},
|
| 1152 |
+
"SEL": "SE",
|
| 1153 |
+
"PU": {"*": "P", "4PU": "PU"},
|
| 1154 |
+
"CLAZ": "CL",
|
| 1155 |
+
"SE'": "SE",
|
| 1156 |
+
"CLBA": "CL",
|
| 1157 |
+
"SEN": "SE",
|
| 1158 |
+
"SNN": "SN",
|
| 1159 |
+
"MOB": "MO",
|
| 1160 |
+
"YB": "YB",
|
| 1161 |
+
"BRC": "BR",
|
| 1162 |
+
"BRD": "BR",
|
| 1163 |
+
"CLAM": "CL",
|
| 1164 |
+
"DA": "H",
|
| 1165 |
+
"DB": "H",
|
| 1166 |
+
"DC": "H",
|
| 1167 |
+
"DXT": "H",
|
| 1168 |
+
"DXU": "H",
|
| 1169 |
+
"DXX": "H",
|
| 1170 |
+
"DXY": "H",
|
| 1171 |
+
"DXZ": "H",
|
| 1172 |
+
"DY": "DY",
|
| 1173 |
+
"TA": "TA",
|
| 1174 |
+
"XD": "*",
|
| 1175 |
+
"SED": "SE",
|
| 1176 |
+
"CLAL": "CL",
|
| 1177 |
+
"BRAJ": "BR",
|
| 1178 |
+
"AM": "AM",
|
| 1179 |
+
"CLAO": "CL",
|
| 1180 |
+
"BI": "BI",
|
| 1181 |
+
"KR": "KR",
|
| 1182 |
+
"BRBJ": "BR",
|
| 1183 |
+
"UNK": "*",
|
| 1184 |
+
}
|
protify/FastPLMs/boltz/src/boltz/data/crop/__init__.py
ADDED
|
File without changes
|
protify/FastPLMs/boltz/src/boltz/data/crop/affinity.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import replace
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from boltz.data import const
|
| 7 |
+
from boltz.data.crop.cropper import Cropper
|
| 8 |
+
from boltz.data.types import Tokenized
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AffinityCropper(Cropper):
|
| 12 |
+
"""Interpolate between contiguous and spatial crops."""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
neighborhood_size: int = 10,
|
| 17 |
+
max_tokens_protein: int = 200,
|
| 18 |
+
) -> None:
|
| 19 |
+
"""Initialize the cropper.
|
| 20 |
+
|
| 21 |
+
Parameters
|
| 22 |
+
----------
|
| 23 |
+
neighborhood_size : int
|
| 24 |
+
Modulates the type of cropping to be performed.
|
| 25 |
+
Smaller neighborhoods result in more spatial
|
| 26 |
+
cropping. Larger neighborhoods result in more
|
| 27 |
+
continuous cropping.
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
self.neighborhood_size = neighborhood_size
|
| 31 |
+
self.max_tokens_protein = max_tokens_protein
|
| 32 |
+
|
| 33 |
+
def crop(
|
| 34 |
+
self,
|
| 35 |
+
data: Tokenized,
|
| 36 |
+
max_tokens: int,
|
| 37 |
+
max_atoms: Optional[int] = None,
|
| 38 |
+
) -> Tokenized:
|
| 39 |
+
"""Crop the data to a maximum number of tokens.
|
| 40 |
+
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
data : Tokenized
|
| 44 |
+
The tokenized data.
|
| 45 |
+
max_tokens : int
|
| 46 |
+
The maximum number of tokens to crop.
|
| 47 |
+
random : np.random.RandomState
|
| 48 |
+
The random state for reproducibility.
|
| 49 |
+
max_atoms : Optional[int]
|
| 50 |
+
The maximum number of atoms to consider.
|
| 51 |
+
|
| 52 |
+
Returns
|
| 53 |
+
-------
|
| 54 |
+
Tokenized
|
| 55 |
+
The cropped data.
|
| 56 |
+
|
| 57 |
+
"""
|
| 58 |
+
# Get token data
|
| 59 |
+
token_data = data.tokens
|
| 60 |
+
token_bonds = data.bonds
|
| 61 |
+
|
| 62 |
+
# Filter to resolved tokens
|
| 63 |
+
valid_tokens = token_data[token_data["resolved_mask"]]
|
| 64 |
+
|
| 65 |
+
# Check if we have any valid tokens
|
| 66 |
+
if not valid_tokens.size:
|
| 67 |
+
msg = "No valid tokens in structure"
|
| 68 |
+
raise ValueError(msg)
|
| 69 |
+
|
| 70 |
+
# compute minimum distance to ligand
|
| 71 |
+
ligand_coords = valid_tokens[valid_tokens["affinity_mask"]]["center_coords"]
|
| 72 |
+
dists = np.min(
|
| 73 |
+
np.sum(
|
| 74 |
+
(valid_tokens["center_coords"][:, None] - ligand_coords[None]) ** 2,
|
| 75 |
+
axis=-1,
|
| 76 |
+
)
|
| 77 |
+
** 0.5,
|
| 78 |
+
axis=1,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
indices = np.argsort(dists)
|
| 82 |
+
|
| 83 |
+
# Select cropped indices
|
| 84 |
+
cropped: set[int] = set()
|
| 85 |
+
total_atoms = 0
|
| 86 |
+
|
| 87 |
+
# protein tokens
|
| 88 |
+
cropped_protein: set[int] = set()
|
| 89 |
+
ligand_ids = set(
|
| 90 |
+
valid_tokens[
|
| 91 |
+
valid_tokens["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
| 92 |
+
]["token_idx"]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
for idx in indices:
|
| 96 |
+
# Get the token
|
| 97 |
+
token = valid_tokens[idx]
|
| 98 |
+
|
| 99 |
+
# Get all tokens from this chain
|
| 100 |
+
chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]]
|
| 101 |
+
|
| 102 |
+
# Pick the whole chain if possible, otherwise select
|
| 103 |
+
# a contiguous subset centered at the query token
|
| 104 |
+
if len(chain_tokens) <= self.neighborhood_size:
|
| 105 |
+
new_tokens = chain_tokens
|
| 106 |
+
else:
|
| 107 |
+
# First limit to the maximum set of tokens, with the
|
| 108 |
+
# neighborhood on both sides to handle edges. This
|
| 109 |
+
# is mostly for efficiency with the while loop below.
|
| 110 |
+
min_idx = token["res_idx"] - self.neighborhood_size
|
| 111 |
+
max_idx = token["res_idx"] + self.neighborhood_size
|
| 112 |
+
|
| 113 |
+
max_token_set = chain_tokens
|
| 114 |
+
max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx]
|
| 115 |
+
max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx]
|
| 116 |
+
|
| 117 |
+
# Start by adding just the query token
|
| 118 |
+
new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]]
|
| 119 |
+
|
| 120 |
+
# Expand the neighborhood until we have enough tokens, one
|
| 121 |
+
# by one to handle some edge cases with non-standard chains.
|
| 122 |
+
# We switch to the res_idx instead of the token_idx to always
|
| 123 |
+
# include all tokens from modified residues or from ligands.
|
| 124 |
+
min_idx = max_idx = token["res_idx"]
|
| 125 |
+
while new_tokens.size < self.neighborhood_size:
|
| 126 |
+
min_idx = min_idx - 1
|
| 127 |
+
max_idx = max_idx + 1
|
| 128 |
+
new_tokens = max_token_set
|
| 129 |
+
new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx]
|
| 130 |
+
new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx]
|
| 131 |
+
|
| 132 |
+
# Compute new tokens and new atoms
|
| 133 |
+
new_indices = set(new_tokens["token_idx"]) - cropped
|
| 134 |
+
new_tokens = token_data[list(new_indices)]
|
| 135 |
+
new_atoms = np.sum(new_tokens["atom_num"])
|
| 136 |
+
|
| 137 |
+
# Stop if we exceed the max number of tokens or atoms
|
| 138 |
+
if (
|
| 139 |
+
(len(new_indices) > (max_tokens - len(cropped)))
|
| 140 |
+
or ((max_atoms is not None) and ((total_atoms + new_atoms) > max_atoms))
|
| 141 |
+
or (
|
| 142 |
+
len(cropped_protein | new_indices - ligand_ids)
|
| 143 |
+
> self.max_tokens_protein
|
| 144 |
+
)
|
| 145 |
+
):
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
# Add new indices
|
| 149 |
+
cropped.update(new_indices)
|
| 150 |
+
total_atoms += new_atoms
|
| 151 |
+
|
| 152 |
+
# Add protein indices
|
| 153 |
+
cropped_protein.update(new_indices - ligand_ids)
|
| 154 |
+
|
| 155 |
+
# Get the cropped tokens sorted by index
|
| 156 |
+
token_data = token_data[sorted(cropped)]
|
| 157 |
+
|
| 158 |
+
# Only keep bonds within the cropped tokens
|
| 159 |
+
indices = token_data["token_idx"]
|
| 160 |
+
token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)]
|
| 161 |
+
token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)]
|
| 162 |
+
|
| 163 |
+
# Return the cropped tokens
|
| 164 |
+
return replace(data, tokens=token_data, bonds=token_bonds)
|
protify/FastPLMs/boltz/src/boltz/data/crop/boltz.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import replace
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from scipy.spatial.distance import cdist
|
| 6 |
+
|
| 7 |
+
from boltz.data import const
|
| 8 |
+
from boltz.data.crop.cropper import Cropper
|
| 9 |
+
from boltz.data.types import Tokenized
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def pick_random_token(
|
| 13 |
+
tokens: np.ndarray,
|
| 14 |
+
random: np.random.RandomState,
|
| 15 |
+
) -> np.ndarray:
|
| 16 |
+
"""Pick a random token from the data.
|
| 17 |
+
|
| 18 |
+
Parameters
|
| 19 |
+
----------
|
| 20 |
+
tokens : np.ndarray
|
| 21 |
+
The token data.
|
| 22 |
+
random : np.ndarray
|
| 23 |
+
The random state for reproducibility.
|
| 24 |
+
|
| 25 |
+
Returns
|
| 26 |
+
-------
|
| 27 |
+
np.ndarray
|
| 28 |
+
The selected token.
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
return tokens[random.randint(len(tokens))]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def pick_chain_token(
|
| 35 |
+
tokens: np.ndarray,
|
| 36 |
+
chain_id: int,
|
| 37 |
+
random: np.random.RandomState,
|
| 38 |
+
) -> np.ndarray:
|
| 39 |
+
"""Pick a random token from a chain.
|
| 40 |
+
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
tokens : np.ndarray
|
| 44 |
+
The token data.
|
| 45 |
+
chain_id : int
|
| 46 |
+
The chain ID.
|
| 47 |
+
random : np.ndarray
|
| 48 |
+
The random state for reproducibility.
|
| 49 |
+
|
| 50 |
+
Returns
|
| 51 |
+
-------
|
| 52 |
+
np.ndarray
|
| 53 |
+
The selected token.
|
| 54 |
+
|
| 55 |
+
"""
|
| 56 |
+
# Filter to chain
|
| 57 |
+
chain_tokens = tokens[tokens["asym_id"] == chain_id]
|
| 58 |
+
|
| 59 |
+
# Pick from chain, fallback to all tokens
|
| 60 |
+
if chain_tokens.size:
|
| 61 |
+
query = pick_random_token(chain_tokens, random)
|
| 62 |
+
else:
|
| 63 |
+
query = pick_random_token(tokens, random)
|
| 64 |
+
|
| 65 |
+
return query
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def pick_interface_token(
|
| 69 |
+
tokens: np.ndarray,
|
| 70 |
+
interface: np.ndarray,
|
| 71 |
+
random: np.random.RandomState,
|
| 72 |
+
) -> np.ndarray:
|
| 73 |
+
"""Pick a random token from an interface.
|
| 74 |
+
|
| 75 |
+
Parameters
|
| 76 |
+
----------
|
| 77 |
+
tokens : np.ndarray
|
| 78 |
+
The token data.
|
| 79 |
+
interface : int
|
| 80 |
+
The interface ID.
|
| 81 |
+
random : np.ndarray
|
| 82 |
+
The random state for reproducibility.
|
| 83 |
+
|
| 84 |
+
Returns
|
| 85 |
+
-------
|
| 86 |
+
np.ndarray
|
| 87 |
+
The selected token.
|
| 88 |
+
|
| 89 |
+
"""
|
| 90 |
+
# Sample random interface
|
| 91 |
+
chain_1 = int(interface["chain_1"])
|
| 92 |
+
chain_2 = int(interface["chain_2"])
|
| 93 |
+
|
| 94 |
+
tokens_1 = tokens[tokens["asym_id"] == chain_1]
|
| 95 |
+
tokens_2 = tokens[tokens["asym_id"] == chain_2]
|
| 96 |
+
|
| 97 |
+
# If no interface, pick from the chains
|
| 98 |
+
if tokens_1.size and (not tokens_2.size):
|
| 99 |
+
query = pick_random_token(tokens_1, random)
|
| 100 |
+
elif tokens_2.size and (not tokens_1.size):
|
| 101 |
+
query = pick_random_token(tokens_2, random)
|
| 102 |
+
elif (not tokens_1.size) and (not tokens_2.size):
|
| 103 |
+
query = pick_random_token(tokens, random)
|
| 104 |
+
else:
|
| 105 |
+
# If we have tokens, compute distances
|
| 106 |
+
tokens_1_coords = tokens_1["center_coords"]
|
| 107 |
+
tokens_2_coords = tokens_2["center_coords"]
|
| 108 |
+
|
| 109 |
+
dists = cdist(tokens_1_coords, tokens_2_coords)
|
| 110 |
+
cuttoff = dists < const.interface_cutoff
|
| 111 |
+
|
| 112 |
+
# In rare cases, the interface cuttoff is slightly
|
| 113 |
+
# too small, then we slightly expand it if it happens
|
| 114 |
+
if not np.any(cuttoff):
|
| 115 |
+
cuttoff = dists < (const.interface_cutoff + 5.0)
|
| 116 |
+
|
| 117 |
+
tokens_1 = tokens_1[np.any(cuttoff, axis=1)]
|
| 118 |
+
tokens_2 = tokens_2[np.any(cuttoff, axis=0)]
|
| 119 |
+
|
| 120 |
+
# Select random token
|
| 121 |
+
candidates = np.concatenate([tokens_1, tokens_2])
|
| 122 |
+
query = pick_random_token(candidates, random)
|
| 123 |
+
|
| 124 |
+
return query
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class BoltzCropper(Cropper):
|
| 128 |
+
"""Interpolate between contiguous and spatial crops."""
|
| 129 |
+
|
| 130 |
+
def __init__(self, min_neighborhood: int = 0, max_neighborhood: int = 40) -> None:
|
| 131 |
+
"""Initialize the cropper.
|
| 132 |
+
|
| 133 |
+
Modulates the type of cropping to be performed.
|
| 134 |
+
Smaller neighborhoods result in more spatial
|
| 135 |
+
cropping. Larger neighborhoods result in more
|
| 136 |
+
continuous cropping. A mix can be achieved by
|
| 137 |
+
providing a range over which to sample.
|
| 138 |
+
|
| 139 |
+
Parameters
|
| 140 |
+
----------
|
| 141 |
+
min_neighborhood : int
|
| 142 |
+
The minimum neighborhood size, by default 0.
|
| 143 |
+
max_neighborhood : int
|
| 144 |
+
The maximum neighborhood size, by default 40.
|
| 145 |
+
|
| 146 |
+
"""
|
| 147 |
+
sizes = list(range(min_neighborhood, max_neighborhood + 1, 2))
|
| 148 |
+
self.neighborhood_sizes = sizes
|
| 149 |
+
|
| 150 |
+
def crop( # noqa: PLR0915
|
| 151 |
+
self,
|
| 152 |
+
data: Tokenized,
|
| 153 |
+
max_tokens: int,
|
| 154 |
+
random: np.random.RandomState,
|
| 155 |
+
max_atoms: Optional[int] = None,
|
| 156 |
+
chain_id: Optional[int] = None,
|
| 157 |
+
interface_id: Optional[int] = None,
|
| 158 |
+
) -> Tokenized:
|
| 159 |
+
"""Crop the data to a maximum number of tokens.
|
| 160 |
+
|
| 161 |
+
Parameters
|
| 162 |
+
----------
|
| 163 |
+
data : Tokenized
|
| 164 |
+
The tokenized data.
|
| 165 |
+
max_tokens : int
|
| 166 |
+
The maximum number of tokens to crop.
|
| 167 |
+
random : np.random.RandomState
|
| 168 |
+
The random state for reproducibility.
|
| 169 |
+
max_atoms : int, optional
|
| 170 |
+
The maximum number of atoms to consider.
|
| 171 |
+
chain_id : int, optional
|
| 172 |
+
The chain ID to crop.
|
| 173 |
+
interface_id : int, optional
|
| 174 |
+
The interface ID to crop.
|
| 175 |
+
|
| 176 |
+
Returns
|
| 177 |
+
-------
|
| 178 |
+
Tokenized
|
| 179 |
+
The cropped data.
|
| 180 |
+
|
| 181 |
+
"""
|
| 182 |
+
# Check inputs
|
| 183 |
+
if chain_id is not None and interface_id is not None:
|
| 184 |
+
msg = "Only one of chain_id or interface_id can be provided."
|
| 185 |
+
raise ValueError(msg)
|
| 186 |
+
|
| 187 |
+
# Randomly select a neighborhood size
|
| 188 |
+
neighborhood_size = random.choice(self.neighborhood_sizes)
|
| 189 |
+
|
| 190 |
+
# Get token data
|
| 191 |
+
token_data = data.tokens
|
| 192 |
+
token_bonds = data.bonds
|
| 193 |
+
mask = data.structure.mask
|
| 194 |
+
chains = data.structure.chains
|
| 195 |
+
interfaces = data.structure.interfaces
|
| 196 |
+
|
| 197 |
+
# Filter to valid chains
|
| 198 |
+
valid_chains = chains[mask]
|
| 199 |
+
|
| 200 |
+
# Filter to valid interfaces
|
| 201 |
+
valid_interfaces = interfaces
|
| 202 |
+
valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_1"]]]
|
| 203 |
+
valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_2"]]]
|
| 204 |
+
|
| 205 |
+
# Filter to resolved tokens
|
| 206 |
+
valid_tokens = token_data[token_data["resolved_mask"]]
|
| 207 |
+
|
| 208 |
+
# Check if we have any valid tokens
|
| 209 |
+
if not valid_tokens.size:
|
| 210 |
+
msg = "No valid tokens in structure"
|
| 211 |
+
raise ValueError(msg)
|
| 212 |
+
|
| 213 |
+
# Pick a random token, chain, or interface
|
| 214 |
+
if chain_id is not None:
|
| 215 |
+
query = pick_chain_token(valid_tokens, chain_id, random)
|
| 216 |
+
elif interface_id is not None:
|
| 217 |
+
interface = interfaces[interface_id]
|
| 218 |
+
query = pick_interface_token(valid_tokens, interface, random)
|
| 219 |
+
elif valid_interfaces.size:
|
| 220 |
+
idx = random.randint(len(valid_interfaces))
|
| 221 |
+
interface = valid_interfaces[idx]
|
| 222 |
+
query = pick_interface_token(valid_tokens, interface, random)
|
| 223 |
+
else:
|
| 224 |
+
idx = random.randint(len(valid_chains))
|
| 225 |
+
chain_id = valid_chains[idx]["asym_id"]
|
| 226 |
+
query = pick_chain_token(valid_tokens, chain_id, random)
|
| 227 |
+
|
| 228 |
+
# Sort all tokens by distance to query_coords
|
| 229 |
+
dists = valid_tokens["center_coords"] - query["center_coords"]
|
| 230 |
+
indices = np.argsort(np.linalg.norm(dists, axis=1))
|
| 231 |
+
|
| 232 |
+
# Select cropped indices
|
| 233 |
+
cropped: set[int] = set()
|
| 234 |
+
total_atoms = 0
|
| 235 |
+
for idx in indices:
|
| 236 |
+
# Get the token
|
| 237 |
+
token = valid_tokens[idx]
|
| 238 |
+
|
| 239 |
+
# Get all tokens from this chain
|
| 240 |
+
chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]]
|
| 241 |
+
|
| 242 |
+
# Pick the whole chain if possible, otherwise select
|
| 243 |
+
# a contiguous subset centered at the query token
|
| 244 |
+
if len(chain_tokens) <= neighborhood_size:
|
| 245 |
+
new_tokens = chain_tokens
|
| 246 |
+
else:
|
| 247 |
+
# First limit to the maximum set of tokens, with the
|
| 248 |
+
# neighborhood on both sides to handle edges. This
|
| 249 |
+
# is mostly for efficiency with the while loop below.
|
| 250 |
+
min_idx = token["res_idx"] - neighborhood_size
|
| 251 |
+
max_idx = token["res_idx"] + neighborhood_size
|
| 252 |
+
|
| 253 |
+
max_token_set = chain_tokens
|
| 254 |
+
max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx]
|
| 255 |
+
max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx]
|
| 256 |
+
|
| 257 |
+
# Start by adding just the query token
|
| 258 |
+
new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]]
|
| 259 |
+
|
| 260 |
+
# Expand the neighborhood until we have enough tokens, one
|
| 261 |
+
# by one to handle some edge cases with non-standard chains.
|
| 262 |
+
# We switch to the res_idx instead of the token_idx to always
|
| 263 |
+
# include all tokens from modified residues or from ligands.
|
| 264 |
+
min_idx = max_idx = token["res_idx"]
|
| 265 |
+
while new_tokens.size < neighborhood_size:
|
| 266 |
+
min_idx = min_idx - 1
|
| 267 |
+
max_idx = max_idx + 1
|
| 268 |
+
new_tokens = max_token_set
|
| 269 |
+
new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx]
|
| 270 |
+
new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx]
|
| 271 |
+
|
| 272 |
+
# Compute new tokens and new atoms
|
| 273 |
+
new_indices = set(new_tokens["token_idx"]) - cropped
|
| 274 |
+
new_tokens = token_data[list(new_indices)]
|
| 275 |
+
new_atoms = np.sum(new_tokens["atom_num"])
|
| 276 |
+
|
| 277 |
+
# Stop if we exceed the max number of tokens or atoms
|
| 278 |
+
if (len(new_indices) > (max_tokens - len(cropped))) or (
|
| 279 |
+
(max_atoms is not None) and ((total_atoms + new_atoms) > max_atoms)
|
| 280 |
+
):
|
| 281 |
+
break
|
| 282 |
+
|
| 283 |
+
# Add new indices
|
| 284 |
+
cropped.update(new_indices)
|
| 285 |
+
total_atoms += new_atoms
|
| 286 |
+
|
| 287 |
+
# Get the cropped tokens sorted by index
|
| 288 |
+
token_data = token_data[sorted(cropped)]
|
| 289 |
+
|
| 290 |
+
# Only keep bonds within the cropped tokens
|
| 291 |
+
indices = token_data["token_idx"]
|
| 292 |
+
token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)]
|
| 293 |
+
token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)]
|
| 294 |
+
|
| 295 |
+
# Return the cropped tokens
|
| 296 |
+
return replace(data, tokens=token_data, bonds=token_bonds)
|
protify/FastPLMs/boltz/src/boltz/data/crop/cropper.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from boltz.data.types import Tokenized
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Cropper(ABC):
|
| 10 |
+
"""Abstract base class for cropper."""
|
| 11 |
+
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def crop(
|
| 14 |
+
self,
|
| 15 |
+
data: Tokenized,
|
| 16 |
+
max_tokens: int,
|
| 17 |
+
random: np.random.RandomState,
|
| 18 |
+
max_atoms: Optional[int] = None,
|
| 19 |
+
chain_id: Optional[int] = None,
|
| 20 |
+
interface_id: Optional[int] = None,
|
| 21 |
+
) -> Tokenized:
|
| 22 |
+
"""Crop the data to a maximum number of tokens.
|
| 23 |
+
|
| 24 |
+
Parameters
|
| 25 |
+
----------
|
| 26 |
+
data : Tokenized
|
| 27 |
+
The tokenized data.
|
| 28 |
+
max_tokens : int
|
| 29 |
+
The maximum number of tokens to crop.
|
| 30 |
+
random : np.random.RandomState
|
| 31 |
+
The random state for reproducibility.
|
| 32 |
+
max_atoms : Optional[int]
|
| 33 |
+
The maximum number of atoms to consider.
|
| 34 |
+
chain_id : Optional[int]
|
| 35 |
+
The chain ID to crop.
|
| 36 |
+
interface_id : Optional[int]
|
| 37 |
+
The interface ID to crop.
|
| 38 |
+
|
| 39 |
+
Returns
|
| 40 |
+
-------
|
| 41 |
+
Tokenized
|
| 42 |
+
The cropped data.
|
| 43 |
+
|
| 44 |
+
"""
|
| 45 |
+
raise NotImplementedError
|
protify/FastPLMs/boltz/src/boltz/data/feature/__init__.py
ADDED
|
File without changes
|
protify/FastPLMs/boltz/src/boltz/data/feature/featurizer.py
ADDED
|
@@ -0,0 +1,1225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from collections import deque
|
| 5 |
+
import numba
|
| 6 |
+
import numpy as np
|
| 7 |
+
import numpy.typing as npt
|
| 8 |
+
import torch
|
| 9 |
+
from numba import types
|
| 10 |
+
from torch import Tensor, from_numpy
|
| 11 |
+
from torch.nn.functional import one_hot
|
| 12 |
+
|
| 13 |
+
from boltz.data import const
|
| 14 |
+
from boltz.data.feature.symmetry import (
|
| 15 |
+
get_amino_acids_symmetries,
|
| 16 |
+
get_chain_symmetries,
|
| 17 |
+
get_ligand_symmetries,
|
| 18 |
+
)
|
| 19 |
+
from boltz.data.pad import pad_dim
|
| 20 |
+
from boltz.data.types import (
|
| 21 |
+
MSA,
|
| 22 |
+
MSADeletion,
|
| 23 |
+
MSAResidue,
|
| 24 |
+
MSASequence,
|
| 25 |
+
Tokenized,
|
| 26 |
+
)
|
| 27 |
+
from boltz.model.modules.utils import center_random_augmentation
|
| 28 |
+
|
| 29 |
+
####################################################################################################
|
| 30 |
+
# HELPERS
|
| 31 |
+
####################################################################################################
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compute_frames_nonpolymer(
|
| 35 |
+
data: Tokenized,
|
| 36 |
+
coords,
|
| 37 |
+
resolved_mask,
|
| 38 |
+
atom_to_token,
|
| 39 |
+
frame_data: list,
|
| 40 |
+
resolved_frame_data: list,
|
| 41 |
+
) -> tuple[list, list]:
|
| 42 |
+
"""Get the frames for non-polymer tokens.
|
| 43 |
+
|
| 44 |
+
Parameters
|
| 45 |
+
----------
|
| 46 |
+
data : Tokenized
|
| 47 |
+
The tokenized data.
|
| 48 |
+
frame_data : list
|
| 49 |
+
The frame data.
|
| 50 |
+
resolved_frame_data : list
|
| 51 |
+
The resolved frame data.
|
| 52 |
+
|
| 53 |
+
Returns
|
| 54 |
+
-------
|
| 55 |
+
tuple[list, list]
|
| 56 |
+
The frame data and resolved frame data.
|
| 57 |
+
|
| 58 |
+
"""
|
| 59 |
+
frame_data = np.array(frame_data)
|
| 60 |
+
resolved_frame_data = np.array(resolved_frame_data)
|
| 61 |
+
asym_id_token = data.tokens["asym_id"]
|
| 62 |
+
asym_id_atom = data.tokens["asym_id"][atom_to_token]
|
| 63 |
+
token_idx = 0
|
| 64 |
+
atom_idx = 0
|
| 65 |
+
for id in np.unique(data.tokens["asym_id"]):
|
| 66 |
+
mask_chain_token = asym_id_token == id
|
| 67 |
+
mask_chain_atom = asym_id_atom == id
|
| 68 |
+
num_tokens = mask_chain_token.sum()
|
| 69 |
+
num_atoms = mask_chain_atom.sum()
|
| 70 |
+
if (
|
| 71 |
+
data.tokens[token_idx]["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
| 72 |
+
or num_atoms < 3
|
| 73 |
+
):
|
| 74 |
+
token_idx += num_tokens
|
| 75 |
+
atom_idx += num_atoms
|
| 76 |
+
continue
|
| 77 |
+
dist_mat = (
|
| 78 |
+
(
|
| 79 |
+
coords.reshape(-1, 3)[mask_chain_atom][:, None, :]
|
| 80 |
+
- coords.reshape(-1, 3)[mask_chain_atom][None, :, :]
|
| 81 |
+
)
|
| 82 |
+
** 2
|
| 83 |
+
).sum(-1) ** 0.5
|
| 84 |
+
resolved_pair = 1 - (
|
| 85 |
+
resolved_mask[mask_chain_atom][None, :]
|
| 86 |
+
* resolved_mask[mask_chain_atom][:, None]
|
| 87 |
+
).astype(np.float32)
|
| 88 |
+
resolved_pair[resolved_pair == 1] = math.inf
|
| 89 |
+
indices = np.argsort(dist_mat + resolved_pair, axis=1)
|
| 90 |
+
frames = (
|
| 91 |
+
np.concatenate(
|
| 92 |
+
[
|
| 93 |
+
indices[:, 1:2],
|
| 94 |
+
indices[:, 0:1],
|
| 95 |
+
indices[:, 2:3],
|
| 96 |
+
],
|
| 97 |
+
axis=1,
|
| 98 |
+
)
|
| 99 |
+
+ atom_idx
|
| 100 |
+
)
|
| 101 |
+
frame_data[token_idx : token_idx + num_atoms, :] = frames
|
| 102 |
+
resolved_frame_data[token_idx : token_idx + num_atoms] = resolved_mask[
|
| 103 |
+
frames
|
| 104 |
+
].all(axis=1)
|
| 105 |
+
token_idx += num_tokens
|
| 106 |
+
atom_idx += num_atoms
|
| 107 |
+
frames_expanded = coords.reshape(-1, 3)[frame_data]
|
| 108 |
+
|
| 109 |
+
mask_collinear = compute_collinear_mask(
|
| 110 |
+
frames_expanded[:, 1] - frames_expanded[:, 0],
|
| 111 |
+
frames_expanded[:, 1] - frames_expanded[:, 2],
|
| 112 |
+
)
|
| 113 |
+
return frame_data, resolved_frame_data & mask_collinear
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def compute_collinear_mask(v1, v2):
|
| 117 |
+
norm1 = np.linalg.norm(v1, axis=1, keepdims=True)
|
| 118 |
+
norm2 = np.linalg.norm(v2, axis=1, keepdims=True)
|
| 119 |
+
v1 = v1 / (norm1 + 1e-6)
|
| 120 |
+
v2 = v2 / (norm2 + 1e-6)
|
| 121 |
+
mask_angle = np.abs(np.sum(v1 * v2, axis=1)) < 0.9063
|
| 122 |
+
mask_overlap1 = norm1.reshape(-1) > 1e-2
|
| 123 |
+
mask_overlap2 = norm2.reshape(-1) > 1e-2
|
| 124 |
+
return mask_angle & mask_overlap1 & mask_overlap2
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def dummy_msa(residues: np.ndarray) -> MSA:
|
| 128 |
+
"""Create a dummy MSA for a chain.
|
| 129 |
+
|
| 130 |
+
Parameters
|
| 131 |
+
----------
|
| 132 |
+
residues : np.ndarray
|
| 133 |
+
The residues for the chain.
|
| 134 |
+
|
| 135 |
+
Returns
|
| 136 |
+
-------
|
| 137 |
+
MSA
|
| 138 |
+
The dummy MSA.
|
| 139 |
+
|
| 140 |
+
"""
|
| 141 |
+
residues = [res["res_type"] for res in residues]
|
| 142 |
+
deletions = []
|
| 143 |
+
sequences = [(0, -1, 0, len(residues), 0, 0)]
|
| 144 |
+
return MSA(
|
| 145 |
+
residues=np.array(residues, dtype=MSAResidue),
|
| 146 |
+
deletions=np.array(deletions, dtype=MSADeletion),
|
| 147 |
+
sequences=np.array(sequences, dtype=MSASequence),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def construct_paired_msa( # noqa: C901, PLR0915, PLR0912
|
| 152 |
+
data: Tokenized,
|
| 153 |
+
max_seqs: int,
|
| 154 |
+
max_pairs: int = 8192,
|
| 155 |
+
max_total: int = 16384,
|
| 156 |
+
random_subset: bool = False,
|
| 157 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 158 |
+
"""Pair the MSA data.
|
| 159 |
+
|
| 160 |
+
Parameters
|
| 161 |
+
----------
|
| 162 |
+
data : Input
|
| 163 |
+
The input data.
|
| 164 |
+
|
| 165 |
+
Returns
|
| 166 |
+
-------
|
| 167 |
+
Tensor
|
| 168 |
+
The MSA data.
|
| 169 |
+
Tensor
|
| 170 |
+
The deletion data.
|
| 171 |
+
Tensor
|
| 172 |
+
Mask indicating paired sequences.
|
| 173 |
+
|
| 174 |
+
"""
|
| 175 |
+
# Get unique chains (ensuring monotonicity in the order)
|
| 176 |
+
assert np.all(np.diff(data.tokens["asym_id"], n=1) >= 0)
|
| 177 |
+
chain_ids = np.unique(data.tokens["asym_id"])
|
| 178 |
+
|
| 179 |
+
# Get relevant MSA, and create a dummy for chains without
|
| 180 |
+
msa = {k: data.msa[k] for k in chain_ids if k in data.msa}
|
| 181 |
+
for chain_id in chain_ids:
|
| 182 |
+
if chain_id not in msa:
|
| 183 |
+
chain = data.structure.chains[chain_id]
|
| 184 |
+
res_start = chain["res_idx"]
|
| 185 |
+
res_end = res_start + chain["res_num"]
|
| 186 |
+
residues = data.structure.residues[res_start:res_end]
|
| 187 |
+
msa[chain_id] = dummy_msa(residues)
|
| 188 |
+
|
| 189 |
+
# Map taxonomies to (chain_id, seq_idx)
|
| 190 |
+
taxonomy_map: dict[str, list] = {}
|
| 191 |
+
for chain_id, chain_msa in msa.items():
|
| 192 |
+
sequences = chain_msa.sequences
|
| 193 |
+
sequences = sequences[sequences["taxonomy"] != -1]
|
| 194 |
+
for sequence in sequences:
|
| 195 |
+
seq_idx = sequence["seq_idx"]
|
| 196 |
+
taxon = sequence["taxonomy"]
|
| 197 |
+
taxonomy_map.setdefault(taxon, []).append((chain_id, seq_idx))
|
| 198 |
+
|
| 199 |
+
# Remove taxonomies with only one sequence and sort by the
|
| 200 |
+
# number of chain_id present in each of the taxonomies
|
| 201 |
+
taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1}
|
| 202 |
+
taxonomy_map = sorted(
|
| 203 |
+
taxonomy_map.items(),
|
| 204 |
+
key=lambda x: len({c for c, _ in x[1]}),
|
| 205 |
+
reverse=True,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Keep track of the sequences available per chain, keeping the original
|
| 209 |
+
# order of the sequences in the MSA to favor the best matching sequences
|
| 210 |
+
visited = {(c, s) for c, items in taxonomy_map for s in items}
|
| 211 |
+
available = {}
|
| 212 |
+
for c in chain_ids:
|
| 213 |
+
available[c] = deque(
|
| 214 |
+
i for i in range(1, len(msa[c].sequences)) if (c, i) not in visited
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Create sequence pairs
|
| 218 |
+
is_paired = []
|
| 219 |
+
pairing = []
|
| 220 |
+
|
| 221 |
+
# Start with the first sequence for each chain
|
| 222 |
+
is_paired.append({c: 1 for c in chain_ids})
|
| 223 |
+
pairing.append({c: 0 for c in chain_ids})
|
| 224 |
+
|
| 225 |
+
# Then add up to 8191 paired rows
|
| 226 |
+
for _, pairs in taxonomy_map:
|
| 227 |
+
# Group occurences by chain_id in case we have multiple
|
| 228 |
+
# sequences from the same chain and same taxonomy
|
| 229 |
+
chain_occurences = {}
|
| 230 |
+
for chain_id, seq_idx in pairs:
|
| 231 |
+
chain_occurences.setdefault(chain_id, []).append(seq_idx)
|
| 232 |
+
|
| 233 |
+
# We create as many pairings as the maximum number of occurences
|
| 234 |
+
max_occurences = max(len(v) for v in chain_occurences.values())
|
| 235 |
+
for i in range(max_occurences):
|
| 236 |
+
row_pairing = {}
|
| 237 |
+
row_is_paired = {}
|
| 238 |
+
|
| 239 |
+
# Add the chains present in the taxonomy
|
| 240 |
+
for chain_id, seq_idxs in chain_occurences.items():
|
| 241 |
+
# Roll over the sequence index to maximize diversity
|
| 242 |
+
idx = i % len(seq_idxs)
|
| 243 |
+
seq_idx = seq_idxs[idx]
|
| 244 |
+
|
| 245 |
+
# Add the sequence to the pairing
|
| 246 |
+
row_pairing[chain_id] = seq_idx
|
| 247 |
+
row_is_paired[chain_id] = 1
|
| 248 |
+
|
| 249 |
+
# Add any missing chains
|
| 250 |
+
for chain_id in chain_ids:
|
| 251 |
+
if chain_id not in row_pairing:
|
| 252 |
+
row_is_paired[chain_id] = 0
|
| 253 |
+
if available[chain_id]:
|
| 254 |
+
# Add the next available sequence
|
| 255 |
+
row_pairing[chain_id] = available[chain_id].popleft()
|
| 256 |
+
else:
|
| 257 |
+
# No more sequences available, we place a gap
|
| 258 |
+
row_pairing[chain_id] = -1
|
| 259 |
+
|
| 260 |
+
pairing.append(row_pairing)
|
| 261 |
+
is_paired.append(row_is_paired)
|
| 262 |
+
|
| 263 |
+
# Break if we have enough pairs
|
| 264 |
+
if len(pairing) >= max_pairs:
|
| 265 |
+
break
|
| 266 |
+
|
| 267 |
+
# Break if we have enough pairs
|
| 268 |
+
if len(pairing) >= max_pairs:
|
| 269 |
+
break
|
| 270 |
+
|
| 271 |
+
# Now add up to 16384 unpaired rows total
|
| 272 |
+
max_left = max(len(v) for v in available.values())
|
| 273 |
+
for _ in range(min(max_total - len(pairing), max_left)):
|
| 274 |
+
row_pairing = {}
|
| 275 |
+
row_is_paired = {}
|
| 276 |
+
for chain_id in chain_ids:
|
| 277 |
+
row_is_paired[chain_id] = 0
|
| 278 |
+
if available[chain_id]:
|
| 279 |
+
# Add the next available sequence
|
| 280 |
+
row_pairing[chain_id] = available[chain_id].popleft()
|
| 281 |
+
else:
|
| 282 |
+
# No more sequences available, we place a gap
|
| 283 |
+
row_pairing[chain_id] = -1
|
| 284 |
+
|
| 285 |
+
pairing.append(row_pairing)
|
| 286 |
+
is_paired.append(row_is_paired)
|
| 287 |
+
|
| 288 |
+
# Break if we have enough sequences
|
| 289 |
+
if len(pairing) >= max_total:
|
| 290 |
+
break
|
| 291 |
+
|
| 292 |
+
# Randomly sample a subset of the pairs
|
| 293 |
+
# ensuring the first row is always present
|
| 294 |
+
if random_subset:
|
| 295 |
+
num_seqs = len(pairing)
|
| 296 |
+
if num_seqs > max_seqs:
|
| 297 |
+
indices = np.random.choice(
|
| 298 |
+
list(range(1, num_seqs)), size=max_seqs - 1, replace=False
|
| 299 |
+
) # noqa: NPY002
|
| 300 |
+
pairing = [pairing[0]] + [pairing[i] for i in indices]
|
| 301 |
+
is_paired = [is_paired[0]] + [is_paired[i] for i in indices]
|
| 302 |
+
else:
|
| 303 |
+
# Deterministic downsample to max_seqs
|
| 304 |
+
pairing = pairing[:max_seqs]
|
| 305 |
+
is_paired = is_paired[:max_seqs]
|
| 306 |
+
|
| 307 |
+
# Map (chain_id, seq_idx, res_idx) to deletion
|
| 308 |
+
deletions = numba.typed.Dict.empty(
|
| 309 |
+
key_type=numba.types.Tuple(
|
| 310 |
+
[numba.types.int64, numba.types.int64, numba.types.int64]),
|
| 311 |
+
value_type=numba.types.int64
|
| 312 |
+
)
|
| 313 |
+
for chain_id, chain_msa in msa.items():
|
| 314 |
+
chain_deletions = chain_msa.deletions
|
| 315 |
+
for sequence in chain_msa.sequences:
|
| 316 |
+
seq_idx = sequence["seq_idx"]
|
| 317 |
+
del_start = sequence["del_start"]
|
| 318 |
+
del_end = sequence["del_end"]
|
| 319 |
+
chain_deletions = chain_deletions[del_start:del_end]
|
| 320 |
+
for deletion_data in chain_deletions:
|
| 321 |
+
res_idx = deletion_data["res_idx"]
|
| 322 |
+
deletion_values = deletion_data["deletion"]
|
| 323 |
+
deletions[(chain_id, seq_idx, res_idx)] = deletion_values
|
| 324 |
+
|
| 325 |
+
# Add all the token MSA data
|
| 326 |
+
msa_data, del_data, paired_data = prepare_msa_arrays(
|
| 327 |
+
data.tokens, pairing, is_paired, deletions, msa
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
msa_data = torch.tensor(msa_data, dtype=torch.long)
|
| 331 |
+
del_data = torch.tensor(del_data, dtype=torch.float)
|
| 332 |
+
paired_data = torch.tensor(paired_data, dtype=torch.float)
|
| 333 |
+
|
| 334 |
+
return msa_data, del_data, paired_data
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def prepare_msa_arrays(
|
| 338 |
+
tokens,
|
| 339 |
+
pairing: list[dict[int, int]],
|
| 340 |
+
is_paired: list[dict[int, int]],
|
| 341 |
+
deletions: dict[tuple[int, int, int], int],
|
| 342 |
+
msa: dict[int, MSA],
|
| 343 |
+
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
|
| 344 |
+
"""Reshape data to play nicely with numba jit."""
|
| 345 |
+
token_asym_ids_arr = np.array([t["asym_id"] for t in tokens], dtype=np.int64)
|
| 346 |
+
token_res_idxs_arr = np.array([t["res_idx"] for t in tokens], dtype=np.int64)
|
| 347 |
+
|
| 348 |
+
chain_ids = sorted(msa.keys())
|
| 349 |
+
|
| 350 |
+
# chain_ids are not necessarily contiguous (e.g. they might be 0, 24, 25).
|
| 351 |
+
# This allows us to look up a chain_id by it's index in the chain_ids list.
|
| 352 |
+
chain_id_to_idx = {chain_id: i for i, chain_id in enumerate(chain_ids)}
|
| 353 |
+
token_asym_ids_idx_arr = np.array(
|
| 354 |
+
[chain_id_to_idx[asym_id] for asym_id in token_asym_ids_arr], dtype=np.int64
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
pairing_arr = np.zeros((len(pairing), len(chain_ids)), dtype=np.int64)
|
| 358 |
+
is_paired_arr = np.zeros((len(is_paired), len(chain_ids)), dtype=np.int64)
|
| 359 |
+
|
| 360 |
+
for i, row_pairing in enumerate(pairing):
|
| 361 |
+
for chain_id in chain_ids:
|
| 362 |
+
pairing_arr[i, chain_id_to_idx[chain_id]] = row_pairing[chain_id]
|
| 363 |
+
|
| 364 |
+
for i, row_is_paired in enumerate(is_paired):
|
| 365 |
+
for chain_id in chain_ids:
|
| 366 |
+
is_paired_arr[i, chain_id_to_idx[chain_id]] = row_is_paired[chain_id]
|
| 367 |
+
|
| 368 |
+
max_seq_len = max(len(msa[chain_id].sequences) for chain_id in chain_ids)
|
| 369 |
+
|
| 370 |
+
# we want res_start from sequences
|
| 371 |
+
msa_sequences = np.full((len(chain_ids), max_seq_len), -1, dtype=np.int64)
|
| 372 |
+
for chain_id in chain_ids:
|
| 373 |
+
for i, seq in enumerate(msa[chain_id].sequences):
|
| 374 |
+
msa_sequences[chain_id_to_idx[chain_id], i] = seq["res_start"]
|
| 375 |
+
|
| 376 |
+
max_residues_len = max(len(msa[chain_id].residues) for chain_id in chain_ids)
|
| 377 |
+
msa_residues = np.full((len(chain_ids), max_residues_len), -1, dtype=np.int64)
|
| 378 |
+
for chain_id in chain_ids:
|
| 379 |
+
residues = msa[chain_id].residues.astype(np.int64)
|
| 380 |
+
idxs = np.arange(len(residues))
|
| 381 |
+
chain_idx = chain_id_to_idx[chain_id]
|
| 382 |
+
msa_residues[chain_idx, idxs] = residues
|
| 383 |
+
|
| 384 |
+
return _prepare_msa_arrays_inner(
|
| 385 |
+
token_asym_ids_arr,
|
| 386 |
+
token_res_idxs_arr,
|
| 387 |
+
token_asym_ids_idx_arr,
|
| 388 |
+
pairing_arr,
|
| 389 |
+
is_paired_arr,
|
| 390 |
+
deletions,
|
| 391 |
+
msa_sequences,
|
| 392 |
+
msa_residues,
|
| 393 |
+
const.token_ids["-"],
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
deletions_dict_type = types.DictType(types.UniTuple(types.int64, 3), types.int64)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
@numba.njit(
|
| 401 |
+
[
|
| 402 |
+
types.Tuple(
|
| 403 |
+
(
|
| 404 |
+
types.int64[:, ::1], # msa_data
|
| 405 |
+
types.int64[:, ::1], # del_data
|
| 406 |
+
types.int64[:, ::1], # paired_data
|
| 407 |
+
)
|
| 408 |
+
)(
|
| 409 |
+
types.int64[::1], # token_asym_ids
|
| 410 |
+
types.int64[::1], # token_res_idxs
|
| 411 |
+
types.int64[::1], # token_asym_ids_idx
|
| 412 |
+
types.int64[:, ::1], # pairing
|
| 413 |
+
types.int64[:, ::1], # is_paired
|
| 414 |
+
deletions_dict_type, # deletions
|
| 415 |
+
types.int64[:, ::1], # msa_sequences
|
| 416 |
+
types.int64[:, ::1], # msa_residues
|
| 417 |
+
types.int64, # gap_token
|
| 418 |
+
)
|
| 419 |
+
],
|
| 420 |
+
cache=True,
|
| 421 |
+
)
|
| 422 |
+
def _prepare_msa_arrays_inner(
|
| 423 |
+
token_asym_ids: npt.NDArray[np.int64],
|
| 424 |
+
token_res_idxs: npt.NDArray[np.int64],
|
| 425 |
+
token_asym_ids_idx: npt.NDArray[np.int64],
|
| 426 |
+
pairing: npt.NDArray[np.int64],
|
| 427 |
+
is_paired: npt.NDArray[np.int64],
|
| 428 |
+
deletions: dict[tuple[int, int, int], int],
|
| 429 |
+
msa_sequences: npt.NDArray[np.int64],
|
| 430 |
+
msa_residues: npt.NDArray[np.int64],
|
| 431 |
+
gap_token: int,
|
| 432 |
+
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
|
| 433 |
+
n_tokens = len(token_asym_ids)
|
| 434 |
+
n_pairs = len(pairing)
|
| 435 |
+
msa_data = np.full((n_tokens, n_pairs), gap_token, dtype=np.int64)
|
| 436 |
+
paired_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
|
| 437 |
+
del_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
|
| 438 |
+
|
| 439 |
+
# Add all the token MSA data
|
| 440 |
+
for token_idx in range(n_tokens):
|
| 441 |
+
chain_id_idx = token_asym_ids_idx[token_idx]
|
| 442 |
+
chain_id = token_asym_ids[token_idx]
|
| 443 |
+
res_idx = token_res_idxs[token_idx]
|
| 444 |
+
|
| 445 |
+
for pair_idx in range(n_pairs):
|
| 446 |
+
seq_idx = pairing[pair_idx, chain_id_idx]
|
| 447 |
+
paired_data[token_idx, pair_idx] = is_paired[pair_idx, chain_id_idx]
|
| 448 |
+
|
| 449 |
+
# Add residue type
|
| 450 |
+
if seq_idx != -1:
|
| 451 |
+
res_start = msa_sequences[chain_id_idx, seq_idx]
|
| 452 |
+
res_type = msa_residues[chain_id_idx, res_start + res_idx]
|
| 453 |
+
k = (chain_id, seq_idx, res_idx)
|
| 454 |
+
if k in deletions:
|
| 455 |
+
del_data[token_idx, pair_idx] = deletions[k]
|
| 456 |
+
msa_data[token_idx, pair_idx] = res_type
|
| 457 |
+
|
| 458 |
+
return msa_data, del_data, paired_data
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
####################################################################################################
|
| 462 |
+
# FEATURES
|
| 463 |
+
####################################################################################################
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def select_subset_from_mask(mask, p):
|
| 467 |
+
num_true = np.sum(mask)
|
| 468 |
+
v = np.random.geometric(p) + 1
|
| 469 |
+
k = min(v, num_true)
|
| 470 |
+
|
| 471 |
+
true_indices = np.where(mask)[0]
|
| 472 |
+
|
| 473 |
+
# Randomly select k indices from the true_indices
|
| 474 |
+
selected_indices = np.random.choice(true_indices, size=k, replace=False)
|
| 475 |
+
|
| 476 |
+
new_mask = np.zeros_like(mask)
|
| 477 |
+
new_mask[selected_indices] = 1
|
| 478 |
+
|
| 479 |
+
return new_mask
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def process_token_features(
|
| 483 |
+
data: Tokenized,
|
| 484 |
+
max_tokens: Optional[int] = None,
|
| 485 |
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
| 486 |
+
binder_pocket_cutoff: Optional[float] = 6.0,
|
| 487 |
+
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
|
| 488 |
+
only_ligand_binder_pocket: Optional[bool] = False,
|
| 489 |
+
inference_binder: Optional[list[int]] = None,
|
| 490 |
+
inference_pocket: Optional[list[tuple[int, int]]] = None,
|
| 491 |
+
) -> dict[str, Tensor]:
|
| 492 |
+
"""Get the token features.
|
| 493 |
+
|
| 494 |
+
Parameters
|
| 495 |
+
----------
|
| 496 |
+
data : Tokenized
|
| 497 |
+
The tokenized data.
|
| 498 |
+
max_tokens : int
|
| 499 |
+
The maximum number of tokens.
|
| 500 |
+
|
| 501 |
+
Returns
|
| 502 |
+
-------
|
| 503 |
+
dict[str, Tensor]
|
| 504 |
+
The token features.
|
| 505 |
+
|
| 506 |
+
"""
|
| 507 |
+
# Token data
|
| 508 |
+
token_data = data.tokens
|
| 509 |
+
token_bonds = data.bonds
|
| 510 |
+
|
| 511 |
+
# Token core features
|
| 512 |
+
token_index = torch.arange(len(token_data), dtype=torch.long)
|
| 513 |
+
residue_index = from_numpy(token_data["res_idx"].copy()).long()
|
| 514 |
+
asym_id = from_numpy(token_data["asym_id"].copy()).long()
|
| 515 |
+
entity_id = from_numpy(token_data["entity_id"].copy()).long()
|
| 516 |
+
sym_id = from_numpy(token_data["sym_id"].copy()).long()
|
| 517 |
+
mol_type = from_numpy(token_data["mol_type"].copy()).long()
|
| 518 |
+
res_type = from_numpy(token_data["res_type"].copy()).long()
|
| 519 |
+
res_type = one_hot(res_type, num_classes=const.num_tokens)
|
| 520 |
+
disto_center = from_numpy(token_data["disto_coords"].copy())
|
| 521 |
+
|
| 522 |
+
# Token mask features
|
| 523 |
+
pad_mask = torch.ones(len(token_data), dtype=torch.float)
|
| 524 |
+
resolved_mask = from_numpy(token_data["resolved_mask"].copy()).float()
|
| 525 |
+
disto_mask = from_numpy(token_data["disto_mask"].copy()).float()
|
| 526 |
+
cyclic_period = from_numpy(token_data["cyclic_period"].copy())
|
| 527 |
+
|
| 528 |
+
# Token bond features
|
| 529 |
+
if max_tokens is not None:
|
| 530 |
+
pad_len = max_tokens - len(token_data)
|
| 531 |
+
num_tokens = max_tokens if pad_len > 0 else len(token_data)
|
| 532 |
+
else:
|
| 533 |
+
num_tokens = len(token_data)
|
| 534 |
+
|
| 535 |
+
tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)}
|
| 536 |
+
bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float)
|
| 537 |
+
for token_bond in token_bonds:
|
| 538 |
+
token_1 = tok_to_idx[token_bond["token_1"]]
|
| 539 |
+
token_2 = tok_to_idx[token_bond["token_2"]]
|
| 540 |
+
bonds[token_1, token_2] = 1
|
| 541 |
+
bonds[token_2, token_1] = 1
|
| 542 |
+
|
| 543 |
+
bonds = bonds.unsqueeze(-1)
|
| 544 |
+
|
| 545 |
+
# Pocket conditioned feature
|
| 546 |
+
pocket_feature = (
|
| 547 |
+
np.zeros(len(token_data)) + const.pocket_contact_info["UNSPECIFIED"]
|
| 548 |
+
)
|
| 549 |
+
if inference_binder is not None:
|
| 550 |
+
assert inference_pocket is not None
|
| 551 |
+
pocket_residues = set(inference_pocket)
|
| 552 |
+
for idx, token in enumerate(token_data):
|
| 553 |
+
if token["asym_id"] == inference_binder:
|
| 554 |
+
pocket_feature[idx] = const.pocket_contact_info["BINDER"]
|
| 555 |
+
elif (token["asym_id"], token["res_idx"]) in pocket_residues:
|
| 556 |
+
pocket_feature[idx] = const.pocket_contact_info["POCKET"]
|
| 557 |
+
else:
|
| 558 |
+
pocket_feature[idx] = const.pocket_contact_info["UNSELECTED"]
|
| 559 |
+
elif (
|
| 560 |
+
binder_pocket_conditioned_prop > 0.0
|
| 561 |
+
and random.random() < binder_pocket_conditioned_prop
|
| 562 |
+
):
|
| 563 |
+
# choose as binder a random ligand in the crop, if there are no ligands select a protein chain
|
| 564 |
+
binder_asym_ids = np.unique(
|
| 565 |
+
token_data["asym_id"][
|
| 566 |
+
token_data["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
| 567 |
+
]
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
if len(binder_asym_ids) == 0:
|
| 571 |
+
if not only_ligand_binder_pocket:
|
| 572 |
+
binder_asym_ids = np.unique(token_data["asym_id"])
|
| 573 |
+
|
| 574 |
+
if len(binder_asym_ids) > 0:
|
| 575 |
+
pocket_asym_id = random.choice(binder_asym_ids)
|
| 576 |
+
binder_mask = token_data["asym_id"] == pocket_asym_id
|
| 577 |
+
|
| 578 |
+
binder_coords = []
|
| 579 |
+
for token in token_data:
|
| 580 |
+
if token["asym_id"] == pocket_asym_id:
|
| 581 |
+
binder_coords.append(
|
| 582 |
+
data.structure.atoms["coords"][
|
| 583 |
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
| 584 |
+
]
|
| 585 |
+
)
|
| 586 |
+
binder_coords = np.concatenate(binder_coords, axis=0)
|
| 587 |
+
|
| 588 |
+
# find the tokens in the pocket
|
| 589 |
+
token_dist = np.zeros(len(token_data)) + 1000
|
| 590 |
+
for i, token in enumerate(token_data):
|
| 591 |
+
if (
|
| 592 |
+
token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
| 593 |
+
and token["asym_id"] != pocket_asym_id
|
| 594 |
+
and token["resolved_mask"] == 1
|
| 595 |
+
):
|
| 596 |
+
token_coords = data.structure.atoms["coords"][
|
| 597 |
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
| 598 |
+
]
|
| 599 |
+
|
| 600 |
+
# find chain and apply chain transformation
|
| 601 |
+
for chain in data.structure.chains:
|
| 602 |
+
if chain["asym_id"] == token["asym_id"]:
|
| 603 |
+
break
|
| 604 |
+
|
| 605 |
+
token_dist[i] = np.min(
|
| 606 |
+
np.linalg.norm(
|
| 607 |
+
token_coords[:, None, :] - binder_coords[None, :, :],
|
| 608 |
+
axis=-1,
|
| 609 |
+
)
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
pocket_mask = token_dist < binder_pocket_cutoff
|
| 613 |
+
|
| 614 |
+
if np.sum(pocket_mask) > 0:
|
| 615 |
+
pocket_feature = (
|
| 616 |
+
np.zeros(len(token_data)) + const.pocket_contact_info["UNSELECTED"]
|
| 617 |
+
)
|
| 618 |
+
pocket_feature[binder_mask] = const.pocket_contact_info["BINDER"]
|
| 619 |
+
|
| 620 |
+
if binder_pocket_sampling_geometric_p > 0.0:
|
| 621 |
+
# select a subset of the pocket, according
|
| 622 |
+
# to a geometric distribution with one as minimum
|
| 623 |
+
pocket_mask = select_subset_from_mask(
|
| 624 |
+
pocket_mask, binder_pocket_sampling_geometric_p
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
pocket_feature[pocket_mask] = const.pocket_contact_info["POCKET"]
|
| 628 |
+
pocket_feature = from_numpy(pocket_feature).long()
|
| 629 |
+
pocket_feature = one_hot(pocket_feature, num_classes=len(const.pocket_contact_info))
|
| 630 |
+
|
| 631 |
+
# Pad to max tokens if given
|
| 632 |
+
if max_tokens is not None:
|
| 633 |
+
pad_len = max_tokens - len(token_data)
|
| 634 |
+
if pad_len > 0:
|
| 635 |
+
token_index = pad_dim(token_index, 0, pad_len)
|
| 636 |
+
residue_index = pad_dim(residue_index, 0, pad_len)
|
| 637 |
+
asym_id = pad_dim(asym_id, 0, pad_len)
|
| 638 |
+
entity_id = pad_dim(entity_id, 0, pad_len)
|
| 639 |
+
sym_id = pad_dim(sym_id, 0, pad_len)
|
| 640 |
+
mol_type = pad_dim(mol_type, 0, pad_len)
|
| 641 |
+
res_type = pad_dim(res_type, 0, pad_len)
|
| 642 |
+
disto_center = pad_dim(disto_center, 0, pad_len)
|
| 643 |
+
pad_mask = pad_dim(pad_mask, 0, pad_len)
|
| 644 |
+
resolved_mask = pad_dim(resolved_mask, 0, pad_len)
|
| 645 |
+
disto_mask = pad_dim(disto_mask, 0, pad_len)
|
| 646 |
+
pocket_feature = pad_dim(pocket_feature, 0, pad_len)
|
| 647 |
+
cyclic_period = pad_dim(cyclic_period, 0, pad_len)
|
| 648 |
+
|
| 649 |
+
token_features = {
|
| 650 |
+
"token_index": token_index,
|
| 651 |
+
"residue_index": residue_index,
|
| 652 |
+
"asym_id": asym_id,
|
| 653 |
+
"entity_id": entity_id,
|
| 654 |
+
"sym_id": sym_id,
|
| 655 |
+
"mol_type": mol_type,
|
| 656 |
+
"res_type": res_type,
|
| 657 |
+
"disto_center": disto_center,
|
| 658 |
+
"token_bonds": bonds,
|
| 659 |
+
"token_pad_mask": pad_mask,
|
| 660 |
+
"token_resolved_mask": resolved_mask,
|
| 661 |
+
"token_disto_mask": disto_mask,
|
| 662 |
+
"pocket_feature": pocket_feature,
|
| 663 |
+
"cyclic_period": cyclic_period,
|
| 664 |
+
}
|
| 665 |
+
return token_features
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def process_atom_features(
|
| 669 |
+
data: Tokenized,
|
| 670 |
+
atoms_per_window_queries: int = 32,
|
| 671 |
+
min_dist: float = 2.0,
|
| 672 |
+
max_dist: float = 22.0,
|
| 673 |
+
num_bins: int = 64,
|
| 674 |
+
max_atoms: Optional[int] = None,
|
| 675 |
+
max_tokens: Optional[int] = None,
|
| 676 |
+
) -> dict[str, Tensor]:
|
| 677 |
+
"""Get the atom features.
|
| 678 |
+
|
| 679 |
+
Parameters
|
| 680 |
+
----------
|
| 681 |
+
data : Tokenized
|
| 682 |
+
The tokenized data.
|
| 683 |
+
max_atoms : int, optional
|
| 684 |
+
The maximum number of atoms.
|
| 685 |
+
|
| 686 |
+
Returns
|
| 687 |
+
-------
|
| 688 |
+
dict[str, Tensor]
|
| 689 |
+
The atom features.
|
| 690 |
+
|
| 691 |
+
"""
|
| 692 |
+
# Filter to tokens' atoms
|
| 693 |
+
atom_data = []
|
| 694 |
+
ref_space_uid = []
|
| 695 |
+
coord_data = []
|
| 696 |
+
frame_data = []
|
| 697 |
+
resolved_frame_data = []
|
| 698 |
+
atom_to_token = []
|
| 699 |
+
token_to_rep_atom = [] # index on cropped atom table
|
| 700 |
+
r_set_to_rep_atom = []
|
| 701 |
+
disto_coords = []
|
| 702 |
+
atom_idx = 0
|
| 703 |
+
|
| 704 |
+
chain_res_ids = {}
|
| 705 |
+
for token_id, token in enumerate(data.tokens):
|
| 706 |
+
# Get the chain residue ids
|
| 707 |
+
chain_idx, res_id = token["asym_id"], token["res_idx"]
|
| 708 |
+
chain = data.structure.chains[chain_idx]
|
| 709 |
+
|
| 710 |
+
if (chain_idx, res_id) not in chain_res_ids:
|
| 711 |
+
new_idx = len(chain_res_ids)
|
| 712 |
+
chain_res_ids[(chain_idx, res_id)] = new_idx
|
| 713 |
+
else:
|
| 714 |
+
new_idx = chain_res_ids[(chain_idx, res_id)]
|
| 715 |
+
|
| 716 |
+
# Map atoms to token indices
|
| 717 |
+
ref_space_uid.extend([new_idx] * token["atom_num"])
|
| 718 |
+
atom_to_token.extend([token_id] * token["atom_num"])
|
| 719 |
+
|
| 720 |
+
# Add atom data
|
| 721 |
+
start = token["atom_idx"]
|
| 722 |
+
end = token["atom_idx"] + token["atom_num"]
|
| 723 |
+
token_atoms = data.structure.atoms[start:end]
|
| 724 |
+
|
| 725 |
+
# Map token to representative atom
|
| 726 |
+
token_to_rep_atom.append(atom_idx + token["disto_idx"] - start)
|
| 727 |
+
if (chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]) and token[
|
| 728 |
+
"resolved_mask"
|
| 729 |
+
]:
|
| 730 |
+
r_set_to_rep_atom.append(atom_idx + token["center_idx"] - start)
|
| 731 |
+
|
| 732 |
+
# Get token coordinates
|
| 733 |
+
token_coords = np.array([token_atoms["coords"]])
|
| 734 |
+
coord_data.append(token_coords)
|
| 735 |
+
|
| 736 |
+
# Get frame data
|
| 737 |
+
res_type = const.tokens[token["res_type"]]
|
| 738 |
+
|
| 739 |
+
if token["atom_num"] < 3 or res_type in ["PAD", "UNK", "-"]:
|
| 740 |
+
idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
|
| 741 |
+
mask_frame = False
|
| 742 |
+
elif (token["mol_type"] == const.chain_type_ids["PROTEIN"]) and (
|
| 743 |
+
res_type in const.ref_atoms
|
| 744 |
+
):
|
| 745 |
+
idx_frame_a, idx_frame_b, idx_frame_c = (
|
| 746 |
+
const.ref_atoms[res_type].index("N"),
|
| 747 |
+
const.ref_atoms[res_type].index("CA"),
|
| 748 |
+
const.ref_atoms[res_type].index("C"),
|
| 749 |
+
)
|
| 750 |
+
mask_frame = (
|
| 751 |
+
token_atoms["is_present"][idx_frame_a]
|
| 752 |
+
and token_atoms["is_present"][idx_frame_b]
|
| 753 |
+
and token_atoms["is_present"][idx_frame_c]
|
| 754 |
+
)
|
| 755 |
+
elif (
|
| 756 |
+
token["mol_type"] == const.chain_type_ids["DNA"]
|
| 757 |
+
or token["mol_type"] == const.chain_type_ids["RNA"]
|
| 758 |
+
) and (res_type in const.ref_atoms):
|
| 759 |
+
idx_frame_a, idx_frame_b, idx_frame_c = (
|
| 760 |
+
const.ref_atoms[res_type].index("C1'"),
|
| 761 |
+
const.ref_atoms[res_type].index("C3'"),
|
| 762 |
+
const.ref_atoms[res_type].index("C4'"),
|
| 763 |
+
)
|
| 764 |
+
mask_frame = (
|
| 765 |
+
token_atoms["is_present"][idx_frame_a]
|
| 766 |
+
and token_atoms["is_present"][idx_frame_b]
|
| 767 |
+
and token_atoms["is_present"][idx_frame_c]
|
| 768 |
+
)
|
| 769 |
+
else:
|
| 770 |
+
idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
|
| 771 |
+
mask_frame = False
|
| 772 |
+
frame_data.append(
|
| 773 |
+
[idx_frame_a + atom_idx, idx_frame_b + atom_idx, idx_frame_c + atom_idx]
|
| 774 |
+
)
|
| 775 |
+
resolved_frame_data.append(mask_frame)
|
| 776 |
+
|
| 777 |
+
# Get distogram coordinates
|
| 778 |
+
disto_coords_tok = data.structure.atoms[token["disto_idx"]]["coords"]
|
| 779 |
+
disto_coords.append(disto_coords_tok)
|
| 780 |
+
|
| 781 |
+
# Update atom data. This is technically never used again (we rely on coord_data),
|
| 782 |
+
# but we update for consistency and to make sure the Atom object has valid, transformed coordinates.
|
| 783 |
+
token_atoms = token_atoms.copy()
|
| 784 |
+
token_atoms["coords"] = token_coords[0] # atom has a copy of first coords
|
| 785 |
+
atom_data.append(token_atoms)
|
| 786 |
+
atom_idx += len(token_atoms)
|
| 787 |
+
|
| 788 |
+
disto_coords = np.array(disto_coords)
|
| 789 |
+
|
| 790 |
+
# Compute distogram
|
| 791 |
+
t_center = torch.Tensor(disto_coords)
|
| 792 |
+
t_dists = torch.cdist(t_center, t_center)
|
| 793 |
+
boundaries = torch.linspace(min_dist, max_dist, num_bins - 1)
|
| 794 |
+
distogram = (t_dists.unsqueeze(-1) > boundaries).sum(dim=-1).long()
|
| 795 |
+
disto_target = one_hot(distogram, num_classes=num_bins)
|
| 796 |
+
|
| 797 |
+
atom_data = np.concatenate(atom_data)
|
| 798 |
+
coord_data = np.concatenate(coord_data, axis=1)
|
| 799 |
+
ref_space_uid = np.array(ref_space_uid)
|
| 800 |
+
|
| 801 |
+
# Compute features
|
| 802 |
+
ref_atom_name_chars = from_numpy(atom_data["name"]).long()
|
| 803 |
+
ref_element = from_numpy(atom_data["element"]).long()
|
| 804 |
+
ref_charge = from_numpy(atom_data["charge"])
|
| 805 |
+
ref_pos = from_numpy(
|
| 806 |
+
atom_data["conformer"].copy()
|
| 807 |
+
) # not sure why I need to copy here..
|
| 808 |
+
ref_space_uid = from_numpy(ref_space_uid)
|
| 809 |
+
coords = from_numpy(coord_data.copy())
|
| 810 |
+
resolved_mask = from_numpy(atom_data["is_present"])
|
| 811 |
+
pad_mask = torch.ones(len(atom_data), dtype=torch.float)
|
| 812 |
+
atom_to_token = torch.tensor(atom_to_token, dtype=torch.long)
|
| 813 |
+
token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long)
|
| 814 |
+
r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long)
|
| 815 |
+
frame_data, resolved_frame_data = compute_frames_nonpolymer(
|
| 816 |
+
data,
|
| 817 |
+
coord_data,
|
| 818 |
+
atom_data["is_present"],
|
| 819 |
+
atom_to_token,
|
| 820 |
+
frame_data,
|
| 821 |
+
resolved_frame_data,
|
| 822 |
+
) # Compute frames for NONPOLYMER tokens
|
| 823 |
+
frames = from_numpy(frame_data.copy())
|
| 824 |
+
frame_resolved_mask = from_numpy(resolved_frame_data.copy())
|
| 825 |
+
# Convert to one-hot
|
| 826 |
+
ref_atom_name_chars = one_hot(
|
| 827 |
+
ref_atom_name_chars % num_bins, num_classes=num_bins
|
| 828 |
+
) # added for lower case letters
|
| 829 |
+
ref_element = one_hot(ref_element, num_classes=const.num_elements)
|
| 830 |
+
atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
|
| 831 |
+
token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data))
|
| 832 |
+
r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data))
|
| 833 |
+
|
| 834 |
+
# Center the ground truth coordinates
|
| 835 |
+
center = (coords * resolved_mask[None, :, None]).sum(dim=1)
|
| 836 |
+
center = center / resolved_mask.sum().clamp(min=1)
|
| 837 |
+
coords = coords - center[:, None]
|
| 838 |
+
|
| 839 |
+
# Apply random roto-translation to the input atoms
|
| 840 |
+
ref_pos = center_random_augmentation(
|
| 841 |
+
ref_pos[None], resolved_mask[None], centering=False
|
| 842 |
+
)[0]
|
| 843 |
+
|
| 844 |
+
# Compute padding and apply
|
| 845 |
+
if max_atoms is not None:
|
| 846 |
+
assert max_atoms % atoms_per_window_queries == 0
|
| 847 |
+
pad_len = max_atoms - len(atom_data)
|
| 848 |
+
else:
|
| 849 |
+
pad_len = (
|
| 850 |
+
(len(atom_data) - 1) // atoms_per_window_queries + 1
|
| 851 |
+
) * atoms_per_window_queries - len(atom_data)
|
| 852 |
+
|
| 853 |
+
if pad_len > 0:
|
| 854 |
+
pad_mask = pad_dim(pad_mask, 0, pad_len)
|
| 855 |
+
ref_pos = pad_dim(ref_pos, 0, pad_len)
|
| 856 |
+
resolved_mask = pad_dim(resolved_mask, 0, pad_len)
|
| 857 |
+
ref_element = pad_dim(ref_element, 0, pad_len)
|
| 858 |
+
ref_charge = pad_dim(ref_charge, 0, pad_len)
|
| 859 |
+
ref_atom_name_chars = pad_dim(ref_atom_name_chars, 0, pad_len)
|
| 860 |
+
ref_space_uid = pad_dim(ref_space_uid, 0, pad_len)
|
| 861 |
+
coords = pad_dim(coords, 1, pad_len)
|
| 862 |
+
atom_to_token = pad_dim(atom_to_token, 0, pad_len)
|
| 863 |
+
token_to_rep_atom = pad_dim(token_to_rep_atom, 1, pad_len)
|
| 864 |
+
r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 1, pad_len)
|
| 865 |
+
|
| 866 |
+
if max_tokens is not None:
|
| 867 |
+
pad_len = max_tokens - token_to_rep_atom.shape[0]
|
| 868 |
+
if pad_len > 0:
|
| 869 |
+
atom_to_token = pad_dim(atom_to_token, 1, pad_len)
|
| 870 |
+
token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len)
|
| 871 |
+
r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len)
|
| 872 |
+
disto_target = pad_dim(pad_dim(disto_target, 0, pad_len), 1, pad_len)
|
| 873 |
+
frames = pad_dim(frames, 0, pad_len)
|
| 874 |
+
frame_resolved_mask = pad_dim(frame_resolved_mask, 0, pad_len)
|
| 875 |
+
|
| 876 |
+
return {
|
| 877 |
+
"ref_pos": ref_pos,
|
| 878 |
+
"atom_resolved_mask": resolved_mask,
|
| 879 |
+
"ref_element": ref_element,
|
| 880 |
+
"ref_charge": ref_charge,
|
| 881 |
+
"ref_atom_name_chars": ref_atom_name_chars,
|
| 882 |
+
"ref_space_uid": ref_space_uid,
|
| 883 |
+
"coords": coords,
|
| 884 |
+
"atom_pad_mask": pad_mask,
|
| 885 |
+
"atom_to_token": atom_to_token,
|
| 886 |
+
"token_to_rep_atom": token_to_rep_atom,
|
| 887 |
+
"r_set_to_rep_atom": r_set_to_rep_atom,
|
| 888 |
+
"disto_target": disto_target,
|
| 889 |
+
"frames_idx": frames,
|
| 890 |
+
"frame_resolved_mask": frame_resolved_mask,
|
| 891 |
+
}
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def process_msa_features(
|
| 895 |
+
data: Tokenized,
|
| 896 |
+
max_seqs_batch: int,
|
| 897 |
+
max_seqs: int,
|
| 898 |
+
max_tokens: Optional[int] = None,
|
| 899 |
+
pad_to_max_seqs: bool = False,
|
| 900 |
+
) -> dict[str, Tensor]:
|
| 901 |
+
"""Get the MSA features.
|
| 902 |
+
|
| 903 |
+
Parameters
|
| 904 |
+
----------
|
| 905 |
+
data : Tokenized
|
| 906 |
+
The tokenized data.
|
| 907 |
+
max_seqs : int
|
| 908 |
+
The maximum number of MSA sequences.
|
| 909 |
+
max_tokens : int
|
| 910 |
+
The maximum number of tokens.
|
| 911 |
+
pad_to_max_seqs : bool
|
| 912 |
+
Whether to pad to the maximum number of sequences.
|
| 913 |
+
|
| 914 |
+
Returns
|
| 915 |
+
-------
|
| 916 |
+
dict[str, Tensor]
|
| 917 |
+
The MSA features.
|
| 918 |
+
|
| 919 |
+
"""
|
| 920 |
+
# Created paired MSA
|
| 921 |
+
msa, deletion, paired = construct_paired_msa(data, max_seqs_batch)
|
| 922 |
+
msa, deletion, paired = (
|
| 923 |
+
msa.transpose(1, 0),
|
| 924 |
+
deletion.transpose(1, 0),
|
| 925 |
+
paired.transpose(1, 0),
|
| 926 |
+
) # (N_MSA, N_RES, N_AA)
|
| 927 |
+
|
| 928 |
+
# Prepare features
|
| 929 |
+
msa = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
|
| 930 |
+
msa_mask = torch.ones_like(msa[:, :, 0])
|
| 931 |
+
profile = msa.float().mean(dim=0)
|
| 932 |
+
has_deletion = deletion > 0
|
| 933 |
+
deletion = np.pi / 2 * np.arctan(deletion / 3)
|
| 934 |
+
deletion_mean = deletion.mean(axis=0)
|
| 935 |
+
|
| 936 |
+
# Pad in the MSA dimension (dim=0)
|
| 937 |
+
if pad_to_max_seqs:
|
| 938 |
+
pad_len = max_seqs - msa.shape[0]
|
| 939 |
+
if pad_len > 0:
|
| 940 |
+
msa = pad_dim(msa, 0, pad_len, const.token_ids["-"])
|
| 941 |
+
paired = pad_dim(paired, 0, pad_len)
|
| 942 |
+
msa_mask = pad_dim(msa_mask, 0, pad_len)
|
| 943 |
+
has_deletion = pad_dim(has_deletion, 0, pad_len)
|
| 944 |
+
deletion = pad_dim(deletion, 0, pad_len)
|
| 945 |
+
|
| 946 |
+
# Pad in the token dimension (dim=1)
|
| 947 |
+
if max_tokens is not None:
|
| 948 |
+
pad_len = max_tokens - msa.shape[1]
|
| 949 |
+
if pad_len > 0:
|
| 950 |
+
msa = pad_dim(msa, 1, pad_len, const.token_ids["-"])
|
| 951 |
+
paired = pad_dim(paired, 1, pad_len)
|
| 952 |
+
msa_mask = pad_dim(msa_mask, 1, pad_len)
|
| 953 |
+
has_deletion = pad_dim(has_deletion, 1, pad_len)
|
| 954 |
+
deletion = pad_dim(deletion, 1, pad_len)
|
| 955 |
+
profile = pad_dim(profile, 0, pad_len)
|
| 956 |
+
deletion_mean = pad_dim(deletion_mean, 0, pad_len)
|
| 957 |
+
|
| 958 |
+
return {
|
| 959 |
+
"msa": msa,
|
| 960 |
+
"msa_paired": paired,
|
| 961 |
+
"deletion_value": deletion,
|
| 962 |
+
"has_deletion": has_deletion,
|
| 963 |
+
"deletion_mean": deletion_mean,
|
| 964 |
+
"profile": profile,
|
| 965 |
+
"msa_mask": msa_mask,
|
| 966 |
+
}
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
def process_symmetry_features(
|
| 970 |
+
cropped: Tokenized, symmetries: dict
|
| 971 |
+
) -> dict[str, Tensor]:
|
| 972 |
+
"""Get the symmetry features.
|
| 973 |
+
|
| 974 |
+
Parameters
|
| 975 |
+
----------
|
| 976 |
+
data : Tokenized
|
| 977 |
+
The tokenized data.
|
| 978 |
+
|
| 979 |
+
Returns
|
| 980 |
+
-------
|
| 981 |
+
dict[str, Tensor]
|
| 982 |
+
The symmetry features.
|
| 983 |
+
|
| 984 |
+
"""
|
| 985 |
+
features = get_chain_symmetries(cropped)
|
| 986 |
+
features.update(get_amino_acids_symmetries(cropped))
|
| 987 |
+
features.update(get_ligand_symmetries(cropped, symmetries))
|
| 988 |
+
|
| 989 |
+
return features
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
def process_residue_constraint_features(
|
| 993 |
+
data: Tokenized,
|
| 994 |
+
) -> dict[str, Tensor]:
|
| 995 |
+
residue_constraints = data.residue_constraints
|
| 996 |
+
if residue_constraints is not None:
|
| 997 |
+
rdkit_bounds_constraints = residue_constraints.rdkit_bounds_constraints
|
| 998 |
+
chiral_atom_constraints = residue_constraints.chiral_atom_constraints
|
| 999 |
+
stereo_bond_constraints = residue_constraints.stereo_bond_constraints
|
| 1000 |
+
planar_bond_constraints = residue_constraints.planar_bond_constraints
|
| 1001 |
+
planar_ring_5_constraints = residue_constraints.planar_ring_5_constraints
|
| 1002 |
+
planar_ring_6_constraints = residue_constraints.planar_ring_6_constraints
|
| 1003 |
+
|
| 1004 |
+
rdkit_bounds_index = torch.tensor(
|
| 1005 |
+
rdkit_bounds_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1006 |
+
).T
|
| 1007 |
+
rdkit_bounds_bond_mask = torch.tensor(
|
| 1008 |
+
rdkit_bounds_constraints["is_bond"].copy(), dtype=torch.bool
|
| 1009 |
+
)
|
| 1010 |
+
rdkit_bounds_angle_mask = torch.tensor(
|
| 1011 |
+
rdkit_bounds_constraints["is_angle"].copy(), dtype=torch.bool
|
| 1012 |
+
)
|
| 1013 |
+
rdkit_upper_bounds = torch.tensor(
|
| 1014 |
+
rdkit_bounds_constraints["upper_bound"].copy(), dtype=torch.float
|
| 1015 |
+
)
|
| 1016 |
+
rdkit_lower_bounds = torch.tensor(
|
| 1017 |
+
rdkit_bounds_constraints["lower_bound"].copy(), dtype=torch.float
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
chiral_atom_index = torch.tensor(
|
| 1021 |
+
chiral_atom_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1022 |
+
).T
|
| 1023 |
+
chiral_reference_mask = torch.tensor(
|
| 1024 |
+
chiral_atom_constraints["is_reference"].copy(), dtype=torch.bool
|
| 1025 |
+
)
|
| 1026 |
+
chiral_atom_orientations = torch.tensor(
|
| 1027 |
+
chiral_atom_constraints["is_r"].copy(), dtype=torch.bool
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
stereo_bond_index = torch.tensor(
|
| 1031 |
+
stereo_bond_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1032 |
+
).T
|
| 1033 |
+
stereo_reference_mask = torch.tensor(
|
| 1034 |
+
stereo_bond_constraints["is_reference"].copy(), dtype=torch.bool
|
| 1035 |
+
)
|
| 1036 |
+
stereo_bond_orientations = torch.tensor(
|
| 1037 |
+
stereo_bond_constraints["is_e"].copy(), dtype=torch.bool
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
planar_bond_index = torch.tensor(
|
| 1041 |
+
planar_bond_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1042 |
+
).T
|
| 1043 |
+
planar_ring_5_index = torch.tensor(
|
| 1044 |
+
planar_ring_5_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1045 |
+
).T
|
| 1046 |
+
planar_ring_6_index = torch.tensor(
|
| 1047 |
+
planar_ring_6_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1048 |
+
).T
|
| 1049 |
+
else:
|
| 1050 |
+
rdkit_bounds_index = torch.empty((2, 0), dtype=torch.long)
|
| 1051 |
+
rdkit_bounds_bond_mask = torch.empty((0,), dtype=torch.bool)
|
| 1052 |
+
rdkit_bounds_angle_mask = torch.empty((0,), dtype=torch.bool)
|
| 1053 |
+
rdkit_upper_bounds = torch.empty((0,), dtype=torch.float)
|
| 1054 |
+
rdkit_lower_bounds = torch.empty((0,), dtype=torch.float)
|
| 1055 |
+
chiral_atom_index = torch.empty(
|
| 1056 |
+
(
|
| 1057 |
+
4,
|
| 1058 |
+
0,
|
| 1059 |
+
),
|
| 1060 |
+
dtype=torch.long,
|
| 1061 |
+
)
|
| 1062 |
+
chiral_reference_mask = torch.empty((0,), dtype=torch.bool)
|
| 1063 |
+
chiral_atom_orientations = torch.empty((0,), dtype=torch.bool)
|
| 1064 |
+
stereo_bond_index = torch.empty((4, 0), dtype=torch.long)
|
| 1065 |
+
stereo_reference_mask = torch.empty((0,), dtype=torch.bool)
|
| 1066 |
+
stereo_bond_orientations = torch.empty((0,), dtype=torch.bool)
|
| 1067 |
+
planar_bond_index = torch.empty((6, 0), dtype=torch.long)
|
| 1068 |
+
planar_ring_5_index = torch.empty((5, 0), dtype=torch.long)
|
| 1069 |
+
planar_ring_6_index = torch.empty((6, 0), dtype=torch.long)
|
| 1070 |
+
|
| 1071 |
+
return {
|
| 1072 |
+
"rdkit_bounds_index": rdkit_bounds_index,
|
| 1073 |
+
"rdkit_bounds_bond_mask": rdkit_bounds_bond_mask,
|
| 1074 |
+
"rdkit_bounds_angle_mask": rdkit_bounds_angle_mask,
|
| 1075 |
+
"rdkit_upper_bounds": rdkit_upper_bounds,
|
| 1076 |
+
"rdkit_lower_bounds": rdkit_lower_bounds,
|
| 1077 |
+
"chiral_atom_index": chiral_atom_index,
|
| 1078 |
+
"chiral_reference_mask": chiral_reference_mask,
|
| 1079 |
+
"chiral_atom_orientations": chiral_atom_orientations,
|
| 1080 |
+
"stereo_bond_index": stereo_bond_index,
|
| 1081 |
+
"stereo_reference_mask": stereo_reference_mask,
|
| 1082 |
+
"stereo_bond_orientations": stereo_bond_orientations,
|
| 1083 |
+
"planar_bond_index": planar_bond_index,
|
| 1084 |
+
"planar_ring_5_index": planar_ring_5_index,
|
| 1085 |
+
"planar_ring_6_index": planar_ring_6_index,
|
| 1086 |
+
}
|
| 1087 |
+
|
| 1088 |
+
|
| 1089 |
+
def process_chain_feature_constraints(
|
| 1090 |
+
data: Tokenized,
|
| 1091 |
+
) -> dict[str, Tensor]:
|
| 1092 |
+
structure = data.structure
|
| 1093 |
+
if structure.connections.shape[0] > 0:
|
| 1094 |
+
connected_chain_index, connected_atom_index = [], []
|
| 1095 |
+
for connection in structure.connections:
|
| 1096 |
+
connected_chain_index.append([connection["chain_1"], connection["chain_2"]])
|
| 1097 |
+
connected_atom_index.append([connection["atom_1"], connection["atom_2"]])
|
| 1098 |
+
connected_chain_index = torch.tensor(connected_chain_index, dtype=torch.long).T
|
| 1099 |
+
connected_atom_index = torch.tensor(connected_atom_index, dtype=torch.long).T
|
| 1100 |
+
else:
|
| 1101 |
+
connected_chain_index = torch.empty((2, 0), dtype=torch.long)
|
| 1102 |
+
connected_atom_index = torch.empty((2, 0), dtype=torch.long)
|
| 1103 |
+
|
| 1104 |
+
symmetric_chain_index = []
|
| 1105 |
+
for i, chain_i in enumerate(structure.chains):
|
| 1106 |
+
for j, chain_j in enumerate(structure.chains):
|
| 1107 |
+
if j <= i:
|
| 1108 |
+
continue
|
| 1109 |
+
if chain_i["entity_id"] == chain_j["entity_id"]:
|
| 1110 |
+
symmetric_chain_index.append([i, j])
|
| 1111 |
+
if len(symmetric_chain_index) > 0:
|
| 1112 |
+
symmetric_chain_index = torch.tensor(symmetric_chain_index, dtype=torch.long).T
|
| 1113 |
+
else:
|
| 1114 |
+
symmetric_chain_index = torch.empty((2, 0), dtype=torch.long)
|
| 1115 |
+
return {
|
| 1116 |
+
"connected_chain_index": connected_chain_index,
|
| 1117 |
+
"connected_atom_index": connected_atom_index,
|
| 1118 |
+
"symmetric_chain_index": symmetric_chain_index,
|
| 1119 |
+
}
|
| 1120 |
+
|
| 1121 |
+
|
| 1122 |
+
class BoltzFeaturizer:
|
| 1123 |
+
"""Boltz featurizer."""
|
| 1124 |
+
|
| 1125 |
+
def process(
|
| 1126 |
+
self,
|
| 1127 |
+
data: Tokenized,
|
| 1128 |
+
training: bool,
|
| 1129 |
+
max_seqs: int = 4096,
|
| 1130 |
+
atoms_per_window_queries: int = 32,
|
| 1131 |
+
min_dist: float = 2.0,
|
| 1132 |
+
max_dist: float = 22.0,
|
| 1133 |
+
num_bins: int = 64,
|
| 1134 |
+
max_tokens: Optional[int] = None,
|
| 1135 |
+
max_atoms: Optional[int] = None,
|
| 1136 |
+
pad_to_max_seqs: bool = False,
|
| 1137 |
+
compute_symmetries: bool = False,
|
| 1138 |
+
symmetries: Optional[dict] = None,
|
| 1139 |
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
| 1140 |
+
binder_pocket_cutoff: Optional[float] = 6.0,
|
| 1141 |
+
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
|
| 1142 |
+
only_ligand_binder_pocket: Optional[bool] = False,
|
| 1143 |
+
inference_binder: Optional[int] = None,
|
| 1144 |
+
inference_pocket: Optional[list[tuple[int, int]]] = None,
|
| 1145 |
+
compute_constraint_features: bool = False,
|
| 1146 |
+
) -> dict[str, Tensor]:
|
| 1147 |
+
"""Compute features.
|
| 1148 |
+
|
| 1149 |
+
Parameters
|
| 1150 |
+
----------
|
| 1151 |
+
data : Tokenized
|
| 1152 |
+
The tokenized data.
|
| 1153 |
+
training : bool
|
| 1154 |
+
Whether the model is in training mode.
|
| 1155 |
+
max_tokens : int, optional
|
| 1156 |
+
The maximum number of tokens.
|
| 1157 |
+
max_atoms : int, optional
|
| 1158 |
+
The maximum number of atoms
|
| 1159 |
+
max_seqs : int, optional
|
| 1160 |
+
The maximum number of sequences.
|
| 1161 |
+
|
| 1162 |
+
Returns
|
| 1163 |
+
-------
|
| 1164 |
+
dict[str, Tensor]
|
| 1165 |
+
The features for model training.
|
| 1166 |
+
|
| 1167 |
+
"""
|
| 1168 |
+
# Compute random number of sequences
|
| 1169 |
+
if training and max_seqs is not None:
|
| 1170 |
+
max_seqs_batch = np.random.randint(1, max_seqs + 1) # noqa: NPY002
|
| 1171 |
+
else:
|
| 1172 |
+
max_seqs_batch = max_seqs
|
| 1173 |
+
|
| 1174 |
+
# Compute token features
|
| 1175 |
+
token_features = process_token_features(
|
| 1176 |
+
data,
|
| 1177 |
+
max_tokens,
|
| 1178 |
+
binder_pocket_conditioned_prop,
|
| 1179 |
+
binder_pocket_cutoff,
|
| 1180 |
+
binder_pocket_sampling_geometric_p,
|
| 1181 |
+
only_ligand_binder_pocket,
|
| 1182 |
+
inference_binder=inference_binder,
|
| 1183 |
+
inference_pocket=inference_pocket,
|
| 1184 |
+
)
|
| 1185 |
+
|
| 1186 |
+
# Compute atom features
|
| 1187 |
+
atom_features = process_atom_features(
|
| 1188 |
+
data,
|
| 1189 |
+
atoms_per_window_queries,
|
| 1190 |
+
min_dist,
|
| 1191 |
+
max_dist,
|
| 1192 |
+
num_bins,
|
| 1193 |
+
max_atoms,
|
| 1194 |
+
max_tokens,
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
# Compute MSA features
|
| 1198 |
+
msa_features = process_msa_features(
|
| 1199 |
+
data,
|
| 1200 |
+
max_seqs_batch,
|
| 1201 |
+
max_seqs,
|
| 1202 |
+
max_tokens,
|
| 1203 |
+
pad_to_max_seqs,
|
| 1204 |
+
)
|
| 1205 |
+
|
| 1206 |
+
# Compute symmetry features
|
| 1207 |
+
symmetry_features = {}
|
| 1208 |
+
if compute_symmetries:
|
| 1209 |
+
symmetry_features = process_symmetry_features(data, symmetries)
|
| 1210 |
+
|
| 1211 |
+
# Compute constraint features
|
| 1212 |
+
residue_constraint_features = {}
|
| 1213 |
+
chain_constraint_features = {}
|
| 1214 |
+
if compute_constraint_features:
|
| 1215 |
+
residue_constraint_features = process_residue_constraint_features(data)
|
| 1216 |
+
chain_constraint_features = process_chain_feature_constraints(data)
|
| 1217 |
+
|
| 1218 |
+
return {
|
| 1219 |
+
**token_features,
|
| 1220 |
+
**atom_features,
|
| 1221 |
+
**msa_features,
|
| 1222 |
+
**symmetry_features,
|
| 1223 |
+
**residue_constraint_features,
|
| 1224 |
+
**chain_constraint_features,
|
| 1225 |
+
}
|
protify/FastPLMs/boltz/src/boltz/data/feature/featurizerv2.py
ADDED
|
@@ -0,0 +1,2354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from collections import deque
|
| 4 |
+
import numba
|
| 5 |
+
import numpy as np
|
| 6 |
+
import numpy.typing as npt
|
| 7 |
+
import rdkit.Chem.Descriptors
|
| 8 |
+
import torch
|
| 9 |
+
from numba import types
|
| 10 |
+
from rdkit.Chem import Mol
|
| 11 |
+
from scipy.spatial.distance import cdist
|
| 12 |
+
from torch import Tensor, from_numpy
|
| 13 |
+
from torch.nn.functional import one_hot
|
| 14 |
+
|
| 15 |
+
from boltz.data import const
|
| 16 |
+
from boltz.data.mol import (
|
| 17 |
+
get_amino_acids_symmetries,
|
| 18 |
+
get_chain_symmetries,
|
| 19 |
+
get_ligand_symmetries,
|
| 20 |
+
get_symmetries,
|
| 21 |
+
)
|
| 22 |
+
from boltz.data.pad import pad_dim
|
| 23 |
+
from boltz.data.types import (
|
| 24 |
+
MSA,
|
| 25 |
+
MSADeletion,
|
| 26 |
+
MSAResidue,
|
| 27 |
+
MSASequence,
|
| 28 |
+
TemplateInfo,
|
| 29 |
+
Tokenized,
|
| 30 |
+
)
|
| 31 |
+
from boltz.model.modules.utils import center_random_augmentation
|
| 32 |
+
|
| 33 |
+
####################################################################################################
|
| 34 |
+
# HELPERS
|
| 35 |
+
####################################################################################################
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def convert_atom_name(name: str) -> tuple[int, int, int, int]:
|
| 39 |
+
"""Convert an atom name to a standard format.
|
| 40 |
+
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
name : str
|
| 44 |
+
The atom name.
|
| 45 |
+
|
| 46 |
+
Returns
|
| 47 |
+
-------
|
| 48 |
+
tuple[int, int, int, int]
|
| 49 |
+
The converted atom name.
|
| 50 |
+
|
| 51 |
+
"""
|
| 52 |
+
name = str(name).strip()
|
| 53 |
+
name = [ord(c) - 32 for c in name]
|
| 54 |
+
name = name + [0] * (4 - len(name))
|
| 55 |
+
return tuple(name)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def sample_d(
|
| 59 |
+
min_d: float,
|
| 60 |
+
max_d: float,
|
| 61 |
+
n_samples: int,
|
| 62 |
+
random: np.random.Generator,
|
| 63 |
+
) -> np.ndarray:
|
| 64 |
+
"""Generate samples from a 1/d distribution between min_d and max_d.
|
| 65 |
+
|
| 66 |
+
Parameters
|
| 67 |
+
----------
|
| 68 |
+
min_d : float
|
| 69 |
+
Minimum value of d
|
| 70 |
+
max_d : float
|
| 71 |
+
Maximum value of d
|
| 72 |
+
n_samples : int
|
| 73 |
+
Number of samples to generate
|
| 74 |
+
random : numpy.random.Generator
|
| 75 |
+
Random number generator
|
| 76 |
+
|
| 77 |
+
Returns
|
| 78 |
+
-------
|
| 79 |
+
numpy.ndarray
|
| 80 |
+
Array of samples drawn from the distribution
|
| 81 |
+
|
| 82 |
+
Notes
|
| 83 |
+
-----
|
| 84 |
+
The probability density function is:
|
| 85 |
+
f(d) = 1/(d * ln(max_d/min_d)) for d in [min_d, max_d]
|
| 86 |
+
|
| 87 |
+
The inverse CDF transform is:
|
| 88 |
+
d = min_d * (max_d/min_d)**u where u ~ Uniform(0,1)
|
| 89 |
+
|
| 90 |
+
"""
|
| 91 |
+
# Generate n_samples uniform random numbers in [0, 1]
|
| 92 |
+
u = random.random(n_samples)
|
| 93 |
+
# Transform u using the inverse CDF
|
| 94 |
+
return min_d * (max_d / min_d) ** u
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def compute_frames_nonpolymer(
|
| 98 |
+
data: Tokenized,
|
| 99 |
+
coords,
|
| 100 |
+
resolved_mask,
|
| 101 |
+
atom_to_token,
|
| 102 |
+
frame_data: list,
|
| 103 |
+
resolved_frame_data: list,
|
| 104 |
+
) -> tuple[list, list]:
|
| 105 |
+
"""Get the frames for non-polymer tokens.
|
| 106 |
+
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
data : Tokenized
|
| 110 |
+
The input data to the model.
|
| 111 |
+
frame_data : list
|
| 112 |
+
The frame data.
|
| 113 |
+
resolved_frame_data : list
|
| 114 |
+
The resolved frame data.
|
| 115 |
+
|
| 116 |
+
Returns
|
| 117 |
+
-------
|
| 118 |
+
tuple[list, list]
|
| 119 |
+
The frame data and resolved frame data.
|
| 120 |
+
|
| 121 |
+
"""
|
| 122 |
+
frame_data = np.array(frame_data)
|
| 123 |
+
resolved_frame_data = np.array(resolved_frame_data)
|
| 124 |
+
asym_id_token = data.tokens["asym_id"]
|
| 125 |
+
asym_id_atom = data.tokens["asym_id"][atom_to_token]
|
| 126 |
+
token_idx = 0
|
| 127 |
+
atom_idx = 0
|
| 128 |
+
for id in np.unique(data.tokens["asym_id"]):
|
| 129 |
+
mask_chain_token = asym_id_token == id
|
| 130 |
+
mask_chain_atom = asym_id_atom == id
|
| 131 |
+
num_tokens = mask_chain_token.sum()
|
| 132 |
+
num_atoms = mask_chain_atom.sum()
|
| 133 |
+
if (
|
| 134 |
+
data.tokens[token_idx]["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
| 135 |
+
or num_atoms < 3 # noqa: PLR2004
|
| 136 |
+
):
|
| 137 |
+
token_idx += num_tokens
|
| 138 |
+
atom_idx += num_atoms
|
| 139 |
+
continue
|
| 140 |
+
dist_mat = (
|
| 141 |
+
(
|
| 142 |
+
coords.reshape(-1, 3)[mask_chain_atom][:, None, :]
|
| 143 |
+
- coords.reshape(-1, 3)[mask_chain_atom][None, :, :]
|
| 144 |
+
)
|
| 145 |
+
** 2
|
| 146 |
+
).sum(-1) ** 0.5
|
| 147 |
+
resolved_pair = 1 - (
|
| 148 |
+
resolved_mask[mask_chain_atom][None, :]
|
| 149 |
+
* resolved_mask[mask_chain_atom][:, None]
|
| 150 |
+
).astype(np.float32)
|
| 151 |
+
resolved_pair[resolved_pair == 1] = math.inf
|
| 152 |
+
indices = np.argsort(dist_mat + resolved_pair, axis=1)
|
| 153 |
+
frames = (
|
| 154 |
+
np.concatenate(
|
| 155 |
+
[
|
| 156 |
+
indices[:, 1:2],
|
| 157 |
+
indices[:, 0:1],
|
| 158 |
+
indices[:, 2:3],
|
| 159 |
+
],
|
| 160 |
+
axis=1,
|
| 161 |
+
)
|
| 162 |
+
+ atom_idx
|
| 163 |
+
)
|
| 164 |
+
frame_data[token_idx : token_idx + num_atoms, :] = frames
|
| 165 |
+
resolved_frame_data[token_idx : token_idx + num_atoms] = resolved_mask[
|
| 166 |
+
frames
|
| 167 |
+
].all(axis=1)
|
| 168 |
+
token_idx += num_tokens
|
| 169 |
+
atom_idx += num_atoms
|
| 170 |
+
frames_expanded = coords.reshape(-1, 3)[frame_data]
|
| 171 |
+
|
| 172 |
+
mask_collinear = compute_collinear_mask(
|
| 173 |
+
frames_expanded[:, 1] - frames_expanded[:, 0],
|
| 174 |
+
frames_expanded[:, 1] - frames_expanded[:, 2],
|
| 175 |
+
)
|
| 176 |
+
return frame_data, resolved_frame_data & mask_collinear
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def compute_collinear_mask(v1, v2):
|
| 180 |
+
norm1 = np.linalg.norm(v1, axis=1, keepdims=True)
|
| 181 |
+
norm2 = np.linalg.norm(v2, axis=1, keepdims=True)
|
| 182 |
+
v1 = v1 / (norm1 + 1e-6)
|
| 183 |
+
v2 = v2 / (norm2 + 1e-6)
|
| 184 |
+
mask_angle = np.abs(np.sum(v1 * v2, axis=1)) < 0.9063
|
| 185 |
+
mask_overlap1 = norm1.reshape(-1) > 1e-2
|
| 186 |
+
mask_overlap2 = norm2.reshape(-1) > 1e-2
|
| 187 |
+
return mask_angle & mask_overlap1 & mask_overlap2
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def dummy_msa(residues: np.ndarray) -> MSA:
|
| 191 |
+
"""Create a dummy MSA for a chain.
|
| 192 |
+
|
| 193 |
+
Parameters
|
| 194 |
+
----------
|
| 195 |
+
residues : np.ndarray
|
| 196 |
+
The residues for the chain.
|
| 197 |
+
|
| 198 |
+
Returns
|
| 199 |
+
-------
|
| 200 |
+
MSA
|
| 201 |
+
The dummy MSA.
|
| 202 |
+
|
| 203 |
+
"""
|
| 204 |
+
residues = [res["res_type"] for res in residues]
|
| 205 |
+
deletions = []
|
| 206 |
+
sequences = [(0, -1, 0, len(residues), 0, 0)]
|
| 207 |
+
return MSA(
|
| 208 |
+
residues=np.array(residues, dtype=MSAResidue),
|
| 209 |
+
deletions=np.array(deletions, dtype=MSADeletion),
|
| 210 |
+
sequences=np.array(sequences, dtype=MSASequence),
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def construct_paired_msa( # noqa: C901, PLR0915, PLR0912
|
| 215 |
+
data: Tokenized,
|
| 216 |
+
random: np.random.Generator,
|
| 217 |
+
max_seqs: int,
|
| 218 |
+
max_pairs: int = 8192,
|
| 219 |
+
max_total: int = 16384,
|
| 220 |
+
random_subset: bool = False,
|
| 221 |
+
) -> tuple[Tensor, Tensor, Tensor]:
|
| 222 |
+
"""Pair the MSA data.
|
| 223 |
+
|
| 224 |
+
Parameters
|
| 225 |
+
----------
|
| 226 |
+
data : Tokenized
|
| 227 |
+
The input data to the model.
|
| 228 |
+
|
| 229 |
+
Returns
|
| 230 |
+
-------
|
| 231 |
+
Tensor
|
| 232 |
+
The MSA data.
|
| 233 |
+
Tensor
|
| 234 |
+
The deletion data.
|
| 235 |
+
Tensor
|
| 236 |
+
Mask indicating paired sequences.
|
| 237 |
+
|
| 238 |
+
"""
|
| 239 |
+
# Get unique chains (ensuring monotonicity in the order)
|
| 240 |
+
assert np.all(np.diff(data.tokens["asym_id"], n=1) >= 0)
|
| 241 |
+
chain_ids = np.unique(data.tokens["asym_id"])
|
| 242 |
+
|
| 243 |
+
# Get relevant MSA, and create a dummy for chains without
|
| 244 |
+
msa: dict[int, MSA] = {}
|
| 245 |
+
for chain_id in chain_ids:
|
| 246 |
+
# Get input sequence
|
| 247 |
+
chain = data.structure.chains[chain_id]
|
| 248 |
+
res_start = chain["res_idx"]
|
| 249 |
+
res_end = res_start + chain["res_num"]
|
| 250 |
+
residues = data.structure.residues[res_start:res_end]
|
| 251 |
+
|
| 252 |
+
# Check if we have an MSA, and that the
|
| 253 |
+
# first sequence matches the input sequence
|
| 254 |
+
if chain_id in data.msa:
|
| 255 |
+
# Set the MSA
|
| 256 |
+
msa[chain_id] = data.msa[chain_id]
|
| 257 |
+
|
| 258 |
+
# Run length and residue type checks
|
| 259 |
+
first = data.msa[chain_id].sequences[0]
|
| 260 |
+
first_start = first["res_start"]
|
| 261 |
+
first_end = first["res_end"]
|
| 262 |
+
msa_residues = data.msa[chain_id].residues
|
| 263 |
+
first_residues = msa_residues[first_start:first_end]
|
| 264 |
+
|
| 265 |
+
warning = "Warning: MSA does not match input sequence, creating dummy."
|
| 266 |
+
if len(residues) == len(first_residues):
|
| 267 |
+
# If there is a mismatch, check if it is between MET & UNK
|
| 268 |
+
# If so, replace the first sequence with the input sequence.
|
| 269 |
+
# Otherwise, replace with a dummy MSA for this chain.
|
| 270 |
+
mismatches = residues["res_type"] != first_residues["res_type"]
|
| 271 |
+
if mismatches.sum().item():
|
| 272 |
+
idx = np.where(mismatches)[0]
|
| 273 |
+
is_met = residues["res_type"][idx] == const.token_ids["MET"]
|
| 274 |
+
is_unk = residues["res_type"][idx] == const.token_ids["UNK"]
|
| 275 |
+
is_msa_unk = (
|
| 276 |
+
first_residues["res_type"][idx] == const.token_ids["UNK"]
|
| 277 |
+
)
|
| 278 |
+
if (np.all(is_met) and np.all(is_msa_unk)) or np.all(is_unk):
|
| 279 |
+
msa_residues[first_start:first_end]["res_type"] = residues[
|
| 280 |
+
"res_type"
|
| 281 |
+
]
|
| 282 |
+
else:
|
| 283 |
+
print(
|
| 284 |
+
warning,
|
| 285 |
+
"1",
|
| 286 |
+
residues["res_type"],
|
| 287 |
+
first_residues["res_type"],
|
| 288 |
+
data.record.id,
|
| 289 |
+
)
|
| 290 |
+
msa[chain_id] = dummy_msa(residues)
|
| 291 |
+
else:
|
| 292 |
+
print(
|
| 293 |
+
warning,
|
| 294 |
+
"2",
|
| 295 |
+
residues["res_type"],
|
| 296 |
+
first_residues["res_type"],
|
| 297 |
+
data.record.id,
|
| 298 |
+
)
|
| 299 |
+
msa[chain_id] = dummy_msa(residues)
|
| 300 |
+
else:
|
| 301 |
+
msa[chain_id] = dummy_msa(residues)
|
| 302 |
+
|
| 303 |
+
# Map taxonomies to (chain_id, seq_idx)
|
| 304 |
+
taxonomy_map: dict[str, list] = {}
|
| 305 |
+
for chain_id, chain_msa in msa.items():
|
| 306 |
+
sequences = chain_msa.sequences
|
| 307 |
+
sequences = sequences[sequences["taxonomy"] != -1]
|
| 308 |
+
for sequence in sequences:
|
| 309 |
+
seq_idx = sequence["seq_idx"]
|
| 310 |
+
taxon = sequence["taxonomy"]
|
| 311 |
+
taxonomy_map.setdefault(taxon, []).append((chain_id, seq_idx))
|
| 312 |
+
|
| 313 |
+
# Remove taxonomies with only one sequence and sort by the
|
| 314 |
+
# number of chain_id present in each of the taxonomies
|
| 315 |
+
taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1}
|
| 316 |
+
taxonomy_map = sorted(
|
| 317 |
+
taxonomy_map.items(),
|
| 318 |
+
key=lambda x: len({c for c, _ in x[1]}),
|
| 319 |
+
reverse=True,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Keep track of the sequences available per chain, keeping the original
|
| 323 |
+
# order of the sequences in the MSA to favor the best matching sequences
|
| 324 |
+
visited = {(c, s) for c, items in taxonomy_map for s in items}
|
| 325 |
+
available = {}
|
| 326 |
+
for c in chain_ids:
|
| 327 |
+
available[c] = deque(
|
| 328 |
+
i for i in range(1, len(msa[c].sequences)) if (c, i) not in visited
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Create sequence pairs
|
| 332 |
+
is_paired = []
|
| 333 |
+
pairing = []
|
| 334 |
+
|
| 335 |
+
# Start with the first sequence for each chain
|
| 336 |
+
is_paired.append({c: 1 for c in chain_ids})
|
| 337 |
+
pairing.append({c: 0 for c in chain_ids})
|
| 338 |
+
|
| 339 |
+
# Then add up to 8191 paired rows
|
| 340 |
+
for _, pairs in taxonomy_map:
|
| 341 |
+
# Group occurences by chain_id in case we have multiple
|
| 342 |
+
# sequences from the same chain and same taxonomy
|
| 343 |
+
chain_occurences = {}
|
| 344 |
+
for chain_id, seq_idx in pairs:
|
| 345 |
+
chain_occurences.setdefault(chain_id, []).append(seq_idx)
|
| 346 |
+
|
| 347 |
+
# We create as many pairings as the maximum number of occurences
|
| 348 |
+
max_occurences = max(len(v) for v in chain_occurences.values())
|
| 349 |
+
for i in range(max_occurences):
|
| 350 |
+
row_pairing = {}
|
| 351 |
+
row_is_paired = {}
|
| 352 |
+
|
| 353 |
+
# Add the chains present in the taxonomy
|
| 354 |
+
for chain_id, seq_idxs in chain_occurences.items():
|
| 355 |
+
# Roll over the sequence index to maximize diversity
|
| 356 |
+
idx = i % len(seq_idxs)
|
| 357 |
+
seq_idx = seq_idxs[idx]
|
| 358 |
+
|
| 359 |
+
# Add the sequence to the pairing
|
| 360 |
+
row_pairing[chain_id] = seq_idx
|
| 361 |
+
row_is_paired[chain_id] = 1
|
| 362 |
+
|
| 363 |
+
# Add any missing chains
|
| 364 |
+
for chain_id in chain_ids:
|
| 365 |
+
if chain_id not in row_pairing:
|
| 366 |
+
row_is_paired[chain_id] = 0
|
| 367 |
+
if available[chain_id]:
|
| 368 |
+
# Add the next available sequence
|
| 369 |
+
row_pairing[chain_id] = available[chain_id].popleft()
|
| 370 |
+
else:
|
| 371 |
+
# No more sequences available, we place a gap
|
| 372 |
+
row_pairing[chain_id] = -1
|
| 373 |
+
|
| 374 |
+
pairing.append(row_pairing)
|
| 375 |
+
is_paired.append(row_is_paired)
|
| 376 |
+
|
| 377 |
+
# Break if we have enough pairs
|
| 378 |
+
if len(pairing) >= max_pairs:
|
| 379 |
+
break
|
| 380 |
+
|
| 381 |
+
# Break if we have enough pairs
|
| 382 |
+
if len(pairing) >= max_pairs:
|
| 383 |
+
break
|
| 384 |
+
|
| 385 |
+
# Now add up to 16384 unpaired rows total
|
| 386 |
+
max_left = max(len(v) for v in available.values())
|
| 387 |
+
for _ in range(min(max_total - len(pairing), max_left)):
|
| 388 |
+
row_pairing = {}
|
| 389 |
+
row_is_paired = {}
|
| 390 |
+
for chain_id in chain_ids:
|
| 391 |
+
row_is_paired[chain_id] = 0
|
| 392 |
+
if available[chain_id]:
|
| 393 |
+
# Add the next available sequence
|
| 394 |
+
row_pairing[chain_id] = available[chain_id].popleft()
|
| 395 |
+
else:
|
| 396 |
+
# No more sequences available, we place a gap
|
| 397 |
+
row_pairing[chain_id] = -1
|
| 398 |
+
|
| 399 |
+
pairing.append(row_pairing)
|
| 400 |
+
is_paired.append(row_is_paired)
|
| 401 |
+
|
| 402 |
+
# Break if we have enough sequences
|
| 403 |
+
if len(pairing) >= max_total:
|
| 404 |
+
break
|
| 405 |
+
|
| 406 |
+
# Randomly sample a subset of the pairs
|
| 407 |
+
# ensuring the first row is always present
|
| 408 |
+
if random_subset:
|
| 409 |
+
num_seqs = len(pairing)
|
| 410 |
+
if num_seqs > max_seqs:
|
| 411 |
+
indices = random.choice(
|
| 412 |
+
np.arange(1, num_seqs), size=max_seqs - 1, replace=False
|
| 413 |
+
) # noqa: NPY002
|
| 414 |
+
pairing = [pairing[0]] + [pairing[i] for i in indices]
|
| 415 |
+
is_paired = [is_paired[0]] + [is_paired[i] for i in indices]
|
| 416 |
+
else:
|
| 417 |
+
# Deterministic downsample to max_seqs
|
| 418 |
+
pairing = pairing[:max_seqs]
|
| 419 |
+
is_paired = is_paired[:max_seqs]
|
| 420 |
+
|
| 421 |
+
# Map (chain_id, seq_idx, res_idx) to deletion
|
| 422 |
+
deletions = numba.typed.Dict.empty(
|
| 423 |
+
key_type=numba.types.Tuple(
|
| 424 |
+
[numba.types.int64, numba.types.int64, numba.types.int64]),
|
| 425 |
+
value_type=numba.types.int64
|
| 426 |
+
)
|
| 427 |
+
for chain_id, chain_msa in msa.items():
|
| 428 |
+
chain_deletions = chain_msa.deletions
|
| 429 |
+
for sequence in chain_msa.sequences:
|
| 430 |
+
seq_idx = sequence["seq_idx"]
|
| 431 |
+
del_start = sequence["del_start"]
|
| 432 |
+
del_end = sequence["del_end"]
|
| 433 |
+
chain_deletions = chain_deletions[del_start:del_end]
|
| 434 |
+
for deletion_data in chain_deletions:
|
| 435 |
+
res_idx = deletion_data["res_idx"]
|
| 436 |
+
deletion_values = deletion_data["deletion"]
|
| 437 |
+
deletions[(chain_id, seq_idx, res_idx)] = deletion_values
|
| 438 |
+
|
| 439 |
+
# Add all the token MSA data
|
| 440 |
+
msa_data, del_data, paired_data = prepare_msa_arrays(
|
| 441 |
+
data.tokens, pairing, is_paired, deletions, msa
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
msa_data = torch.tensor(msa_data, dtype=torch.long)
|
| 445 |
+
del_data = torch.tensor(del_data, dtype=torch.float)
|
| 446 |
+
paired_data = torch.tensor(paired_data, dtype=torch.float)
|
| 447 |
+
|
| 448 |
+
return msa_data, del_data, paired_data
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def prepare_msa_arrays(
|
| 452 |
+
tokens,
|
| 453 |
+
pairing: list[dict[int, int]],
|
| 454 |
+
is_paired: list[dict[int, int]],
|
| 455 |
+
deletions: dict[tuple[int, int, int], int],
|
| 456 |
+
msa: dict[int, MSA],
|
| 457 |
+
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
|
| 458 |
+
"""Reshape data to play nicely with numba jit."""
|
| 459 |
+
token_asym_ids_arr = np.array([t["asym_id"] for t in tokens], dtype=np.int64)
|
| 460 |
+
token_res_idxs_arr = np.array([t["res_idx"] for t in tokens], dtype=np.int64)
|
| 461 |
+
|
| 462 |
+
chain_ids = sorted(msa.keys())
|
| 463 |
+
|
| 464 |
+
# chain_ids are not necessarily contiguous (e.g. they might be 0, 24, 25).
|
| 465 |
+
# This allows us to look up a chain_id by it's index in the chain_ids list.
|
| 466 |
+
chain_id_to_idx = {chain_id: i for i, chain_id in enumerate(chain_ids)}
|
| 467 |
+
token_asym_ids_idx_arr = np.array(
|
| 468 |
+
[chain_id_to_idx[asym_id] for asym_id in token_asym_ids_arr], dtype=np.int64
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
pairing_arr = np.zeros((len(pairing), len(chain_ids)), dtype=np.int64)
|
| 472 |
+
is_paired_arr = np.zeros((len(is_paired), len(chain_ids)), dtype=np.int64)
|
| 473 |
+
|
| 474 |
+
for i, row_pairing in enumerate(pairing):
|
| 475 |
+
for chain_id in chain_ids:
|
| 476 |
+
pairing_arr[i, chain_id_to_idx[chain_id]] = row_pairing[chain_id]
|
| 477 |
+
|
| 478 |
+
for i, row_is_paired in enumerate(is_paired):
|
| 479 |
+
for chain_id in chain_ids:
|
| 480 |
+
is_paired_arr[i, chain_id_to_idx[chain_id]] = row_is_paired[chain_id]
|
| 481 |
+
|
| 482 |
+
max_seq_len = max(len(msa[chain_id].sequences) for chain_id in chain_ids)
|
| 483 |
+
|
| 484 |
+
# we want res_start from sequences
|
| 485 |
+
msa_sequences = np.full((len(chain_ids), max_seq_len), -1, dtype=np.int64)
|
| 486 |
+
for chain_id in chain_ids:
|
| 487 |
+
for i, seq in enumerate(msa[chain_id].sequences):
|
| 488 |
+
msa_sequences[chain_id_to_idx[chain_id], i] = seq["res_start"]
|
| 489 |
+
|
| 490 |
+
max_residues_len = max(len(msa[chain_id].residues) for chain_id in chain_ids)
|
| 491 |
+
msa_residues = np.full((len(chain_ids), max_residues_len), -1, dtype=np.int64)
|
| 492 |
+
for chain_id in chain_ids:
|
| 493 |
+
residues = msa[chain_id].residues.astype(np.int64)
|
| 494 |
+
idxs = np.arange(len(residues))
|
| 495 |
+
chain_idx = chain_id_to_idx[chain_id]
|
| 496 |
+
msa_residues[chain_idx, idxs] = residues
|
| 497 |
+
|
| 498 |
+
return _prepare_msa_arrays_inner(
|
| 499 |
+
token_asym_ids_arr,
|
| 500 |
+
token_res_idxs_arr,
|
| 501 |
+
token_asym_ids_idx_arr,
|
| 502 |
+
pairing_arr,
|
| 503 |
+
is_paired_arr,
|
| 504 |
+
deletions,
|
| 505 |
+
msa_sequences,
|
| 506 |
+
msa_residues,
|
| 507 |
+
const.token_ids["-"],
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
deletions_dict_type = types.DictType(types.UniTuple(types.int64, 3), types.int64)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
@numba.njit(
|
| 515 |
+
[
|
| 516 |
+
types.Tuple(
|
| 517 |
+
(
|
| 518 |
+
types.int64[:, ::1], # msa_data
|
| 519 |
+
types.int64[:, ::1], # del_data
|
| 520 |
+
types.int64[:, ::1], # paired_data
|
| 521 |
+
)
|
| 522 |
+
)(
|
| 523 |
+
types.int64[::1], # token_asym_ids
|
| 524 |
+
types.int64[::1], # token_res_idxs
|
| 525 |
+
types.int64[::1], # token_asym_ids_idx
|
| 526 |
+
types.int64[:, ::1], # pairing
|
| 527 |
+
types.int64[:, ::1], # is_paired
|
| 528 |
+
deletions_dict_type, # deletions
|
| 529 |
+
types.int64[:, ::1], # msa_sequences
|
| 530 |
+
types.int64[:, ::1], # msa_residues
|
| 531 |
+
types.int64, # gap_token
|
| 532 |
+
)
|
| 533 |
+
],
|
| 534 |
+
cache=True,
|
| 535 |
+
)
|
| 536 |
+
def _prepare_msa_arrays_inner(
|
| 537 |
+
token_asym_ids: npt.NDArray[np.int64],
|
| 538 |
+
token_res_idxs: npt.NDArray[np.int64],
|
| 539 |
+
token_asym_ids_idx: npt.NDArray[np.int64],
|
| 540 |
+
pairing: npt.NDArray[np.int64],
|
| 541 |
+
is_paired: npt.NDArray[np.int64],
|
| 542 |
+
deletions: dict[tuple[int, int, int], int],
|
| 543 |
+
msa_sequences: npt.NDArray[np.int64],
|
| 544 |
+
msa_residues: npt.NDArray[np.int64],
|
| 545 |
+
gap_token: int,
|
| 546 |
+
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
|
| 547 |
+
n_tokens = len(token_asym_ids)
|
| 548 |
+
n_pairs = len(pairing)
|
| 549 |
+
msa_data = np.full((n_tokens, n_pairs), gap_token, dtype=np.int64)
|
| 550 |
+
paired_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
|
| 551 |
+
del_data = np.zeros((n_tokens, n_pairs), dtype=np.int64)
|
| 552 |
+
|
| 553 |
+
# Add all the token MSA data
|
| 554 |
+
for token_idx in range(n_tokens):
|
| 555 |
+
chain_id_idx = token_asym_ids_idx[token_idx]
|
| 556 |
+
chain_id = token_asym_ids[token_idx]
|
| 557 |
+
res_idx = token_res_idxs[token_idx]
|
| 558 |
+
|
| 559 |
+
for pair_idx in range(n_pairs):
|
| 560 |
+
seq_idx = pairing[pair_idx, chain_id_idx]
|
| 561 |
+
paired_data[token_idx, pair_idx] = is_paired[pair_idx, chain_id_idx]
|
| 562 |
+
|
| 563 |
+
# Add residue type
|
| 564 |
+
if seq_idx != -1:
|
| 565 |
+
res_start = msa_sequences[chain_id_idx, seq_idx]
|
| 566 |
+
res_type = msa_residues[chain_id_idx, res_start + res_idx]
|
| 567 |
+
k = (chain_id, seq_idx, res_idx)
|
| 568 |
+
if k in deletions:
|
| 569 |
+
del_data[token_idx, pair_idx] = deletions[k]
|
| 570 |
+
msa_data[token_idx, pair_idx] = res_type
|
| 571 |
+
|
| 572 |
+
return msa_data, del_data, paired_data
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
####################################################################################################
|
| 576 |
+
# FEATURES
|
| 577 |
+
####################################################################################################
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def select_subset_from_mask(mask, p, random: np.random.Generator) -> np.ndarray:
|
| 581 |
+
num_true = np.sum(mask)
|
| 582 |
+
v = random.geometric(p) + 1
|
| 583 |
+
k = min(v, num_true)
|
| 584 |
+
|
| 585 |
+
true_indices = np.where(mask)[0]
|
| 586 |
+
|
| 587 |
+
# Randomly select k indices from the true_indices
|
| 588 |
+
selected_indices = random.choice(true_indices, size=k, replace=False)
|
| 589 |
+
|
| 590 |
+
new_mask = np.zeros_like(mask)
|
| 591 |
+
new_mask[selected_indices] = 1
|
| 592 |
+
|
| 593 |
+
return new_mask
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def get_range_bin(value: float, range_dict: dict[tuple[float, float], int], default=0):
|
| 597 |
+
"""Get the bin of a value given a range dictionary."""
|
| 598 |
+
value = float(value)
|
| 599 |
+
for k, idx in range_dict.items():
|
| 600 |
+
if k == "other":
|
| 601 |
+
continue
|
| 602 |
+
low, high = k
|
| 603 |
+
if low <= value < high:
|
| 604 |
+
return idx
|
| 605 |
+
return default
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def process_token_features( # noqa: C901, PLR0915, PLR0912
|
| 609 |
+
data: Tokenized,
|
| 610 |
+
random: np.random.Generator,
|
| 611 |
+
max_tokens: Optional[int] = None,
|
| 612 |
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
| 613 |
+
contact_conditioned_prop: Optional[float] = 0.0,
|
| 614 |
+
binder_pocket_cutoff_min: Optional[float] = 4.0,
|
| 615 |
+
binder_pocket_cutoff_max: Optional[float] = 20.0,
|
| 616 |
+
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
|
| 617 |
+
only_ligand_binder_pocket: Optional[bool] = False,
|
| 618 |
+
only_pp_contact: Optional[bool] = False,
|
| 619 |
+
inference_pocket_constraints: Optional[
|
| 620 |
+
list[tuple[int, list[tuple[int, int]], float]]
|
| 621 |
+
] = False,
|
| 622 |
+
inference_contact_constraints: Optional[
|
| 623 |
+
list[tuple[tuple[int, int], tuple[int, int], float]]
|
| 624 |
+
] = False,
|
| 625 |
+
override_method: Optional[str] = None,
|
| 626 |
+
) -> dict[str, Tensor]:
|
| 627 |
+
"""Get the token features.
|
| 628 |
+
|
| 629 |
+
Parameters
|
| 630 |
+
----------
|
| 631 |
+
data : Tokenized
|
| 632 |
+
The input data to the model.
|
| 633 |
+
max_tokens : int
|
| 634 |
+
The maximum number of tokens.
|
| 635 |
+
|
| 636 |
+
Returns
|
| 637 |
+
-------
|
| 638 |
+
dict[str, Tensor]
|
| 639 |
+
The token features.
|
| 640 |
+
|
| 641 |
+
"""
|
| 642 |
+
# Token data
|
| 643 |
+
token_data = data.tokens
|
| 644 |
+
token_bonds = data.bonds
|
| 645 |
+
|
| 646 |
+
# Token core features
|
| 647 |
+
token_index = torch.arange(len(token_data), dtype=torch.long)
|
| 648 |
+
residue_index = from_numpy(token_data["res_idx"]).long()
|
| 649 |
+
asym_id = from_numpy(token_data["asym_id"]).long()
|
| 650 |
+
entity_id = from_numpy(token_data["entity_id"]).long()
|
| 651 |
+
sym_id = from_numpy(token_data["sym_id"]).long()
|
| 652 |
+
mol_type = from_numpy(token_data["mol_type"]).long()
|
| 653 |
+
res_type = from_numpy(token_data["res_type"]).long()
|
| 654 |
+
res_type = one_hot(res_type, num_classes=const.num_tokens)
|
| 655 |
+
disto_center = from_numpy(token_data["disto_coords"])
|
| 656 |
+
modified = from_numpy(token_data["modified"]).long() # float()
|
| 657 |
+
cyclic_period = from_numpy(token_data["cyclic_period"].copy())
|
| 658 |
+
affinity_mask = from_numpy(token_data["affinity_mask"]).float()
|
| 659 |
+
|
| 660 |
+
## Conditioning features ##
|
| 661 |
+
method = (
|
| 662 |
+
np.zeros(len(token_data))
|
| 663 |
+
+ const.method_types_ids[
|
| 664 |
+
(
|
| 665 |
+
"x-ray diffraction"
|
| 666 |
+
if override_method is None
|
| 667 |
+
else override_method.lower()
|
| 668 |
+
)
|
| 669 |
+
]
|
| 670 |
+
)
|
| 671 |
+
if data.record is not None:
|
| 672 |
+
if (
|
| 673 |
+
override_method is None
|
| 674 |
+
and data.record.structure.method is not None
|
| 675 |
+
and data.record.structure.method.lower() in const.method_types_ids
|
| 676 |
+
):
|
| 677 |
+
method = (method * 0) + const.method_types_ids[
|
| 678 |
+
data.record.structure.method.lower()
|
| 679 |
+
]
|
| 680 |
+
|
| 681 |
+
method_feature = from_numpy(method).long()
|
| 682 |
+
|
| 683 |
+
# Token mask features
|
| 684 |
+
pad_mask = torch.ones(len(token_data), dtype=torch.float)
|
| 685 |
+
resolved_mask = from_numpy(token_data["resolved_mask"]).float()
|
| 686 |
+
disto_mask = from_numpy(token_data["disto_mask"]).float()
|
| 687 |
+
|
| 688 |
+
# Token bond features
|
| 689 |
+
if max_tokens is not None:
|
| 690 |
+
pad_len = max_tokens - len(token_data)
|
| 691 |
+
num_tokens = max_tokens if pad_len > 0 else len(token_data)
|
| 692 |
+
else:
|
| 693 |
+
num_tokens = len(token_data)
|
| 694 |
+
|
| 695 |
+
tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)}
|
| 696 |
+
bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float)
|
| 697 |
+
bonds_type = torch.zeros(num_tokens, num_tokens, dtype=torch.long)
|
| 698 |
+
for token_bond in token_bonds:
|
| 699 |
+
token_1 = tok_to_idx[token_bond["token_1"]]
|
| 700 |
+
token_2 = tok_to_idx[token_bond["token_2"]]
|
| 701 |
+
bonds[token_1, token_2] = 1
|
| 702 |
+
bonds[token_2, token_1] = 1
|
| 703 |
+
bond_type = token_bond["type"]
|
| 704 |
+
bonds_type[token_1, token_2] = bond_type
|
| 705 |
+
bonds_type[token_2, token_1] = bond_type
|
| 706 |
+
|
| 707 |
+
bonds = bonds.unsqueeze(-1)
|
| 708 |
+
|
| 709 |
+
# Pocket conditioned feature
|
| 710 |
+
contact_conditioning = (
|
| 711 |
+
np.zeros((len(token_data), len(token_data)))
|
| 712 |
+
+ const.contact_conditioning_info["UNSELECTED"]
|
| 713 |
+
)
|
| 714 |
+
contact_threshold = np.zeros((len(token_data), len(token_data)))
|
| 715 |
+
|
| 716 |
+
if inference_pocket_constraints is not None:
|
| 717 |
+
for binder, contacts, max_distance, force in inference_pocket_constraints:
|
| 718 |
+
binder_mask = token_data["asym_id"] == binder
|
| 719 |
+
|
| 720 |
+
for idx, token in enumerate(token_data):
|
| 721 |
+
if (
|
| 722 |
+
token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
| 723 |
+
and (token["asym_id"], token["res_idx"]) in contacts
|
| 724 |
+
) or (
|
| 725 |
+
token["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
| 726 |
+
and (token["asym_id"], token["atom_idx"]) in contacts
|
| 727 |
+
):
|
| 728 |
+
contact_conditioning[binder_mask, idx] = (
|
| 729 |
+
const.contact_conditioning_info["BINDER>POCKET"]
|
| 730 |
+
)
|
| 731 |
+
contact_conditioning[idx, binder_mask] = (
|
| 732 |
+
const.contact_conditioning_info["POCKET>BINDER"]
|
| 733 |
+
)
|
| 734 |
+
contact_threshold[binder_mask, idx] = max_distance
|
| 735 |
+
contact_threshold[idx, binder_mask] = max_distance
|
| 736 |
+
|
| 737 |
+
if inference_contact_constraints is not None:
|
| 738 |
+
for token1, token2, max_distance, force in inference_contact_constraints:
|
| 739 |
+
for idx1, _token1 in enumerate(token_data):
|
| 740 |
+
if (
|
| 741 |
+
_token1["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
| 742 |
+
and (_token1["asym_id"], _token1["res_idx"]) == token1
|
| 743 |
+
) or (
|
| 744 |
+
_token1["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
| 745 |
+
and (_token1["asym_id"], _token1["atom_idx"]) == token1
|
| 746 |
+
):
|
| 747 |
+
for idx2, _token2 in enumerate(token_data):
|
| 748 |
+
if (
|
| 749 |
+
_token2["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
| 750 |
+
and (_token2["asym_id"], _token2["res_idx"]) == token2
|
| 751 |
+
) or (
|
| 752 |
+
_token2["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
| 753 |
+
and (_token2["asym_id"], _token2["atom_idx"]) == token2
|
| 754 |
+
):
|
| 755 |
+
contact_conditioning[idx1, idx2] = (
|
| 756 |
+
const.contact_conditioning_info["CONTACT"]
|
| 757 |
+
)
|
| 758 |
+
contact_conditioning[idx2, idx1] = (
|
| 759 |
+
const.contact_conditioning_info["CONTACT"]
|
| 760 |
+
)
|
| 761 |
+
contact_threshold[idx1, idx2] = max_distance
|
| 762 |
+
contact_threshold[idx2, idx1] = max_distance
|
| 763 |
+
break
|
| 764 |
+
break
|
| 765 |
+
|
| 766 |
+
if binder_pocket_conditioned_prop > 0.0:
|
| 767 |
+
# choose as binder a random ligand in the crop, if there are no ligands select a protein chain
|
| 768 |
+
binder_asym_ids = np.unique(
|
| 769 |
+
token_data["asym_id"][
|
| 770 |
+
token_data["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
| 771 |
+
]
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
if len(binder_asym_ids) == 0:
|
| 775 |
+
if not only_ligand_binder_pocket:
|
| 776 |
+
binder_asym_ids = np.unique(token_data["asym_id"])
|
| 777 |
+
|
| 778 |
+
while random.random() < binder_pocket_conditioned_prop:
|
| 779 |
+
if len(binder_asym_ids) == 0:
|
| 780 |
+
break
|
| 781 |
+
|
| 782 |
+
pocket_asym_id = random.choice(binder_asym_ids)
|
| 783 |
+
binder_asym_ids = binder_asym_ids[binder_asym_ids != pocket_asym_id]
|
| 784 |
+
|
| 785 |
+
binder_pocket_cutoff = sample_d(
|
| 786 |
+
min_d=binder_pocket_cutoff_min,
|
| 787 |
+
max_d=binder_pocket_cutoff_max,
|
| 788 |
+
n_samples=1,
|
| 789 |
+
random=random,
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
binder_mask = token_data["asym_id"] == pocket_asym_id
|
| 793 |
+
|
| 794 |
+
binder_coords = []
|
| 795 |
+
for token in token_data:
|
| 796 |
+
if token["asym_id"] == pocket_asym_id:
|
| 797 |
+
_coords = data.structure.atoms["coords"][
|
| 798 |
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
| 799 |
+
]
|
| 800 |
+
_is_present = data.structure.atoms["is_present"][
|
| 801 |
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
| 802 |
+
]
|
| 803 |
+
binder_coords.append(_coords[_is_present])
|
| 804 |
+
binder_coords = np.concatenate(binder_coords, axis=0)
|
| 805 |
+
|
| 806 |
+
# find the tokens in the pocket
|
| 807 |
+
token_dist = np.zeros(len(token_data)) + 1000
|
| 808 |
+
for i, token in enumerate(token_data):
|
| 809 |
+
if (
|
| 810 |
+
token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
| 811 |
+
and token["asym_id"] != pocket_asym_id
|
| 812 |
+
and token["resolved_mask"] == 1
|
| 813 |
+
):
|
| 814 |
+
token_coords = data.structure.atoms["coords"][
|
| 815 |
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
| 816 |
+
]
|
| 817 |
+
token_is_present = data.structure.atoms["is_present"][
|
| 818 |
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
| 819 |
+
]
|
| 820 |
+
token_coords = token_coords[token_is_present]
|
| 821 |
+
|
| 822 |
+
# find chain and apply chain transformation
|
| 823 |
+
for chain in data.structure.chains:
|
| 824 |
+
if chain["asym_id"] == token["asym_id"]:
|
| 825 |
+
break
|
| 826 |
+
|
| 827 |
+
token_dist[i] = np.min(
|
| 828 |
+
np.linalg.norm(
|
| 829 |
+
token_coords[:, None, :] - binder_coords[None, :, :],
|
| 830 |
+
axis=-1,
|
| 831 |
+
)
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
pocket_mask = token_dist < binder_pocket_cutoff
|
| 835 |
+
|
| 836 |
+
if np.sum(pocket_mask) > 0:
|
| 837 |
+
if binder_pocket_sampling_geometric_p > 0.0:
|
| 838 |
+
# select a subset of the pocket, according
|
| 839 |
+
# to a geometric distribution with one as minimum
|
| 840 |
+
pocket_mask = select_subset_from_mask(
|
| 841 |
+
pocket_mask,
|
| 842 |
+
binder_pocket_sampling_geometric_p,
|
| 843 |
+
random,
|
| 844 |
+
)
|
| 845 |
+
|
| 846 |
+
contact_conditioning[np.ix_(binder_mask, pocket_mask)] = (
|
| 847 |
+
const.contact_conditioning_info["BINDER>POCKET"]
|
| 848 |
+
)
|
| 849 |
+
contact_conditioning[np.ix_(pocket_mask, binder_mask)] = (
|
| 850 |
+
const.contact_conditioning_info["POCKET>BINDER"]
|
| 851 |
+
)
|
| 852 |
+
contact_threshold[np.ix_(binder_mask, pocket_mask)] = (
|
| 853 |
+
binder_pocket_cutoff
|
| 854 |
+
)
|
| 855 |
+
contact_threshold[np.ix_(pocket_mask, binder_mask)] = (
|
| 856 |
+
binder_pocket_cutoff
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
# Contact conditioning feature
|
| 860 |
+
if contact_conditioned_prop > 0.0:
|
| 861 |
+
while random.random() < contact_conditioned_prop:
|
| 862 |
+
contact_cutoff = sample_d(
|
| 863 |
+
min_d=binder_pocket_cutoff_min,
|
| 864 |
+
max_d=binder_pocket_cutoff_max,
|
| 865 |
+
n_samples=1,
|
| 866 |
+
random=random,
|
| 867 |
+
)
|
| 868 |
+
if only_pp_contact:
|
| 869 |
+
chain_asym_ids = np.unique(
|
| 870 |
+
token_data["asym_id"][
|
| 871 |
+
token_data["mol_type"] == const.chain_type_ids["PROTEIN"]
|
| 872 |
+
]
|
| 873 |
+
)
|
| 874 |
+
else:
|
| 875 |
+
chain_asym_ids = np.unique(token_data["asym_id"])
|
| 876 |
+
|
| 877 |
+
if len(chain_asym_ids) > 1:
|
| 878 |
+
chain_asym_id = random.choice(chain_asym_ids)
|
| 879 |
+
|
| 880 |
+
chain_coords = []
|
| 881 |
+
for token in token_data:
|
| 882 |
+
if token["asym_id"] == chain_asym_id:
|
| 883 |
+
_coords = data.structure.atoms["coords"][
|
| 884 |
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
| 885 |
+
]
|
| 886 |
+
_is_present = data.structure.atoms["is_present"][
|
| 887 |
+
token["atom_idx"] : token["atom_idx"] + token["atom_num"]
|
| 888 |
+
]
|
| 889 |
+
chain_coords.append(_coords[_is_present])
|
| 890 |
+
chain_coords = np.concatenate(chain_coords, axis=0)
|
| 891 |
+
|
| 892 |
+
# find contacts in other chains
|
| 893 |
+
possible_other_chains = []
|
| 894 |
+
for other_chain_id in chain_asym_ids[chain_asym_ids != chain_asym_id]:
|
| 895 |
+
for token in token_data:
|
| 896 |
+
if token["asym_id"] == other_chain_id:
|
| 897 |
+
_coords = data.structure.atoms["coords"][
|
| 898 |
+
token["atom_idx"] : token["atom_idx"]
|
| 899 |
+
+ token["atom_num"]
|
| 900 |
+
]
|
| 901 |
+
_is_present = data.structure.atoms["is_present"][
|
| 902 |
+
token["atom_idx"] : token["atom_idx"]
|
| 903 |
+
+ token["atom_num"]
|
| 904 |
+
]
|
| 905 |
+
if _is_present.sum() == 0:
|
| 906 |
+
continue
|
| 907 |
+
token_coords = _coords[_is_present]
|
| 908 |
+
|
| 909 |
+
# check minimum distance
|
| 910 |
+
if (
|
| 911 |
+
np.min(cdist(chain_coords, token_coords))
|
| 912 |
+
< contact_cutoff
|
| 913 |
+
):
|
| 914 |
+
possible_other_chains.append(other_chain_id)
|
| 915 |
+
break
|
| 916 |
+
|
| 917 |
+
if len(possible_other_chains) > 0:
|
| 918 |
+
other_chain_id = random.choice(possible_other_chains)
|
| 919 |
+
|
| 920 |
+
pairs = []
|
| 921 |
+
for token_1 in token_data:
|
| 922 |
+
if token_1["asym_id"] == chain_asym_id:
|
| 923 |
+
_coords = data.structure.atoms["coords"][
|
| 924 |
+
token_1["atom_idx"] : token_1["atom_idx"]
|
| 925 |
+
+ token_1["atom_num"]
|
| 926 |
+
]
|
| 927 |
+
_is_present = data.structure.atoms["is_present"][
|
| 928 |
+
token_1["atom_idx"] : token_1["atom_idx"]
|
| 929 |
+
+ token_1["atom_num"]
|
| 930 |
+
]
|
| 931 |
+
if _is_present.sum() == 0:
|
| 932 |
+
continue
|
| 933 |
+
token_1_coords = _coords[_is_present]
|
| 934 |
+
|
| 935 |
+
for token_2 in token_data:
|
| 936 |
+
if token_2["asym_id"] == other_chain_id:
|
| 937 |
+
_coords = data.structure.atoms["coords"][
|
| 938 |
+
token_2["atom_idx"] : token_2["atom_idx"]
|
| 939 |
+
+ token_2["atom_num"]
|
| 940 |
+
]
|
| 941 |
+
_is_present = data.structure.atoms["is_present"][
|
| 942 |
+
token_2["atom_idx"] : token_2["atom_idx"]
|
| 943 |
+
+ token_2["atom_num"]
|
| 944 |
+
]
|
| 945 |
+
if _is_present.sum() == 0:
|
| 946 |
+
continue
|
| 947 |
+
token_2_coords = _coords[_is_present]
|
| 948 |
+
|
| 949 |
+
if (
|
| 950 |
+
np.min(cdist(token_1_coords, token_2_coords))
|
| 951 |
+
< contact_cutoff
|
| 952 |
+
):
|
| 953 |
+
pairs.append(
|
| 954 |
+
(token_1["token_idx"], token_2["token_idx"])
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
assert len(pairs) > 0
|
| 958 |
+
|
| 959 |
+
pair = random.choice(pairs)
|
| 960 |
+
token_1_mask = token_data["token_idx"] == pair[0]
|
| 961 |
+
token_2_mask = token_data["token_idx"] == pair[1]
|
| 962 |
+
|
| 963 |
+
contact_conditioning[np.ix_(token_1_mask, token_2_mask)] = (
|
| 964 |
+
const.contact_conditioning_info["CONTACT"]
|
| 965 |
+
)
|
| 966 |
+
contact_conditioning[np.ix_(token_2_mask, token_1_mask)] = (
|
| 967 |
+
const.contact_conditioning_info["CONTACT"]
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
elif not only_pp_contact:
|
| 971 |
+
# only one chain, find contacts within the chain with minimum residue distance
|
| 972 |
+
pairs = []
|
| 973 |
+
for token_1 in token_data:
|
| 974 |
+
_coords = data.structure.atoms["coords"][
|
| 975 |
+
token_1["atom_idx"] : token_1["atom_idx"] + token_1["atom_num"]
|
| 976 |
+
]
|
| 977 |
+
_is_present = data.structure.atoms["is_present"][
|
| 978 |
+
token_1["atom_idx"] : token_1["atom_idx"] + token_1["atom_num"]
|
| 979 |
+
]
|
| 980 |
+
if _is_present.sum() == 0:
|
| 981 |
+
continue
|
| 982 |
+
token_1_coords = _coords[_is_present]
|
| 983 |
+
|
| 984 |
+
for token_2 in token_data:
|
| 985 |
+
if np.abs(token_1["res_idx"] - token_2["res_idx"]) <= 8:
|
| 986 |
+
continue
|
| 987 |
+
|
| 988 |
+
_coords = data.structure.atoms["coords"][
|
| 989 |
+
token_2["atom_idx"] : token_2["atom_idx"]
|
| 990 |
+
+ token_2["atom_num"]
|
| 991 |
+
]
|
| 992 |
+
_is_present = data.structure.atoms["is_present"][
|
| 993 |
+
token_2["atom_idx"] : token_2["atom_idx"]
|
| 994 |
+
+ token_2["atom_num"]
|
| 995 |
+
]
|
| 996 |
+
if _is_present.sum() == 0:
|
| 997 |
+
continue
|
| 998 |
+
token_2_coords = _coords[_is_present]
|
| 999 |
+
|
| 1000 |
+
if (
|
| 1001 |
+
np.min(cdist(token_1_coords, token_2_coords))
|
| 1002 |
+
< contact_cutoff
|
| 1003 |
+
):
|
| 1004 |
+
pairs.append((token_1["token_idx"], token_2["token_idx"]))
|
| 1005 |
+
|
| 1006 |
+
if len(pairs) > 0:
|
| 1007 |
+
pair = random.choice(pairs)
|
| 1008 |
+
token_1_mask = token_data["token_idx"] == pair[0]
|
| 1009 |
+
token_2_mask = token_data["token_idx"] == pair[1]
|
| 1010 |
+
|
| 1011 |
+
contact_conditioning[np.ix_(token_1_mask, token_2_mask)] = (
|
| 1012 |
+
const.contact_conditioning_info["CONTACT"]
|
| 1013 |
+
)
|
| 1014 |
+
contact_conditioning[np.ix_(token_2_mask, token_1_mask)] = (
|
| 1015 |
+
const.contact_conditioning_info["CONTACT"]
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
if np.all(contact_conditioning == const.contact_conditioning_info["UNSELECTED"]):
|
| 1019 |
+
contact_conditioning = (
|
| 1020 |
+
contact_conditioning
|
| 1021 |
+
- const.contact_conditioning_info["UNSELECTED"]
|
| 1022 |
+
+ const.contact_conditioning_info["UNSPECIFIED"]
|
| 1023 |
+
)
|
| 1024 |
+
contact_conditioning = from_numpy(contact_conditioning).long()
|
| 1025 |
+
contact_conditioning = one_hot(
|
| 1026 |
+
contact_conditioning, num_classes=len(const.contact_conditioning_info)
|
| 1027 |
+
)
|
| 1028 |
+
contact_threshold = from_numpy(contact_threshold).float()
|
| 1029 |
+
|
| 1030 |
+
# compute cyclic polymer mask
|
| 1031 |
+
cyclic_ids = {}
|
| 1032 |
+
for idx_chain, asym_id_iter in enumerate(data.structure.chains["asym_id"]):
|
| 1033 |
+
for connection in data.structure.bonds:
|
| 1034 |
+
if (
|
| 1035 |
+
idx_chain == connection["chain_1"] == connection["chain_2"]
|
| 1036 |
+
and data.structure.chains[connection["chain_1"]]["res_num"] > 2
|
| 1037 |
+
and connection["res_1"]
|
| 1038 |
+
!= connection["res_2"] # Avoid same residue bonds!
|
| 1039 |
+
):
|
| 1040 |
+
if (
|
| 1041 |
+
data.structure.chains[connection["chain_1"]]["res_num"]
|
| 1042 |
+
== (connection["res_2"] + 1)
|
| 1043 |
+
and connection["res_1"] == 0
|
| 1044 |
+
) or (
|
| 1045 |
+
data.structure.chains[connection["chain_1"]]["res_num"]
|
| 1046 |
+
== (connection["res_1"] + 1)
|
| 1047 |
+
and connection["res_2"] == 0
|
| 1048 |
+
):
|
| 1049 |
+
cyclic_ids[asym_id_iter] = data.structure.chains[
|
| 1050 |
+
connection["chain_1"]
|
| 1051 |
+
]["res_num"]
|
| 1052 |
+
cyclic = from_numpy(
|
| 1053 |
+
np.array(
|
| 1054 |
+
[
|
| 1055 |
+
(cyclic_ids[asym_id_iter] if asym_id_iter in cyclic_ids else 0)
|
| 1056 |
+
for asym_id_iter in token_data["asym_id"]
|
| 1057 |
+
]
|
| 1058 |
+
)
|
| 1059 |
+
).float()
|
| 1060 |
+
|
| 1061 |
+
# cyclic period is either computed from the bonds or given as input flag
|
| 1062 |
+
cyclic_period = torch.maximum(cyclic, cyclic_period)
|
| 1063 |
+
|
| 1064 |
+
# Pad to max tokens if given
|
| 1065 |
+
if max_tokens is not None:
|
| 1066 |
+
pad_len = max_tokens - len(token_data)
|
| 1067 |
+
if pad_len > 0:
|
| 1068 |
+
token_index = pad_dim(token_index, 0, pad_len)
|
| 1069 |
+
residue_index = pad_dim(residue_index, 0, pad_len)
|
| 1070 |
+
asym_id = pad_dim(asym_id, 0, pad_len)
|
| 1071 |
+
entity_id = pad_dim(entity_id, 0, pad_len)
|
| 1072 |
+
sym_id = pad_dim(sym_id, 0, pad_len)
|
| 1073 |
+
mol_type = pad_dim(mol_type, 0, pad_len)
|
| 1074 |
+
res_type = pad_dim(res_type, 0, pad_len)
|
| 1075 |
+
disto_center = pad_dim(disto_center, 0, pad_len)
|
| 1076 |
+
pad_mask = pad_dim(pad_mask, 0, pad_len)
|
| 1077 |
+
resolved_mask = pad_dim(resolved_mask, 0, pad_len)
|
| 1078 |
+
disto_mask = pad_dim(disto_mask, 0, pad_len)
|
| 1079 |
+
contact_conditioning = pad_dim(contact_conditioning, 0, pad_len)
|
| 1080 |
+
contact_conditioning = pad_dim(contact_conditioning, 1, pad_len)
|
| 1081 |
+
contact_threshold = pad_dim(contact_threshold, 0, pad_len)
|
| 1082 |
+
contact_threshold = pad_dim(contact_threshold, 1, pad_len)
|
| 1083 |
+
method_feature = pad_dim(method_feature, 0, pad_len)
|
| 1084 |
+
modified = pad_dim(modified, 0, pad_len)
|
| 1085 |
+
cyclic_period = pad_dim(cyclic_period, 0, pad_len)
|
| 1086 |
+
affinity_mask = pad_dim(affinity_mask, 0, pad_len)
|
| 1087 |
+
|
| 1088 |
+
token_features = {
|
| 1089 |
+
"token_index": token_index,
|
| 1090 |
+
"residue_index": residue_index,
|
| 1091 |
+
"asym_id": asym_id,
|
| 1092 |
+
"entity_id": entity_id,
|
| 1093 |
+
"sym_id": sym_id,
|
| 1094 |
+
"mol_type": mol_type,
|
| 1095 |
+
"res_type": res_type,
|
| 1096 |
+
"disto_center": disto_center,
|
| 1097 |
+
"token_bonds": bonds,
|
| 1098 |
+
"type_bonds": bonds_type,
|
| 1099 |
+
"token_pad_mask": pad_mask,
|
| 1100 |
+
"token_resolved_mask": resolved_mask,
|
| 1101 |
+
"token_disto_mask": disto_mask,
|
| 1102 |
+
"contact_conditioning": contact_conditioning,
|
| 1103 |
+
"contact_threshold": contact_threshold,
|
| 1104 |
+
"method_feature": method_feature,
|
| 1105 |
+
"modified": modified,
|
| 1106 |
+
"cyclic_period": cyclic_period,
|
| 1107 |
+
"affinity_token_mask": affinity_mask,
|
| 1108 |
+
}
|
| 1109 |
+
|
| 1110 |
+
return token_features
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
def process_atom_features(
|
| 1114 |
+
data: Tokenized,
|
| 1115 |
+
random: np.random.Generator,
|
| 1116 |
+
ensemble_features: dict,
|
| 1117 |
+
molecules: dict[str, Mol],
|
| 1118 |
+
atoms_per_window_queries: int = 32,
|
| 1119 |
+
min_dist: float = 2.0,
|
| 1120 |
+
max_dist: float = 22.0,
|
| 1121 |
+
num_bins: int = 64,
|
| 1122 |
+
max_atoms: Optional[int] = None,
|
| 1123 |
+
max_tokens: Optional[int] = None,
|
| 1124 |
+
disto_use_ensemble: Optional[bool] = False,
|
| 1125 |
+
override_bfactor: bool = False,
|
| 1126 |
+
compute_frames: bool = False,
|
| 1127 |
+
override_coords: Optional[Tensor] = None,
|
| 1128 |
+
bfactor_md_correction: bool = False,
|
| 1129 |
+
) -> dict[str, Tensor]:
|
| 1130 |
+
"""Get the atom features.
|
| 1131 |
+
|
| 1132 |
+
Parameters
|
| 1133 |
+
----------
|
| 1134 |
+
data : Tokenized
|
| 1135 |
+
The input to the model.
|
| 1136 |
+
max_atoms : int, optional
|
| 1137 |
+
The maximum number of atoms.
|
| 1138 |
+
|
| 1139 |
+
Returns
|
| 1140 |
+
-------
|
| 1141 |
+
dict[str, Tensor]
|
| 1142 |
+
The atom features.
|
| 1143 |
+
|
| 1144 |
+
"""
|
| 1145 |
+
# Filter to tokens' atoms
|
| 1146 |
+
atom_data = []
|
| 1147 |
+
atom_name = []
|
| 1148 |
+
atom_element = []
|
| 1149 |
+
atom_charge = []
|
| 1150 |
+
atom_conformer = []
|
| 1151 |
+
atom_chirality = []
|
| 1152 |
+
ref_space_uid = []
|
| 1153 |
+
coord_data = []
|
| 1154 |
+
if compute_frames:
|
| 1155 |
+
frame_data = []
|
| 1156 |
+
resolved_frame_data = []
|
| 1157 |
+
atom_to_token = []
|
| 1158 |
+
token_to_rep_atom = [] # index on cropped atom table
|
| 1159 |
+
r_set_to_rep_atom = []
|
| 1160 |
+
disto_coords_ensemble = []
|
| 1161 |
+
backbone_feat_index = []
|
| 1162 |
+
token_to_center_atom = []
|
| 1163 |
+
|
| 1164 |
+
e_offsets = data.structure.ensemble["atom_coord_idx"]
|
| 1165 |
+
atom_idx = 0
|
| 1166 |
+
|
| 1167 |
+
# Start atom idx in full atom table for structures chosen. Up to num_ensembles points.
|
| 1168 |
+
ensemble_atom_starts = [
|
| 1169 |
+
data.structure.ensemble[idx]["atom_coord_idx"]
|
| 1170 |
+
for idx in ensemble_features["ensemble_ref_idxs"]
|
| 1171 |
+
]
|
| 1172 |
+
|
| 1173 |
+
# Set unk chirality id
|
| 1174 |
+
unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
|
| 1175 |
+
|
| 1176 |
+
chain_res_ids = {}
|
| 1177 |
+
res_index_to_conf_id = {}
|
| 1178 |
+
for token_id, token in enumerate(data.tokens):
|
| 1179 |
+
# Get the chain residue ids
|
| 1180 |
+
chain_idx, res_id = token["asym_id"], token["res_idx"]
|
| 1181 |
+
chain = data.structure.chains[chain_idx]
|
| 1182 |
+
|
| 1183 |
+
if (chain_idx, res_id) not in chain_res_ids:
|
| 1184 |
+
new_idx = len(chain_res_ids)
|
| 1185 |
+
chain_res_ids[(chain_idx, res_id)] = new_idx
|
| 1186 |
+
else:
|
| 1187 |
+
new_idx = chain_res_ids[(chain_idx, res_id)]
|
| 1188 |
+
|
| 1189 |
+
# Get the molecule and conformer
|
| 1190 |
+
mol = molecules[token["res_name"]]
|
| 1191 |
+
atom_name_to_ref = {a.GetProp("name"): a for a in mol.GetAtoms()}
|
| 1192 |
+
|
| 1193 |
+
# Sample a random conformer
|
| 1194 |
+
if (chain_idx, res_id) not in res_index_to_conf_id:
|
| 1195 |
+
conf_ids = [int(conf.GetId()) for conf in mol.GetConformers()]
|
| 1196 |
+
conf_id = int(random.choice(conf_ids))
|
| 1197 |
+
res_index_to_conf_id[(chain_idx, res_id)] = conf_id
|
| 1198 |
+
|
| 1199 |
+
conf_id = res_index_to_conf_id[(chain_idx, res_id)]
|
| 1200 |
+
conformer = mol.GetConformer(conf_id)
|
| 1201 |
+
|
| 1202 |
+
# Map atoms to token indices
|
| 1203 |
+
ref_space_uid.extend([new_idx] * token["atom_num"])
|
| 1204 |
+
atom_to_token.extend([token_id] * token["atom_num"])
|
| 1205 |
+
|
| 1206 |
+
# Add atom data
|
| 1207 |
+
start = token["atom_idx"]
|
| 1208 |
+
end = token["atom_idx"] + token["atom_num"]
|
| 1209 |
+
token_atoms = data.structure.atoms[start:end]
|
| 1210 |
+
|
| 1211 |
+
# Add atom ref data
|
| 1212 |
+
# element, charge, conformer, chirality
|
| 1213 |
+
token_atom_name = np.array([convert_atom_name(a["name"]) for a in token_atoms])
|
| 1214 |
+
token_atoms_ref = np.array([atom_name_to_ref[a["name"]] for a in token_atoms])
|
| 1215 |
+
token_atoms_element = np.array([a.GetAtomicNum() for a in token_atoms_ref])
|
| 1216 |
+
token_atoms_charge = np.array([a.GetFormalCharge() for a in token_atoms_ref])
|
| 1217 |
+
token_atoms_conformer = np.array(
|
| 1218 |
+
[
|
| 1219 |
+
(
|
| 1220 |
+
conformer.GetAtomPosition(a.GetIdx()).x,
|
| 1221 |
+
conformer.GetAtomPosition(a.GetIdx()).y,
|
| 1222 |
+
conformer.GetAtomPosition(a.GetIdx()).z,
|
| 1223 |
+
)
|
| 1224 |
+
for a in token_atoms_ref
|
| 1225 |
+
]
|
| 1226 |
+
)
|
| 1227 |
+
token_atoms_chirality = np.array(
|
| 1228 |
+
[
|
| 1229 |
+
const.chirality_type_ids.get(a.GetChiralTag().name, unk_chirality)
|
| 1230 |
+
for a in token_atoms_ref
|
| 1231 |
+
]
|
| 1232 |
+
)
|
| 1233 |
+
|
| 1234 |
+
# Map token to representative atom
|
| 1235 |
+
token_to_rep_atom.append(atom_idx + token["disto_idx"] - start)
|
| 1236 |
+
token_to_center_atom.append(atom_idx + token["center_idx"] - start)
|
| 1237 |
+
if (chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]) and token[
|
| 1238 |
+
"resolved_mask"
|
| 1239 |
+
]:
|
| 1240 |
+
r_set_to_rep_atom.append(atom_idx + token["center_idx"] - start)
|
| 1241 |
+
|
| 1242 |
+
if chain["mol_type"] == const.chain_type_ids["PROTEIN"]:
|
| 1243 |
+
backbone_index = [
|
| 1244 |
+
(
|
| 1245 |
+
const.protein_backbone_atom_index[atom_name] + 1
|
| 1246 |
+
if atom_name in const.protein_backbone_atom_index
|
| 1247 |
+
else 0
|
| 1248 |
+
)
|
| 1249 |
+
for atom_name in token_atoms["name"]
|
| 1250 |
+
]
|
| 1251 |
+
elif (
|
| 1252 |
+
chain["mol_type"] == const.chain_type_ids["DNA"]
|
| 1253 |
+
or chain["mol_type"] == const.chain_type_ids["RNA"]
|
| 1254 |
+
):
|
| 1255 |
+
backbone_index = [
|
| 1256 |
+
(
|
| 1257 |
+
const.nucleic_backbone_atom_index[atom_name]
|
| 1258 |
+
+ 1
|
| 1259 |
+
+ len(const.protein_backbone_atom_index)
|
| 1260 |
+
if atom_name in const.nucleic_backbone_atom_index
|
| 1261 |
+
else 0
|
| 1262 |
+
)
|
| 1263 |
+
for atom_name in token_atoms["name"]
|
| 1264 |
+
]
|
| 1265 |
+
else:
|
| 1266 |
+
backbone_index = [0] * token["atom_num"]
|
| 1267 |
+
backbone_feat_index.extend(backbone_index)
|
| 1268 |
+
|
| 1269 |
+
# Get token coordinates across sampled ensembles and apply transforms
|
| 1270 |
+
token_coords = np.array(
|
| 1271 |
+
[
|
| 1272 |
+
data.structure.coords[
|
| 1273 |
+
ensemble_atom_start + start : ensemble_atom_start + end
|
| 1274 |
+
]["coords"]
|
| 1275 |
+
for ensemble_atom_start in ensemble_atom_starts
|
| 1276 |
+
]
|
| 1277 |
+
)
|
| 1278 |
+
coord_data.append(token_coords)
|
| 1279 |
+
|
| 1280 |
+
if compute_frames:
|
| 1281 |
+
# Get frame data
|
| 1282 |
+
res_type = const.tokens[token["res_type"]]
|
| 1283 |
+
res_name = str(token["res_name"])
|
| 1284 |
+
|
| 1285 |
+
if token["atom_num"] < 3 or res_type in ["PAD", "UNK", "-"]:
|
| 1286 |
+
idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
|
| 1287 |
+
mask_frame = False
|
| 1288 |
+
elif (token["mol_type"] == const.chain_type_ids["PROTEIN"]) and (
|
| 1289 |
+
res_name in const.ref_atoms
|
| 1290 |
+
):
|
| 1291 |
+
idx_frame_a, idx_frame_b, idx_frame_c = (
|
| 1292 |
+
const.ref_atoms[res_name].index("N"),
|
| 1293 |
+
const.ref_atoms[res_name].index("CA"),
|
| 1294 |
+
const.ref_atoms[res_name].index("C"),
|
| 1295 |
+
)
|
| 1296 |
+
mask_frame = (
|
| 1297 |
+
token_atoms["is_present"][idx_frame_a]
|
| 1298 |
+
and token_atoms["is_present"][idx_frame_b]
|
| 1299 |
+
and token_atoms["is_present"][idx_frame_c]
|
| 1300 |
+
)
|
| 1301 |
+
elif (
|
| 1302 |
+
token["mol_type"] == const.chain_type_ids["DNA"]
|
| 1303 |
+
or token["mol_type"] == const.chain_type_ids["RNA"]
|
| 1304 |
+
) and (res_name in const.ref_atoms):
|
| 1305 |
+
idx_frame_a, idx_frame_b, idx_frame_c = (
|
| 1306 |
+
const.ref_atoms[res_name].index("C1'"),
|
| 1307 |
+
const.ref_atoms[res_name].index("C3'"),
|
| 1308 |
+
const.ref_atoms[res_name].index("C4'"),
|
| 1309 |
+
)
|
| 1310 |
+
mask_frame = (
|
| 1311 |
+
token_atoms["is_present"][idx_frame_a]
|
| 1312 |
+
and token_atoms["is_present"][idx_frame_b]
|
| 1313 |
+
and token_atoms["is_present"][idx_frame_c]
|
| 1314 |
+
)
|
| 1315 |
+
elif token["mol_type"] == const.chain_type_ids["PROTEIN"]:
|
| 1316 |
+
# Try to look for the atom nams in the modified residue
|
| 1317 |
+
is_ca = token_atoms["name"] == "CA"
|
| 1318 |
+
idx_frame_a = is_ca.argmax()
|
| 1319 |
+
ca_present = (
|
| 1320 |
+
token_atoms[idx_frame_a]["is_present"] if is_ca.any() else False
|
| 1321 |
+
)
|
| 1322 |
+
|
| 1323 |
+
is_n = token_atoms["name"] == "N"
|
| 1324 |
+
idx_frame_b = is_n.argmax()
|
| 1325 |
+
n_present = (
|
| 1326 |
+
token_atoms[idx_frame_b]["is_present"] if is_n.any() else False
|
| 1327 |
+
)
|
| 1328 |
+
|
| 1329 |
+
is_c = token_atoms["name"] == "C"
|
| 1330 |
+
idx_frame_c = is_c.argmax()
|
| 1331 |
+
c_present = (
|
| 1332 |
+
token_atoms[idx_frame_c]["is_present"] if is_c.any() else False
|
| 1333 |
+
)
|
| 1334 |
+
mask_frame = ca_present and n_present and c_present
|
| 1335 |
+
|
| 1336 |
+
elif (token["mol_type"] == const.chain_type_ids["DNA"]) or (
|
| 1337 |
+
token["mol_type"] == const.chain_type_ids["RNA"]
|
| 1338 |
+
):
|
| 1339 |
+
# Try to look for the atom nams in the modified residue
|
| 1340 |
+
is_c1 = token_atoms["name"] == "C1'"
|
| 1341 |
+
idx_frame_a = is_c1.argmax()
|
| 1342 |
+
c1_present = (
|
| 1343 |
+
token_atoms[idx_frame_a]["is_present"] if is_c1.any() else False
|
| 1344 |
+
)
|
| 1345 |
+
|
| 1346 |
+
is_c3 = token_atoms["name"] == "C3'"
|
| 1347 |
+
idx_frame_b = is_c3.argmax()
|
| 1348 |
+
c3_present = (
|
| 1349 |
+
token_atoms[idx_frame_b]["is_present"] if is_c3.any() else False
|
| 1350 |
+
)
|
| 1351 |
+
|
| 1352 |
+
is_c4 = token_atoms["name"] == "C4'"
|
| 1353 |
+
idx_frame_c = is_c4.argmax()
|
| 1354 |
+
c4_present = (
|
| 1355 |
+
token_atoms[idx_frame_c]["is_present"] if is_c4.any() else False
|
| 1356 |
+
)
|
| 1357 |
+
mask_frame = c1_present and c3_present and c4_present
|
| 1358 |
+
else:
|
| 1359 |
+
idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
|
| 1360 |
+
mask_frame = False
|
| 1361 |
+
frame_data.append(
|
| 1362 |
+
[
|
| 1363 |
+
idx_frame_a + atom_idx,
|
| 1364 |
+
idx_frame_b + atom_idx,
|
| 1365 |
+
idx_frame_c + atom_idx,
|
| 1366 |
+
]
|
| 1367 |
+
)
|
| 1368 |
+
resolved_frame_data.append(mask_frame)
|
| 1369 |
+
|
| 1370 |
+
# Get distogram coordinates
|
| 1371 |
+
disto_coords_ensemble_tok = data.structure.coords[
|
| 1372 |
+
e_offsets + token["disto_idx"]
|
| 1373 |
+
]["coords"]
|
| 1374 |
+
disto_coords_ensemble.append(disto_coords_ensemble_tok)
|
| 1375 |
+
|
| 1376 |
+
# Update atom data. This is technically never used again (we rely on coord_data),
|
| 1377 |
+
# but we update for consistency and to make sure the Atom object has valid, transformed coordinates.
|
| 1378 |
+
token_atoms = token_atoms.copy()
|
| 1379 |
+
token_atoms["coords"] = token_coords[
|
| 1380 |
+
0
|
| 1381 |
+
] # atom has a copy of first coords in ensemble
|
| 1382 |
+
atom_data.append(token_atoms)
|
| 1383 |
+
atom_name.append(token_atom_name)
|
| 1384 |
+
atom_element.append(token_atoms_element)
|
| 1385 |
+
atom_charge.append(token_atoms_charge)
|
| 1386 |
+
atom_conformer.append(token_atoms_conformer)
|
| 1387 |
+
atom_chirality.append(token_atoms_chirality)
|
| 1388 |
+
atom_idx += len(token_atoms)
|
| 1389 |
+
|
| 1390 |
+
disto_coords_ensemble = np.array(disto_coords_ensemble) # (N_TOK, N_ENS, 3)
|
| 1391 |
+
|
| 1392 |
+
# Compute ensemble distogram
|
| 1393 |
+
L = len(data.tokens)
|
| 1394 |
+
|
| 1395 |
+
if disto_use_ensemble:
|
| 1396 |
+
# Use all available structures to create distogram
|
| 1397 |
+
idx_list = range(disto_coords_ensemble.shape[1])
|
| 1398 |
+
else:
|
| 1399 |
+
# Only use a sampled structures to create distogram
|
| 1400 |
+
idx_list = ensemble_features["ensemble_ref_idxs"]
|
| 1401 |
+
|
| 1402 |
+
# Create distogram
|
| 1403 |
+
disto_target = torch.zeros(L, L, len(idx_list), num_bins) # TODO1
|
| 1404 |
+
|
| 1405 |
+
# disto_target = torch.zeros(L, L, num_bins)
|
| 1406 |
+
for i, e_idx in enumerate(idx_list):
|
| 1407 |
+
t_center = torch.Tensor(disto_coords_ensemble[:, e_idx, :])
|
| 1408 |
+
t_dists = torch.cdist(t_center, t_center)
|
| 1409 |
+
boundaries = torch.linspace(min_dist, max_dist, num_bins - 1)
|
| 1410 |
+
distogram = (t_dists.unsqueeze(-1) > boundaries).sum(dim=-1).long()
|
| 1411 |
+
# disto_target += one_hot(distogram, num_classes=num_bins)
|
| 1412 |
+
disto_target[:, :, i, :] = one_hot(distogram, num_classes=num_bins) # TODO1
|
| 1413 |
+
|
| 1414 |
+
# Normalize distogram
|
| 1415 |
+
# disto_target = disto_target / disto_target.sum(-1)[..., None] # remove TODO1
|
| 1416 |
+
atom_data = np.concatenate(atom_data)
|
| 1417 |
+
atom_name = np.concatenate(atom_name)
|
| 1418 |
+
atom_element = np.concatenate(atom_element)
|
| 1419 |
+
atom_charge = np.concatenate(atom_charge)
|
| 1420 |
+
atom_conformer = np.concatenate(atom_conformer)
|
| 1421 |
+
atom_chirality = np.concatenate(atom_chirality)
|
| 1422 |
+
coord_data = np.concatenate(coord_data, axis=1)
|
| 1423 |
+
ref_space_uid = np.array(ref_space_uid)
|
| 1424 |
+
|
| 1425 |
+
# Compute features
|
| 1426 |
+
disto_coords_ensemble = from_numpy(disto_coords_ensemble)
|
| 1427 |
+
disto_coords_ensemble = disto_coords_ensemble[
|
| 1428 |
+
:, ensemble_features["ensemble_ref_idxs"]
|
| 1429 |
+
].permute(1, 0, 2)
|
| 1430 |
+
backbone_feat_index = from_numpy(np.asarray(backbone_feat_index)).long()
|
| 1431 |
+
ref_atom_name_chars = from_numpy(atom_name).long()
|
| 1432 |
+
ref_element = from_numpy(atom_element).long()
|
| 1433 |
+
ref_charge = from_numpy(atom_charge).float()
|
| 1434 |
+
ref_pos = from_numpy(atom_conformer).float()
|
| 1435 |
+
ref_space_uid = from_numpy(ref_space_uid)
|
| 1436 |
+
ref_chirality = from_numpy(atom_chirality).long()
|
| 1437 |
+
coords = from_numpy(coord_data.copy())
|
| 1438 |
+
resolved_mask = from_numpy(atom_data["is_present"])
|
| 1439 |
+
pad_mask = torch.ones(len(atom_data), dtype=torch.float)
|
| 1440 |
+
atom_to_token = torch.tensor(atom_to_token, dtype=torch.long)
|
| 1441 |
+
token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long)
|
| 1442 |
+
r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long)
|
| 1443 |
+
token_to_center_atom = torch.tensor(token_to_center_atom, dtype=torch.long)
|
| 1444 |
+
bfactor = from_numpy(atom_data["bfactor"].copy())
|
| 1445 |
+
plddt = from_numpy(atom_data["plddt"].copy())
|
| 1446 |
+
if override_bfactor:
|
| 1447 |
+
bfactor = bfactor * 0.0
|
| 1448 |
+
|
| 1449 |
+
if bfactor_md_correction and data.record.structure.method.lower() == "md":
|
| 1450 |
+
# MD bfactor was computed as RMSF
|
| 1451 |
+
# Convert to b-factor
|
| 1452 |
+
bfactor = 8 * (np.pi**2) * (bfactor**2)
|
| 1453 |
+
|
| 1454 |
+
# We compute frames within ensemble
|
| 1455 |
+
if compute_frames:
|
| 1456 |
+
frames = []
|
| 1457 |
+
frame_resolved_mask = []
|
| 1458 |
+
for i in range(coord_data.shape[0]):
|
| 1459 |
+
frame_data_, resolved_frame_data_ = compute_frames_nonpolymer(
|
| 1460 |
+
data,
|
| 1461 |
+
coord_data[i],
|
| 1462 |
+
atom_data["is_present"],
|
| 1463 |
+
atom_to_token,
|
| 1464 |
+
frame_data,
|
| 1465 |
+
resolved_frame_data,
|
| 1466 |
+
) # Compute frames for NONPOLYMER tokens
|
| 1467 |
+
frames.append(frame_data_.copy())
|
| 1468 |
+
frame_resolved_mask.append(resolved_frame_data_.copy())
|
| 1469 |
+
frames = from_numpy(np.stack(frames)) # (N_ENS, N_TOK, 3)
|
| 1470 |
+
frame_resolved_mask = from_numpy(np.stack(frame_resolved_mask))
|
| 1471 |
+
|
| 1472 |
+
# Convert to one-hot
|
| 1473 |
+
backbone_feat_index = one_hot(
|
| 1474 |
+
backbone_feat_index,
|
| 1475 |
+
num_classes=1
|
| 1476 |
+
+ len(const.protein_backbone_atom_index)
|
| 1477 |
+
+ len(const.nucleic_backbone_atom_index),
|
| 1478 |
+
)
|
| 1479 |
+
ref_atom_name_chars = one_hot(ref_atom_name_chars, num_classes=64)
|
| 1480 |
+
ref_element = one_hot(ref_element, num_classes=const.num_elements)
|
| 1481 |
+
atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
|
| 1482 |
+
token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data))
|
| 1483 |
+
r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data))
|
| 1484 |
+
token_to_center_atom = one_hot(token_to_center_atom, num_classes=len(atom_data))
|
| 1485 |
+
|
| 1486 |
+
# Center the ground truth coordinates
|
| 1487 |
+
center = (coords * resolved_mask[None, :, None]).sum(dim=1)
|
| 1488 |
+
center = center / resolved_mask.sum().clamp(min=1)
|
| 1489 |
+
coords = coords - center[:, None]
|
| 1490 |
+
|
| 1491 |
+
if isinstance(override_coords, Tensor):
|
| 1492 |
+
coords = override_coords.unsqueeze(0)
|
| 1493 |
+
|
| 1494 |
+
# Apply random roto-translation to the input conformers
|
| 1495 |
+
for i in range(torch.max(ref_space_uid)):
|
| 1496 |
+
included = ref_space_uid == i
|
| 1497 |
+
if torch.sum(included) > 0 and torch.any(resolved_mask[included]):
|
| 1498 |
+
ref_pos[included] = center_random_augmentation(
|
| 1499 |
+
ref_pos[included][None], resolved_mask[included][None], centering=True
|
| 1500 |
+
)[0]
|
| 1501 |
+
|
| 1502 |
+
# Compute padding and apply
|
| 1503 |
+
if max_atoms is not None:
|
| 1504 |
+
assert max_atoms % atoms_per_window_queries == 0
|
| 1505 |
+
pad_len = max_atoms - len(atom_data)
|
| 1506 |
+
else:
|
| 1507 |
+
pad_len = (
|
| 1508 |
+
(len(atom_data) - 1) // atoms_per_window_queries + 1
|
| 1509 |
+
) * atoms_per_window_queries - len(atom_data)
|
| 1510 |
+
|
| 1511 |
+
if pad_len > 0:
|
| 1512 |
+
pad_mask = pad_dim(pad_mask, 0, pad_len)
|
| 1513 |
+
ref_pos = pad_dim(ref_pos, 0, pad_len)
|
| 1514 |
+
resolved_mask = pad_dim(resolved_mask, 0, pad_len)
|
| 1515 |
+
ref_atom_name_chars = pad_dim(ref_atom_name_chars, 0, pad_len)
|
| 1516 |
+
ref_element = pad_dim(ref_element, 0, pad_len)
|
| 1517 |
+
ref_charge = pad_dim(ref_charge, 0, pad_len)
|
| 1518 |
+
ref_chirality = pad_dim(ref_chirality, 0, pad_len)
|
| 1519 |
+
backbone_feat_index = pad_dim(backbone_feat_index, 0, pad_len)
|
| 1520 |
+
ref_space_uid = pad_dim(ref_space_uid, 0, pad_len)
|
| 1521 |
+
coords = pad_dim(coords, 1, pad_len)
|
| 1522 |
+
atom_to_token = pad_dim(atom_to_token, 0, pad_len)
|
| 1523 |
+
token_to_rep_atom = pad_dim(token_to_rep_atom, 1, pad_len)
|
| 1524 |
+
token_to_center_atom = pad_dim(token_to_center_atom, 1, pad_len)
|
| 1525 |
+
r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 1, pad_len)
|
| 1526 |
+
bfactor = pad_dim(bfactor, 0, pad_len)
|
| 1527 |
+
plddt = pad_dim(plddt, 0, pad_len)
|
| 1528 |
+
|
| 1529 |
+
if max_tokens is not None:
|
| 1530 |
+
pad_len = max_tokens - token_to_rep_atom.shape[0]
|
| 1531 |
+
if pad_len > 0:
|
| 1532 |
+
atom_to_token = pad_dim(atom_to_token, 1, pad_len)
|
| 1533 |
+
token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len)
|
| 1534 |
+
r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len)
|
| 1535 |
+
token_to_center_atom = pad_dim(token_to_center_atom, 0, pad_len)
|
| 1536 |
+
disto_target = pad_dim(pad_dim(disto_target, 0, pad_len), 1, pad_len)
|
| 1537 |
+
disto_coords_ensemble = pad_dim(disto_coords_ensemble, 1, pad_len)
|
| 1538 |
+
|
| 1539 |
+
if compute_frames:
|
| 1540 |
+
frames = pad_dim(frames, 1, pad_len)
|
| 1541 |
+
frame_resolved_mask = pad_dim(frame_resolved_mask, 1, pad_len)
|
| 1542 |
+
|
| 1543 |
+
atom_features = {
|
| 1544 |
+
"ref_pos": ref_pos,
|
| 1545 |
+
"atom_resolved_mask": resolved_mask,
|
| 1546 |
+
"ref_atom_name_chars": ref_atom_name_chars,
|
| 1547 |
+
"ref_element": ref_element,
|
| 1548 |
+
"ref_charge": ref_charge,
|
| 1549 |
+
"ref_chirality": ref_chirality,
|
| 1550 |
+
"atom_backbone_feat": backbone_feat_index,
|
| 1551 |
+
"ref_space_uid": ref_space_uid,
|
| 1552 |
+
"coords": coords,
|
| 1553 |
+
"atom_pad_mask": pad_mask,
|
| 1554 |
+
"atom_to_token": atom_to_token,
|
| 1555 |
+
"token_to_rep_atom": token_to_rep_atom,
|
| 1556 |
+
"r_set_to_rep_atom": r_set_to_rep_atom,
|
| 1557 |
+
"token_to_center_atom": token_to_center_atom,
|
| 1558 |
+
"disto_target": disto_target,
|
| 1559 |
+
"disto_coords_ensemble": disto_coords_ensemble,
|
| 1560 |
+
"bfactor": bfactor,
|
| 1561 |
+
"plddt": plddt,
|
| 1562 |
+
}
|
| 1563 |
+
|
| 1564 |
+
if compute_frames:
|
| 1565 |
+
atom_features["frames_idx"] = frames
|
| 1566 |
+
atom_features["frame_resolved_mask"] = frame_resolved_mask
|
| 1567 |
+
|
| 1568 |
+
return atom_features
|
| 1569 |
+
|
| 1570 |
+
|
| 1571 |
+
def process_msa_features(
|
| 1572 |
+
data: Tokenized,
|
| 1573 |
+
random: np.random.Generator,
|
| 1574 |
+
max_seqs_batch: int,
|
| 1575 |
+
max_seqs: int,
|
| 1576 |
+
max_tokens: Optional[int] = None,
|
| 1577 |
+
pad_to_max_seqs: bool = False,
|
| 1578 |
+
msa_sampling: bool = False,
|
| 1579 |
+
affinity: bool = False,
|
| 1580 |
+
) -> dict[str, Tensor]:
|
| 1581 |
+
"""Get the MSA features.
|
| 1582 |
+
|
| 1583 |
+
Parameters
|
| 1584 |
+
----------
|
| 1585 |
+
data : Tokenized
|
| 1586 |
+
The input to the model.
|
| 1587 |
+
random : np.random.Generator
|
| 1588 |
+
The random number generator.
|
| 1589 |
+
max_seqs : int
|
| 1590 |
+
The maximum number of MSA sequences.
|
| 1591 |
+
max_tokens : int
|
| 1592 |
+
The maximum number of tokens.
|
| 1593 |
+
pad_to_max_seqs : bool
|
| 1594 |
+
Whether to pad to the maximum number of sequences.
|
| 1595 |
+
msa_sampling : bool
|
| 1596 |
+
Whether to sample the MSA.
|
| 1597 |
+
|
| 1598 |
+
Returns
|
| 1599 |
+
-------
|
| 1600 |
+
dict[str, Tensor]
|
| 1601 |
+
The MSA features.
|
| 1602 |
+
|
| 1603 |
+
"""
|
| 1604 |
+
# Created paired MSA
|
| 1605 |
+
msa, deletion, paired = construct_paired_msa(
|
| 1606 |
+
data=data,
|
| 1607 |
+
random=random,
|
| 1608 |
+
max_seqs=max_seqs_batch,
|
| 1609 |
+
random_subset=msa_sampling,
|
| 1610 |
+
)
|
| 1611 |
+
msa, deletion, paired = (
|
| 1612 |
+
msa.transpose(1, 0),
|
| 1613 |
+
deletion.transpose(1, 0),
|
| 1614 |
+
paired.transpose(1, 0),
|
| 1615 |
+
) # (N_MSA, N_RES, N_AA)
|
| 1616 |
+
|
| 1617 |
+
# Prepare features
|
| 1618 |
+
assert torch.all(msa >= 0) and torch.all(msa < const.num_tokens)
|
| 1619 |
+
msa_one_hot = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
|
| 1620 |
+
msa_mask = torch.ones_like(msa)
|
| 1621 |
+
profile = msa_one_hot.float().mean(dim=0)
|
| 1622 |
+
has_deletion = deletion > 0
|
| 1623 |
+
deletion = np.pi / 2 * np.arctan(deletion / 3)
|
| 1624 |
+
deletion_mean = deletion.mean(axis=0)
|
| 1625 |
+
|
| 1626 |
+
# Pad in the MSA dimension (dim=0)
|
| 1627 |
+
if pad_to_max_seqs:
|
| 1628 |
+
pad_len = max_seqs - msa.shape[0]
|
| 1629 |
+
if pad_len > 0:
|
| 1630 |
+
msa = pad_dim(msa, 0, pad_len, const.token_ids["-"])
|
| 1631 |
+
paired = pad_dim(paired, 0, pad_len)
|
| 1632 |
+
msa_mask = pad_dim(msa_mask, 0, pad_len)
|
| 1633 |
+
has_deletion = pad_dim(has_deletion, 0, pad_len)
|
| 1634 |
+
deletion = pad_dim(deletion, 0, pad_len)
|
| 1635 |
+
|
| 1636 |
+
# Pad in the token dimension (dim=1)
|
| 1637 |
+
if max_tokens is not None:
|
| 1638 |
+
pad_len = max_tokens - msa.shape[1]
|
| 1639 |
+
if pad_len > 0:
|
| 1640 |
+
msa = pad_dim(msa, 1, pad_len, const.token_ids["-"])
|
| 1641 |
+
paired = pad_dim(paired, 1, pad_len)
|
| 1642 |
+
msa_mask = pad_dim(msa_mask, 1, pad_len)
|
| 1643 |
+
has_deletion = pad_dim(has_deletion, 1, pad_len)
|
| 1644 |
+
deletion = pad_dim(deletion, 1, pad_len)
|
| 1645 |
+
profile = pad_dim(profile, 0, pad_len)
|
| 1646 |
+
deletion_mean = pad_dim(deletion_mean, 0, pad_len)
|
| 1647 |
+
if affinity:
|
| 1648 |
+
return {
|
| 1649 |
+
"deletion_mean_affinity": deletion_mean,
|
| 1650 |
+
"profile_affinity": profile,
|
| 1651 |
+
}
|
| 1652 |
+
else:
|
| 1653 |
+
return {
|
| 1654 |
+
"msa": msa,
|
| 1655 |
+
"msa_paired": paired,
|
| 1656 |
+
"deletion_value": deletion,
|
| 1657 |
+
"has_deletion": has_deletion,
|
| 1658 |
+
"deletion_mean": deletion_mean,
|
| 1659 |
+
"profile": profile,
|
| 1660 |
+
"msa_mask": msa_mask,
|
| 1661 |
+
}
|
| 1662 |
+
|
| 1663 |
+
|
| 1664 |
+
def load_dummy_templates_features(tdim: int, num_tokens: int) -> dict:
|
| 1665 |
+
"""Load dummy templates for v2."""
|
| 1666 |
+
# Allocate features
|
| 1667 |
+
res_type = np.zeros((tdim, num_tokens), dtype=np.int64)
|
| 1668 |
+
frame_rot = np.zeros((tdim, num_tokens, 3, 3), dtype=np.float32)
|
| 1669 |
+
frame_t = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
|
| 1670 |
+
cb_coords = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
|
| 1671 |
+
ca_coords = np.zeros((tdim, num_tokens, 3), dtype=np.float32)
|
| 1672 |
+
frame_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
|
| 1673 |
+
cb_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
|
| 1674 |
+
template_mask = np.zeros((tdim, num_tokens), dtype=np.float32)
|
| 1675 |
+
query_to_template = np.zeros((tdim, num_tokens), dtype=np.int64)
|
| 1676 |
+
visibility_ids = np.zeros((tdim, num_tokens), dtype=np.float32)
|
| 1677 |
+
|
| 1678 |
+
# Convert to one-hot
|
| 1679 |
+
res_type = torch.from_numpy(res_type)
|
| 1680 |
+
res_type = one_hot(res_type, num_classes=const.num_tokens)
|
| 1681 |
+
|
| 1682 |
+
return {
|
| 1683 |
+
"template_restype": res_type,
|
| 1684 |
+
"template_frame_rot": torch.from_numpy(frame_rot),
|
| 1685 |
+
"template_frame_t": torch.from_numpy(frame_t),
|
| 1686 |
+
"template_cb": torch.from_numpy(cb_coords),
|
| 1687 |
+
"template_ca": torch.from_numpy(ca_coords),
|
| 1688 |
+
"template_mask_cb": torch.from_numpy(cb_mask),
|
| 1689 |
+
"template_mask_frame": torch.from_numpy(frame_mask),
|
| 1690 |
+
"template_mask": torch.from_numpy(template_mask),
|
| 1691 |
+
"query_to_template": torch.from_numpy(query_to_template),
|
| 1692 |
+
"visibility_ids": torch.from_numpy(visibility_ids),
|
| 1693 |
+
}
|
| 1694 |
+
|
| 1695 |
+
|
| 1696 |
+
def compute_template_features(
|
| 1697 |
+
query_tokens: Tokenized,
|
| 1698 |
+
tmpl_tokens: list[dict],
|
| 1699 |
+
num_tokens: int,
|
| 1700 |
+
) -> dict:
|
| 1701 |
+
"""Compute the template features."""
|
| 1702 |
+
# Allocate features
|
| 1703 |
+
res_type = np.zeros((num_tokens,), dtype=np.int64)
|
| 1704 |
+
frame_rot = np.zeros((num_tokens, 3, 3), dtype=np.float32)
|
| 1705 |
+
frame_t = np.zeros((num_tokens, 3), dtype=np.float32)
|
| 1706 |
+
cb_coords = np.zeros((num_tokens, 3), dtype=np.float32)
|
| 1707 |
+
ca_coords = np.zeros((num_tokens, 3), dtype=np.float32)
|
| 1708 |
+
frame_mask = np.zeros((num_tokens,), dtype=np.float32)
|
| 1709 |
+
cb_mask = np.zeros((num_tokens,), dtype=np.float32)
|
| 1710 |
+
template_mask = np.zeros((num_tokens,), dtype=np.float32)
|
| 1711 |
+
query_to_template = np.zeros((num_tokens,), dtype=np.int64)
|
| 1712 |
+
visibility_ids = np.zeros((num_tokens,), dtype=np.float32)
|
| 1713 |
+
|
| 1714 |
+
# Now create features per token
|
| 1715 |
+
asym_id_to_pdb_id = {}
|
| 1716 |
+
|
| 1717 |
+
for token_dict in tmpl_tokens:
|
| 1718 |
+
idx = token_dict["q_idx"]
|
| 1719 |
+
pdb_id = token_dict["pdb_id"]
|
| 1720 |
+
token = token_dict["token"]
|
| 1721 |
+
query_token = query_tokens.tokens[idx]
|
| 1722 |
+
asym_id_to_pdb_id[query_token["asym_id"]] = pdb_id
|
| 1723 |
+
res_type[idx] = token["res_type"]
|
| 1724 |
+
frame_rot[idx] = token["frame_rot"].reshape(3, 3)
|
| 1725 |
+
frame_t[idx] = token["frame_t"]
|
| 1726 |
+
cb_coords[idx] = token["disto_coords"]
|
| 1727 |
+
ca_coords[idx] = token["center_coords"]
|
| 1728 |
+
cb_mask[idx] = token["disto_mask"]
|
| 1729 |
+
frame_mask[idx] = token["frame_mask"]
|
| 1730 |
+
template_mask[idx] = 1.0
|
| 1731 |
+
|
| 1732 |
+
# Set visibility_id for templated chains
|
| 1733 |
+
for asym_id, pdb_id in asym_id_to_pdb_id.items():
|
| 1734 |
+
indices = (query_tokens.tokens["asym_id"] == asym_id).nonzero()
|
| 1735 |
+
visibility_ids[indices] = pdb_id
|
| 1736 |
+
|
| 1737 |
+
# Set visibility for non templated chain + olygomerics
|
| 1738 |
+
for asym_id in np.unique(query_tokens.structure.chains["asym_id"]):
|
| 1739 |
+
if asym_id not in asym_id_to_pdb_id:
|
| 1740 |
+
# We hack the chain id to be negative to not overlap with the above
|
| 1741 |
+
indices = (query_tokens.tokens["asym_id"] == asym_id).nonzero()
|
| 1742 |
+
visibility_ids[indices] = -1 - asym_id
|
| 1743 |
+
|
| 1744 |
+
# Convert to one-hot
|
| 1745 |
+
res_type = torch.from_numpy(res_type)
|
| 1746 |
+
res_type = one_hot(res_type, num_classes=const.num_tokens)
|
| 1747 |
+
|
| 1748 |
+
return {
|
| 1749 |
+
"template_restype": res_type,
|
| 1750 |
+
"template_frame_rot": torch.from_numpy(frame_rot),
|
| 1751 |
+
"template_frame_t": torch.from_numpy(frame_t),
|
| 1752 |
+
"template_cb": torch.from_numpy(cb_coords),
|
| 1753 |
+
"template_ca": torch.from_numpy(ca_coords),
|
| 1754 |
+
"template_mask_cb": torch.from_numpy(cb_mask),
|
| 1755 |
+
"template_mask_frame": torch.from_numpy(frame_mask),
|
| 1756 |
+
"template_mask": torch.from_numpy(template_mask),
|
| 1757 |
+
"query_to_template": torch.from_numpy(query_to_template),
|
| 1758 |
+
"visibility_ids": torch.from_numpy(visibility_ids),
|
| 1759 |
+
}
|
| 1760 |
+
|
| 1761 |
+
|
| 1762 |
+
def process_template_features(
|
| 1763 |
+
data: Tokenized,
|
| 1764 |
+
max_tokens: int,
|
| 1765 |
+
) -> dict[str, torch.Tensor]:
|
| 1766 |
+
"""Load the given input data.
|
| 1767 |
+
|
| 1768 |
+
Parameters
|
| 1769 |
+
----------
|
| 1770 |
+
data : Tokenized
|
| 1771 |
+
The input to the model.
|
| 1772 |
+
max_tokens : int
|
| 1773 |
+
The maximum number of tokens.
|
| 1774 |
+
|
| 1775 |
+
Returns
|
| 1776 |
+
-------
|
| 1777 |
+
dict[str, torch.Tensor]
|
| 1778 |
+
The loaded template features.
|
| 1779 |
+
|
| 1780 |
+
"""
|
| 1781 |
+
# Group templates by name
|
| 1782 |
+
name_to_templates: dict[str, list[TemplateInfo]] = {}
|
| 1783 |
+
for template_info in data.record.templates:
|
| 1784 |
+
name_to_templates.setdefault(template_info.name, []).append(template_info)
|
| 1785 |
+
|
| 1786 |
+
# Map chain name to asym_id
|
| 1787 |
+
chain_name_to_asym_id = {}
|
| 1788 |
+
for chain in data.structure.chains:
|
| 1789 |
+
chain_name_to_asym_id[chain["name"]] = chain["asym_id"]
|
| 1790 |
+
|
| 1791 |
+
# Compute the offset
|
| 1792 |
+
template_features = []
|
| 1793 |
+
for template_id, (template_name, templates) in enumerate(name_to_templates.items()):
|
| 1794 |
+
row_tokens = []
|
| 1795 |
+
template_structure = data.templates[template_name]
|
| 1796 |
+
template_tokens = data.template_tokens[template_name]
|
| 1797 |
+
tmpl_chain_name_to_asym_id = {}
|
| 1798 |
+
for chain in template_structure.chains:
|
| 1799 |
+
tmpl_chain_name_to_asym_id[chain["name"]] = chain["asym_id"]
|
| 1800 |
+
|
| 1801 |
+
for template in templates:
|
| 1802 |
+
offset = template.template_st - template.query_st
|
| 1803 |
+
|
| 1804 |
+
# Get query and template tokens to map residues
|
| 1805 |
+
query_tokens = data.tokens
|
| 1806 |
+
chain_id = chain_name_to_asym_id[template.query_chain]
|
| 1807 |
+
q_tokens = query_tokens[query_tokens["asym_id"] == chain_id]
|
| 1808 |
+
q_indices = dict(zip(q_tokens["res_idx"], q_tokens["token_idx"]))
|
| 1809 |
+
|
| 1810 |
+
# Get the template tokens at the query residues
|
| 1811 |
+
chain_id = tmpl_chain_name_to_asym_id[template.template_chain]
|
| 1812 |
+
toks = template_tokens[template_tokens["asym_id"] == chain_id]
|
| 1813 |
+
toks = [t for t in toks if t["res_idx"] - offset in q_indices]
|
| 1814 |
+
for t in toks:
|
| 1815 |
+
q_idx = q_indices[t["res_idx"] - offset]
|
| 1816 |
+
row_tokens.append(
|
| 1817 |
+
{
|
| 1818 |
+
"token": t,
|
| 1819 |
+
"pdb_id": template_id,
|
| 1820 |
+
"q_idx": q_idx,
|
| 1821 |
+
}
|
| 1822 |
+
)
|
| 1823 |
+
|
| 1824 |
+
# Compute template features for each row
|
| 1825 |
+
row_features = compute_template_features(data, row_tokens, max_tokens)
|
| 1826 |
+
row_features["template_force"] = torch.tensor(template.force)
|
| 1827 |
+
row_features["template_force_threshold"] = torch.tensor(
|
| 1828 |
+
template.threshold if template.threshold is not None else float("inf"),
|
| 1829 |
+
dtype=torch.float32,
|
| 1830 |
+
)
|
| 1831 |
+
template_features.append(row_features)
|
| 1832 |
+
|
| 1833 |
+
# Stack each feature
|
| 1834 |
+
out = {}
|
| 1835 |
+
for k in template_features[0]:
|
| 1836 |
+
out[k] = torch.stack([f[k] for f in template_features])
|
| 1837 |
+
return out
|
| 1838 |
+
|
| 1839 |
+
|
| 1840 |
+
def process_symmetry_features(
|
| 1841 |
+
cropped: Tokenized, symmetries: dict
|
| 1842 |
+
) -> dict[str, Tensor]:
|
| 1843 |
+
"""Get the symmetry features.
|
| 1844 |
+
|
| 1845 |
+
Parameters
|
| 1846 |
+
----------
|
| 1847 |
+
data : Tokenized
|
| 1848 |
+
The input to the model.
|
| 1849 |
+
|
| 1850 |
+
Returns
|
| 1851 |
+
-------
|
| 1852 |
+
dict[str, Tensor]
|
| 1853 |
+
The symmetry features.
|
| 1854 |
+
|
| 1855 |
+
"""
|
| 1856 |
+
features = get_chain_symmetries(cropped)
|
| 1857 |
+
features.update(get_amino_acids_symmetries(cropped))
|
| 1858 |
+
features.update(get_ligand_symmetries(cropped, symmetries))
|
| 1859 |
+
|
| 1860 |
+
return features
|
| 1861 |
+
|
| 1862 |
+
|
| 1863 |
+
def process_ensemble_features(
|
| 1864 |
+
data: Tokenized,
|
| 1865 |
+
random: np.random.Generator,
|
| 1866 |
+
num_ensembles: int,
|
| 1867 |
+
ensemble_sample_replacement: bool,
|
| 1868 |
+
fix_single_ensemble: bool,
|
| 1869 |
+
) -> dict[str, Tensor]:
|
| 1870 |
+
"""Get the ensemble features.
|
| 1871 |
+
|
| 1872 |
+
Parameters
|
| 1873 |
+
----------
|
| 1874 |
+
data : Tokenized
|
| 1875 |
+
The input to the model.
|
| 1876 |
+
random : np.random.Generator
|
| 1877 |
+
The random number generator.
|
| 1878 |
+
num_ensembles : int
|
| 1879 |
+
The maximum number of ensembles to sample.
|
| 1880 |
+
ensemble_sample_replacement : bool
|
| 1881 |
+
Whether to sample with replacement.
|
| 1882 |
+
|
| 1883 |
+
Returns
|
| 1884 |
+
-------
|
| 1885 |
+
dict[str, Tensor]
|
| 1886 |
+
The ensemble features.
|
| 1887 |
+
|
| 1888 |
+
"""
|
| 1889 |
+
assert num_ensembles > 0, "Number of conformers sampled must be greater than 0."
|
| 1890 |
+
|
| 1891 |
+
# Number of available conformers in the structure
|
| 1892 |
+
# s_ensemble_num = min(len(cropped.structure.ensemble), 24) # Limit to 24 conformers DEBUG: TODO: remove !
|
| 1893 |
+
s_ensemble_num = len(data.structure.ensemble)
|
| 1894 |
+
|
| 1895 |
+
if fix_single_ensemble:
|
| 1896 |
+
# Always take the first conformer for train and validation
|
| 1897 |
+
assert num_ensembles == 1, (
|
| 1898 |
+
"Number of conformers sampled must be 1 with fix_single_ensemble=True."
|
| 1899 |
+
)
|
| 1900 |
+
ensemble_ref_idxs = np.array([0])
|
| 1901 |
+
else:
|
| 1902 |
+
if ensemble_sample_replacement:
|
| 1903 |
+
# Used in training
|
| 1904 |
+
ensemble_ref_idxs = random.integers(0, s_ensemble_num, (num_ensembles,))
|
| 1905 |
+
else:
|
| 1906 |
+
# Used in validation
|
| 1907 |
+
if s_ensemble_num < num_ensembles:
|
| 1908 |
+
# Take all available conformers
|
| 1909 |
+
ensemble_ref_idxs = np.arange(0, s_ensemble_num)
|
| 1910 |
+
else:
|
| 1911 |
+
# Sample without replacement
|
| 1912 |
+
ensemble_ref_idxs = random.choice(
|
| 1913 |
+
s_ensemble_num, num_ensembles, replace=False
|
| 1914 |
+
)
|
| 1915 |
+
|
| 1916 |
+
ensemble_features = {
|
| 1917 |
+
"ensemble_ref_idxs": torch.Tensor(ensemble_ref_idxs).long(),
|
| 1918 |
+
}
|
| 1919 |
+
|
| 1920 |
+
return ensemble_features
|
| 1921 |
+
|
| 1922 |
+
|
| 1923 |
+
def process_residue_constraint_features(data: Tokenized) -> dict[str, Tensor]:
|
| 1924 |
+
residue_constraints = data.residue_constraints
|
| 1925 |
+
if residue_constraints is not None:
|
| 1926 |
+
rdkit_bounds_constraints = residue_constraints.rdkit_bounds_constraints
|
| 1927 |
+
chiral_atom_constraints = residue_constraints.chiral_atom_constraints
|
| 1928 |
+
stereo_bond_constraints = residue_constraints.stereo_bond_constraints
|
| 1929 |
+
planar_bond_constraints = residue_constraints.planar_bond_constraints
|
| 1930 |
+
planar_ring_5_constraints = residue_constraints.planar_ring_5_constraints
|
| 1931 |
+
planar_ring_6_constraints = residue_constraints.planar_ring_6_constraints
|
| 1932 |
+
|
| 1933 |
+
rdkit_bounds_index = torch.tensor(
|
| 1934 |
+
rdkit_bounds_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1935 |
+
).T
|
| 1936 |
+
rdkit_bounds_bond_mask = torch.tensor(
|
| 1937 |
+
rdkit_bounds_constraints["is_bond"].copy(), dtype=torch.bool
|
| 1938 |
+
)
|
| 1939 |
+
rdkit_bounds_angle_mask = torch.tensor(
|
| 1940 |
+
rdkit_bounds_constraints["is_angle"].copy(), dtype=torch.bool
|
| 1941 |
+
)
|
| 1942 |
+
rdkit_upper_bounds = torch.tensor(
|
| 1943 |
+
rdkit_bounds_constraints["upper_bound"].copy(), dtype=torch.float
|
| 1944 |
+
)
|
| 1945 |
+
rdkit_lower_bounds = torch.tensor(
|
| 1946 |
+
rdkit_bounds_constraints["lower_bound"].copy(), dtype=torch.float
|
| 1947 |
+
)
|
| 1948 |
+
|
| 1949 |
+
chiral_atom_index = torch.tensor(
|
| 1950 |
+
chiral_atom_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1951 |
+
).T
|
| 1952 |
+
chiral_reference_mask = torch.tensor(
|
| 1953 |
+
chiral_atom_constraints["is_reference"].copy(), dtype=torch.bool
|
| 1954 |
+
)
|
| 1955 |
+
chiral_atom_orientations = torch.tensor(
|
| 1956 |
+
chiral_atom_constraints["is_r"].copy(), dtype=torch.bool
|
| 1957 |
+
)
|
| 1958 |
+
|
| 1959 |
+
stereo_bond_index = torch.tensor(
|
| 1960 |
+
stereo_bond_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1961 |
+
).T
|
| 1962 |
+
stereo_reference_mask = torch.tensor(
|
| 1963 |
+
stereo_bond_constraints["is_reference"].copy(), dtype=torch.bool
|
| 1964 |
+
)
|
| 1965 |
+
stereo_bond_orientations = torch.tensor(
|
| 1966 |
+
stereo_bond_constraints["is_e"].copy(), dtype=torch.bool
|
| 1967 |
+
)
|
| 1968 |
+
|
| 1969 |
+
planar_bond_index = torch.tensor(
|
| 1970 |
+
planar_bond_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1971 |
+
).T
|
| 1972 |
+
planar_ring_5_index = torch.tensor(
|
| 1973 |
+
planar_ring_5_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1974 |
+
).T
|
| 1975 |
+
planar_ring_6_index = torch.tensor(
|
| 1976 |
+
planar_ring_6_constraints["atom_idxs"].copy(), dtype=torch.long
|
| 1977 |
+
).T
|
| 1978 |
+
else:
|
| 1979 |
+
rdkit_bounds_index = torch.empty((2, 0), dtype=torch.long)
|
| 1980 |
+
rdkit_bounds_bond_mask = torch.empty((0,), dtype=torch.bool)
|
| 1981 |
+
rdkit_bounds_angle_mask = torch.empty((0,), dtype=torch.bool)
|
| 1982 |
+
rdkit_upper_bounds = torch.empty((0,), dtype=torch.float)
|
| 1983 |
+
rdkit_lower_bounds = torch.empty((0,), dtype=torch.float)
|
| 1984 |
+
chiral_atom_index = torch.empty(
|
| 1985 |
+
(
|
| 1986 |
+
4,
|
| 1987 |
+
0,
|
| 1988 |
+
),
|
| 1989 |
+
dtype=torch.long,
|
| 1990 |
+
)
|
| 1991 |
+
chiral_reference_mask = torch.empty((0,), dtype=torch.bool)
|
| 1992 |
+
chiral_atom_orientations = torch.empty((0,), dtype=torch.bool)
|
| 1993 |
+
stereo_bond_index = torch.empty((4, 0), dtype=torch.long)
|
| 1994 |
+
stereo_reference_mask = torch.empty((0,), dtype=torch.bool)
|
| 1995 |
+
stereo_bond_orientations = torch.empty((0,), dtype=torch.bool)
|
| 1996 |
+
planar_bond_index = torch.empty((6, 0), dtype=torch.long)
|
| 1997 |
+
planar_ring_5_index = torch.empty((5, 0), dtype=torch.long)
|
| 1998 |
+
planar_ring_6_index = torch.empty((6, 0), dtype=torch.long)
|
| 1999 |
+
|
| 2000 |
+
return {
|
| 2001 |
+
"rdkit_bounds_index": rdkit_bounds_index,
|
| 2002 |
+
"rdkit_bounds_bond_mask": rdkit_bounds_bond_mask,
|
| 2003 |
+
"rdkit_bounds_angle_mask": rdkit_bounds_angle_mask,
|
| 2004 |
+
"rdkit_upper_bounds": rdkit_upper_bounds,
|
| 2005 |
+
"rdkit_lower_bounds": rdkit_lower_bounds,
|
| 2006 |
+
"chiral_atom_index": chiral_atom_index,
|
| 2007 |
+
"chiral_reference_mask": chiral_reference_mask,
|
| 2008 |
+
"chiral_atom_orientations": chiral_atom_orientations,
|
| 2009 |
+
"stereo_bond_index": stereo_bond_index,
|
| 2010 |
+
"stereo_reference_mask": stereo_reference_mask,
|
| 2011 |
+
"stereo_bond_orientations": stereo_bond_orientations,
|
| 2012 |
+
"planar_bond_index": planar_bond_index,
|
| 2013 |
+
"planar_ring_5_index": planar_ring_5_index,
|
| 2014 |
+
"planar_ring_6_index": planar_ring_6_index,
|
| 2015 |
+
}
|
| 2016 |
+
|
| 2017 |
+
|
| 2018 |
+
def process_chain_feature_constraints(data: Tokenized) -> dict[str, Tensor]:
|
| 2019 |
+
structure = data.structure
|
| 2020 |
+
if structure.bonds.shape[0] > 0:
|
| 2021 |
+
connected_chain_index, connected_atom_index = [], []
|
| 2022 |
+
for connection in structure.bonds:
|
| 2023 |
+
if connection["chain_1"] == connection["chain_2"]:
|
| 2024 |
+
continue
|
| 2025 |
+
connected_chain_index.append([connection["chain_1"], connection["chain_2"]])
|
| 2026 |
+
connected_atom_index.append([connection["atom_1"], connection["atom_2"]])
|
| 2027 |
+
if len(connected_chain_index) > 0:
|
| 2028 |
+
connected_chain_index = torch.tensor(
|
| 2029 |
+
connected_chain_index, dtype=torch.long
|
| 2030 |
+
).T
|
| 2031 |
+
connected_atom_index = torch.tensor(
|
| 2032 |
+
connected_atom_index, dtype=torch.long
|
| 2033 |
+
).T
|
| 2034 |
+
else:
|
| 2035 |
+
connected_chain_index = torch.empty((2, 0), dtype=torch.long)
|
| 2036 |
+
connected_atom_index = torch.empty((2, 0), dtype=torch.long)
|
| 2037 |
+
else:
|
| 2038 |
+
connected_chain_index = torch.empty((2, 0), dtype=torch.long)
|
| 2039 |
+
connected_atom_index = torch.empty((2, 0), dtype=torch.long)
|
| 2040 |
+
|
| 2041 |
+
symmetric_chain_index = []
|
| 2042 |
+
for i, chain_i in enumerate(structure.chains):
|
| 2043 |
+
for j, chain_j in enumerate(structure.chains):
|
| 2044 |
+
if j <= i:
|
| 2045 |
+
continue
|
| 2046 |
+
if chain_i["entity_id"] == chain_j["entity_id"]:
|
| 2047 |
+
symmetric_chain_index.append([i, j])
|
| 2048 |
+
if len(symmetric_chain_index) > 0:
|
| 2049 |
+
symmetric_chain_index = torch.tensor(symmetric_chain_index, dtype=torch.long).T
|
| 2050 |
+
else:
|
| 2051 |
+
symmetric_chain_index = torch.empty((2, 0), dtype=torch.long)
|
| 2052 |
+
return {
|
| 2053 |
+
"connected_chain_index": connected_chain_index,
|
| 2054 |
+
"connected_atom_index": connected_atom_index,
|
| 2055 |
+
"symmetric_chain_index": symmetric_chain_index,
|
| 2056 |
+
}
|
| 2057 |
+
|
| 2058 |
+
|
| 2059 |
+
def process_contact_feature_constraints(
|
| 2060 |
+
data: Tokenized,
|
| 2061 |
+
inference_pocket_constraints: list[tuple[int, list[tuple[int, int]], float]],
|
| 2062 |
+
inference_contact_constraints: list[tuple[tuple[int, int], tuple[int, int], float]],
|
| 2063 |
+
):
|
| 2064 |
+
token_data = data.tokens
|
| 2065 |
+
union_idx = 0
|
| 2066 |
+
pair_index, union_index, negation_mask, thresholds = [], [], [], []
|
| 2067 |
+
for binder, contacts, max_distance, force in inference_pocket_constraints:
|
| 2068 |
+
if not force:
|
| 2069 |
+
continue
|
| 2070 |
+
|
| 2071 |
+
binder_chain = data.structure.chains[binder]
|
| 2072 |
+
for token in token_data:
|
| 2073 |
+
if (
|
| 2074 |
+
token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
| 2075 |
+
and (token["asym_id"], token["res_idx"]) in contacts
|
| 2076 |
+
) or (
|
| 2077 |
+
token["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
| 2078 |
+
and (token["asym_id"], token["atom_idx"]) in contacts
|
| 2079 |
+
):
|
| 2080 |
+
atom_idx_pairs = torch.cartesian_prod(
|
| 2081 |
+
torch.arange(
|
| 2082 |
+
binder_chain["atom_idx"],
|
| 2083 |
+
binder_chain["atom_idx"] + binder_chain["atom_num"],
|
| 2084 |
+
),
|
| 2085 |
+
torch.arange(
|
| 2086 |
+
token["atom_idx"], token["atom_idx"] + token["atom_num"]
|
| 2087 |
+
),
|
| 2088 |
+
).T
|
| 2089 |
+
pair_index.append(atom_idx_pairs)
|
| 2090 |
+
union_index.append(torch.full((atom_idx_pairs.shape[1],), union_idx))
|
| 2091 |
+
negation_mask.append(
|
| 2092 |
+
torch.ones((atom_idx_pairs.shape[1],), dtype=torch.bool)
|
| 2093 |
+
)
|
| 2094 |
+
thresholds.append(torch.full((atom_idx_pairs.shape[1],), max_distance))
|
| 2095 |
+
union_idx += 1
|
| 2096 |
+
|
| 2097 |
+
for token1, token2, max_distance, force in inference_contact_constraints:
|
| 2098 |
+
if not force:
|
| 2099 |
+
continue
|
| 2100 |
+
|
| 2101 |
+
for idx1, _token1 in enumerate(token_data):
|
| 2102 |
+
if (
|
| 2103 |
+
_token1["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
| 2104 |
+
and (_token1["asym_id"], _token1["res_idx"]) == token1
|
| 2105 |
+
) or (
|
| 2106 |
+
_token1["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
| 2107 |
+
and (_token1["asym_id"], _token1["atom_idx"]) == token1
|
| 2108 |
+
):
|
| 2109 |
+
for idx2, _token2 in enumerate(token_data):
|
| 2110 |
+
if (
|
| 2111 |
+
_token2["mol_type"] != const.chain_type_ids["NONPOLYMER"]
|
| 2112 |
+
and (_token2["asym_id"], _token2["res_idx"]) == token2
|
| 2113 |
+
) or (
|
| 2114 |
+
_token2["mol_type"] == const.chain_type_ids["NONPOLYMER"]
|
| 2115 |
+
and (_token2["asym_id"], _token2["atom_idx"]) == token2
|
| 2116 |
+
):
|
| 2117 |
+
atom_idx_pairs = torch.cartesian_prod(
|
| 2118 |
+
torch.arange(
|
| 2119 |
+
_token1["atom_idx"],
|
| 2120 |
+
_token1["atom_idx"] + _token1["atom_num"],
|
| 2121 |
+
),
|
| 2122 |
+
torch.arange(
|
| 2123 |
+
_token2["atom_idx"],
|
| 2124 |
+
_token2["atom_idx"] + _token2["atom_num"],
|
| 2125 |
+
),
|
| 2126 |
+
).T
|
| 2127 |
+
pair_index.append(atom_idx_pairs)
|
| 2128 |
+
union_index.append(
|
| 2129 |
+
torch.full((atom_idx_pairs.shape[1],), union_idx)
|
| 2130 |
+
)
|
| 2131 |
+
negation_mask.append(
|
| 2132 |
+
torch.ones((atom_idx_pairs.shape[1],), dtype=torch.bool)
|
| 2133 |
+
)
|
| 2134 |
+
thresholds.append(
|
| 2135 |
+
torch.full((atom_idx_pairs.shape[1],), max_distance)
|
| 2136 |
+
)
|
| 2137 |
+
union_idx += 1
|
| 2138 |
+
break
|
| 2139 |
+
break
|
| 2140 |
+
|
| 2141 |
+
if len(pair_index) > 0:
|
| 2142 |
+
pair_index = torch.cat(pair_index, dim=1)
|
| 2143 |
+
union_index = torch.cat(union_index)
|
| 2144 |
+
negation_mask = torch.cat(negation_mask)
|
| 2145 |
+
thresholds = torch.cat(thresholds)
|
| 2146 |
+
else:
|
| 2147 |
+
pair_index = torch.empty((2, 0), dtype=torch.long)
|
| 2148 |
+
union_index = torch.empty((0,), dtype=torch.long)
|
| 2149 |
+
negation_mask = torch.empty((0,), dtype=torch.bool)
|
| 2150 |
+
thresholds = torch.empty((0,), dtype=torch.float32)
|
| 2151 |
+
|
| 2152 |
+
return {
|
| 2153 |
+
"contact_pair_index": pair_index,
|
| 2154 |
+
"contact_union_index": union_index,
|
| 2155 |
+
"contact_negation_mask": negation_mask,
|
| 2156 |
+
"contact_thresholds": thresholds,
|
| 2157 |
+
}
|
| 2158 |
+
|
| 2159 |
+
|
| 2160 |
+
class Boltz2Featurizer:
|
| 2161 |
+
"""Boltz2 featurizer."""
|
| 2162 |
+
|
| 2163 |
+
def process(
|
| 2164 |
+
self,
|
| 2165 |
+
data: Tokenized,
|
| 2166 |
+
random: np.random.Generator,
|
| 2167 |
+
molecules: dict[str, Mol],
|
| 2168 |
+
training: bool,
|
| 2169 |
+
max_seqs: int,
|
| 2170 |
+
atoms_per_window_queries: int = 32,
|
| 2171 |
+
min_dist: float = 2.0,
|
| 2172 |
+
max_dist: float = 22.0,
|
| 2173 |
+
num_bins: int = 64,
|
| 2174 |
+
num_ensembles: int = 1,
|
| 2175 |
+
ensemble_sample_replacement: bool = False,
|
| 2176 |
+
disto_use_ensemble: Optional[bool] = False,
|
| 2177 |
+
fix_single_ensemble: Optional[bool] = True,
|
| 2178 |
+
max_tokens: Optional[int] = None,
|
| 2179 |
+
max_atoms: Optional[int] = None,
|
| 2180 |
+
pad_to_max_seqs: bool = False,
|
| 2181 |
+
compute_symmetries: bool = False,
|
| 2182 |
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
| 2183 |
+
contact_conditioned_prop: Optional[float] = 0.0,
|
| 2184 |
+
binder_pocket_cutoff_min: Optional[float] = 4.0,
|
| 2185 |
+
binder_pocket_cutoff_max: Optional[float] = 20.0,
|
| 2186 |
+
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
|
| 2187 |
+
only_ligand_binder_pocket: Optional[bool] = False,
|
| 2188 |
+
only_pp_contact: Optional[bool] = False,
|
| 2189 |
+
single_sequence_prop: Optional[float] = 0.0,
|
| 2190 |
+
msa_sampling: bool = False,
|
| 2191 |
+
override_bfactor: float = False,
|
| 2192 |
+
override_method: Optional[str] = None,
|
| 2193 |
+
compute_frames: bool = False,
|
| 2194 |
+
override_coords: Optional[Tensor] = None,
|
| 2195 |
+
bfactor_md_correction: bool = False,
|
| 2196 |
+
compute_constraint_features: bool = False,
|
| 2197 |
+
inference_pocket_constraints: Optional[
|
| 2198 |
+
list[tuple[int, list[tuple[int, int]], float]]
|
| 2199 |
+
] = None,
|
| 2200 |
+
inference_contact_constraints: Optional[
|
| 2201 |
+
list[tuple[tuple[int, int], tuple[int, int], float]]
|
| 2202 |
+
] = None,
|
| 2203 |
+
compute_affinity: bool = False,
|
| 2204 |
+
) -> dict[str, Tensor]:
|
| 2205 |
+
"""Compute features.
|
| 2206 |
+
|
| 2207 |
+
Parameters
|
| 2208 |
+
----------
|
| 2209 |
+
data : Tokenized
|
| 2210 |
+
The input to the model.
|
| 2211 |
+
training : bool
|
| 2212 |
+
Whether the model is in training mode.
|
| 2213 |
+
max_tokens : int, optional
|
| 2214 |
+
The maximum number of tokens.
|
| 2215 |
+
max_atoms : int, optional
|
| 2216 |
+
The maximum number of atoms
|
| 2217 |
+
max_seqs : int, optional
|
| 2218 |
+
The maximum number of sequences.
|
| 2219 |
+
|
| 2220 |
+
Returns
|
| 2221 |
+
-------
|
| 2222 |
+
dict[str, Tensor]
|
| 2223 |
+
The features for model training.
|
| 2224 |
+
|
| 2225 |
+
"""
|
| 2226 |
+
# Compute random number of sequences
|
| 2227 |
+
if training and max_seqs is not None:
|
| 2228 |
+
if random.random() > single_sequence_prop:
|
| 2229 |
+
max_seqs_batch = random.integers(1, max_seqs + 1)
|
| 2230 |
+
else:
|
| 2231 |
+
max_seqs_batch = 1
|
| 2232 |
+
else:
|
| 2233 |
+
max_seqs_batch = max_seqs
|
| 2234 |
+
|
| 2235 |
+
# Compute ensemble features
|
| 2236 |
+
ensemble_features = process_ensemble_features(
|
| 2237 |
+
data=data,
|
| 2238 |
+
random=random,
|
| 2239 |
+
num_ensembles=num_ensembles,
|
| 2240 |
+
ensemble_sample_replacement=ensemble_sample_replacement,
|
| 2241 |
+
fix_single_ensemble=fix_single_ensemble,
|
| 2242 |
+
)
|
| 2243 |
+
|
| 2244 |
+
# Compute token features
|
| 2245 |
+
token_features = process_token_features(
|
| 2246 |
+
data=data,
|
| 2247 |
+
random=random,
|
| 2248 |
+
max_tokens=max_tokens,
|
| 2249 |
+
binder_pocket_conditioned_prop=binder_pocket_conditioned_prop,
|
| 2250 |
+
contact_conditioned_prop=contact_conditioned_prop,
|
| 2251 |
+
binder_pocket_cutoff_min=binder_pocket_cutoff_min,
|
| 2252 |
+
binder_pocket_cutoff_max=binder_pocket_cutoff_max,
|
| 2253 |
+
binder_pocket_sampling_geometric_p=binder_pocket_sampling_geometric_p,
|
| 2254 |
+
only_ligand_binder_pocket=only_ligand_binder_pocket,
|
| 2255 |
+
only_pp_contact=only_pp_contact,
|
| 2256 |
+
override_method=override_method,
|
| 2257 |
+
inference_pocket_constraints=inference_pocket_constraints,
|
| 2258 |
+
inference_contact_constraints=inference_contact_constraints,
|
| 2259 |
+
)
|
| 2260 |
+
|
| 2261 |
+
# Compute atom features
|
| 2262 |
+
atom_features = process_atom_features(
|
| 2263 |
+
data=data,
|
| 2264 |
+
random=random,
|
| 2265 |
+
molecules=molecules,
|
| 2266 |
+
ensemble_features=ensemble_features,
|
| 2267 |
+
atoms_per_window_queries=atoms_per_window_queries,
|
| 2268 |
+
min_dist=min_dist,
|
| 2269 |
+
max_dist=max_dist,
|
| 2270 |
+
num_bins=num_bins,
|
| 2271 |
+
max_atoms=max_atoms,
|
| 2272 |
+
max_tokens=max_tokens,
|
| 2273 |
+
disto_use_ensemble=disto_use_ensemble,
|
| 2274 |
+
override_bfactor=override_bfactor,
|
| 2275 |
+
compute_frames=compute_frames,
|
| 2276 |
+
override_coords=override_coords,
|
| 2277 |
+
bfactor_md_correction=bfactor_md_correction,
|
| 2278 |
+
)
|
| 2279 |
+
|
| 2280 |
+
# Compute MSA features
|
| 2281 |
+
msa_features = process_msa_features(
|
| 2282 |
+
data=data,
|
| 2283 |
+
random=random,
|
| 2284 |
+
max_seqs_batch=max_seqs_batch,
|
| 2285 |
+
max_seqs=max_seqs,
|
| 2286 |
+
max_tokens=max_tokens,
|
| 2287 |
+
pad_to_max_seqs=pad_to_max_seqs,
|
| 2288 |
+
msa_sampling=training and msa_sampling,
|
| 2289 |
+
)
|
| 2290 |
+
|
| 2291 |
+
# Compute MSA features
|
| 2292 |
+
msa_features_affinity = {}
|
| 2293 |
+
if compute_affinity:
|
| 2294 |
+
msa_features_affinity = process_msa_features(
|
| 2295 |
+
data=data,
|
| 2296 |
+
random=random,
|
| 2297 |
+
max_seqs_batch=1,
|
| 2298 |
+
max_seqs=1,
|
| 2299 |
+
max_tokens=max_tokens,
|
| 2300 |
+
pad_to_max_seqs=pad_to_max_seqs,
|
| 2301 |
+
msa_sampling=training and msa_sampling,
|
| 2302 |
+
affinity=True,
|
| 2303 |
+
)
|
| 2304 |
+
|
| 2305 |
+
# Compute affinity ligand Molecular Weight
|
| 2306 |
+
ligand_to_mw = {}
|
| 2307 |
+
if compute_affinity:
|
| 2308 |
+
ligand_to_mw["affinity_mw"] = data.record.affinity.mw
|
| 2309 |
+
|
| 2310 |
+
# Compute template features
|
| 2311 |
+
num_tokens = data.tokens.shape[0] if max_tokens is None else max_tokens
|
| 2312 |
+
if data.templates and not compute_affinity:
|
| 2313 |
+
template_features = process_template_features(
|
| 2314 |
+
data=data,
|
| 2315 |
+
max_tokens=num_tokens,
|
| 2316 |
+
)
|
| 2317 |
+
else:
|
| 2318 |
+
template_features = load_dummy_templates_features(
|
| 2319 |
+
tdim=1,
|
| 2320 |
+
num_tokens=num_tokens,
|
| 2321 |
+
)
|
| 2322 |
+
|
| 2323 |
+
# Compute symmetry features
|
| 2324 |
+
symmetry_features = {}
|
| 2325 |
+
if compute_symmetries:
|
| 2326 |
+
symmetries = get_symmetries(molecules)
|
| 2327 |
+
symmetry_features = process_symmetry_features(data, symmetries)
|
| 2328 |
+
|
| 2329 |
+
# Compute constraint features
|
| 2330 |
+
residue_constraint_features = {}
|
| 2331 |
+
chain_constraint_features = {}
|
| 2332 |
+
contact_constraint_features = {}
|
| 2333 |
+
if compute_constraint_features:
|
| 2334 |
+
residue_constraint_features = process_residue_constraint_features(data)
|
| 2335 |
+
chain_constraint_features = process_chain_feature_constraints(data)
|
| 2336 |
+
contact_constraint_features = process_contact_feature_constraints(
|
| 2337 |
+
data=data,
|
| 2338 |
+
inference_pocket_constraints=inference_pocket_constraints if inference_pocket_constraints else [],
|
| 2339 |
+
inference_contact_constraints=inference_contact_constraints if inference_contact_constraints else [],
|
| 2340 |
+
)
|
| 2341 |
+
|
| 2342 |
+
return {
|
| 2343 |
+
**token_features,
|
| 2344 |
+
**atom_features,
|
| 2345 |
+
**msa_features,
|
| 2346 |
+
**msa_features_affinity,
|
| 2347 |
+
**template_features,
|
| 2348 |
+
**symmetry_features,
|
| 2349 |
+
**ensemble_features,
|
| 2350 |
+
**residue_constraint_features,
|
| 2351 |
+
**chain_constraint_features,
|
| 2352 |
+
**contact_constraint_features,
|
| 2353 |
+
**ligand_to_mw,
|
| 2354 |
+
}
|
protify/FastPLMs/boltz/src/boltz/data/feature/symmetry.py
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import pickle
|
| 3 |
+
import random
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from boltz.data import const
|
| 10 |
+
from boltz.data.pad import pad_dim
|
| 11 |
+
from boltz.model.loss.confidence import lddt_dist
|
| 12 |
+
from boltz.model.loss.validation import weighted_minimum_rmsd_single
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def convert_atom_name(name: str) -> tuple[int, int, int, int]:
|
| 16 |
+
"""Convert an atom name to a standard format.
|
| 17 |
+
|
| 18 |
+
Parameters
|
| 19 |
+
----------
|
| 20 |
+
name : str
|
| 21 |
+
The atom name.
|
| 22 |
+
|
| 23 |
+
Returns
|
| 24 |
+
-------
|
| 25 |
+
Tuple[int, int, int, int]
|
| 26 |
+
The converted atom name.
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
name = name.strip()
|
| 30 |
+
name = [ord(c) - 32 for c in name]
|
| 31 |
+
name = name + [0] * (4 - len(name))
|
| 32 |
+
return tuple(name)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_symmetries(path: str) -> dict:
|
| 36 |
+
"""Create a dictionary for the ligand symmetries.
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
path : str
|
| 41 |
+
The path to the ligand symmetries.
|
| 42 |
+
|
| 43 |
+
Returns
|
| 44 |
+
-------
|
| 45 |
+
dict
|
| 46 |
+
The ligand symmetries.
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
with Path(path).open("rb") as f:
|
| 50 |
+
data: dict = pickle.load(f) # noqa: S301
|
| 51 |
+
|
| 52 |
+
symmetries = {}
|
| 53 |
+
for key, mol in data.items():
|
| 54 |
+
try:
|
| 55 |
+
serialized_sym = bytes.fromhex(mol.GetProp("symmetries"))
|
| 56 |
+
sym = pickle.loads(serialized_sym) # noqa: S301
|
| 57 |
+
atom_names = []
|
| 58 |
+
for atom in mol.GetAtoms():
|
| 59 |
+
# Get atom name
|
| 60 |
+
atom_name = convert_atom_name(atom.GetProp("name"))
|
| 61 |
+
atom_names.append(atom_name)
|
| 62 |
+
|
| 63 |
+
symmetries[key] = (sym, atom_names)
|
| 64 |
+
except Exception: # noqa: BLE001, PERF203, S110
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
return symmetries
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def compute_symmetry_idx_dictionary(data):
|
| 71 |
+
# Compute the symmetry index dictionary
|
| 72 |
+
total_count = 0
|
| 73 |
+
all_coords = []
|
| 74 |
+
for i, chain in enumerate(data.chains):
|
| 75 |
+
chain.start_idx = total_count
|
| 76 |
+
for j, token in enumerate(chain.tokens):
|
| 77 |
+
token.start_idx = total_count - chain.start_idx
|
| 78 |
+
all_coords.extend(
|
| 79 |
+
[[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
|
| 80 |
+
)
|
| 81 |
+
total_count += len(token.atoms)
|
| 82 |
+
return all_coords
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_current_idx_list(data):
|
| 86 |
+
idx = []
|
| 87 |
+
for chain in data.chains:
|
| 88 |
+
if chain.in_crop:
|
| 89 |
+
for token in chain.tokens:
|
| 90 |
+
if token.in_crop:
|
| 91 |
+
idx.extend(
|
| 92 |
+
[
|
| 93 |
+
chain.start_idx + token.start_idx + i
|
| 94 |
+
for i in range(len(token.atoms))
|
| 95 |
+
]
|
| 96 |
+
)
|
| 97 |
+
return idx
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def all_different_after_swap(l):
|
| 101 |
+
final = [s[-1] for s in l]
|
| 102 |
+
return len(final) == len(set(final))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def minimum_symmetry_coords(
|
| 106 |
+
coords: torch.Tensor,
|
| 107 |
+
feats: dict,
|
| 108 |
+
index_batch: int,
|
| 109 |
+
**args_rmsd,
|
| 110 |
+
):
|
| 111 |
+
all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
|
| 112 |
+
all_resolved_mask = (
|
| 113 |
+
feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
|
| 114 |
+
)
|
| 115 |
+
crop_to_all_atom_map = (
|
| 116 |
+
feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
|
| 117 |
+
)
|
| 118 |
+
chain_symmetries = feats["chain_symmetries"][index_batch]
|
| 119 |
+
amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
|
| 120 |
+
ligand_symmetries = feats["ligand_symmetries"][index_batch]
|
| 121 |
+
|
| 122 |
+
# Check best symmetry on chain swap
|
| 123 |
+
best_true_coords = None
|
| 124 |
+
best_rmsd = float("inf")
|
| 125 |
+
best_align_weights = None
|
| 126 |
+
for c in chain_symmetries:
|
| 127 |
+
true_all_coords = all_coords.clone()
|
| 128 |
+
true_all_resolved_mask = all_resolved_mask.clone()
|
| 129 |
+
for start1, end1, start2, end2, chainidx1, chainidx2 in c:
|
| 130 |
+
true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
|
| 131 |
+
true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
|
| 132 |
+
true_coords = true_all_coords[:, crop_to_all_atom_map]
|
| 133 |
+
true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
|
| 134 |
+
true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
|
| 135 |
+
true_resolved_mask = pad_dim(
|
| 136 |
+
true_resolved_mask,
|
| 137 |
+
0,
|
| 138 |
+
coords.shape[1] - true_resolved_mask.shape[0],
|
| 139 |
+
)
|
| 140 |
+
try:
|
| 141 |
+
rmsd, aligned_coords, align_weights = weighted_minimum_rmsd_single(
|
| 142 |
+
coords,
|
| 143 |
+
true_coords,
|
| 144 |
+
atom_mask=true_resolved_mask,
|
| 145 |
+
atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
|
| 146 |
+
mol_type=feats["mol_type"][index_batch : index_batch + 1],
|
| 147 |
+
**args_rmsd,
|
| 148 |
+
)
|
| 149 |
+
except:
|
| 150 |
+
print("Warning: error in rmsd computation inside symmetry code")
|
| 151 |
+
continue
|
| 152 |
+
rmsd = rmsd.item()
|
| 153 |
+
|
| 154 |
+
if rmsd < best_rmsd:
|
| 155 |
+
best_rmsd = rmsd
|
| 156 |
+
best_true_coords = aligned_coords
|
| 157 |
+
best_align_weights = align_weights
|
| 158 |
+
best_true_resolved_mask = true_resolved_mask
|
| 159 |
+
|
| 160 |
+
# atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
|
| 161 |
+
true_coords = best_true_coords.clone()
|
| 162 |
+
true_resolved_mask = best_true_resolved_mask.clone()
|
| 163 |
+
for symmetric_amino in amino_acids_symmetries:
|
| 164 |
+
for c in symmetric_amino:
|
| 165 |
+
# starting from greedy best, try to swap the atoms
|
| 166 |
+
new_true_coords = true_coords.clone()
|
| 167 |
+
new_true_resolved_mask = true_resolved_mask.clone()
|
| 168 |
+
for i, j in c:
|
| 169 |
+
new_true_coords[:, i] = true_coords[:, j]
|
| 170 |
+
new_true_resolved_mask[i] = true_resolved_mask[j]
|
| 171 |
+
|
| 172 |
+
# compute squared distance, for efficiency we do not recompute the alignment
|
| 173 |
+
best_mse_loss = torch.sum(
|
| 174 |
+
((coords - best_true_coords) ** 2).sum(dim=-1)
|
| 175 |
+
* best_align_weights
|
| 176 |
+
* best_true_resolved_mask,
|
| 177 |
+
dim=-1,
|
| 178 |
+
) / torch.sum(best_align_weights * best_true_resolved_mask, dim=-1)
|
| 179 |
+
new_mse_loss = torch.sum(
|
| 180 |
+
((coords - new_true_coords) ** 2).sum(dim=-1)
|
| 181 |
+
* best_align_weights
|
| 182 |
+
* new_true_resolved_mask,
|
| 183 |
+
dim=-1,
|
| 184 |
+
) / torch.sum(best_align_weights * new_true_resolved_mask, dim=-1)
|
| 185 |
+
|
| 186 |
+
if best_mse_loss > new_mse_loss:
|
| 187 |
+
best_true_coords = new_true_coords
|
| 188 |
+
best_true_resolved_mask = new_true_resolved_mask
|
| 189 |
+
|
| 190 |
+
# greedily update best coordinates after each amino acid
|
| 191 |
+
true_coords = best_true_coords.clone()
|
| 192 |
+
true_resolved_mask = best_true_resolved_mask.clone()
|
| 193 |
+
|
| 194 |
+
# Recomputing alignment
|
| 195 |
+
rmsd, true_coords, best_align_weights = weighted_minimum_rmsd_single(
|
| 196 |
+
coords,
|
| 197 |
+
true_coords,
|
| 198 |
+
atom_mask=true_resolved_mask,
|
| 199 |
+
atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
|
| 200 |
+
mol_type=feats["mol_type"][index_batch : index_batch + 1],
|
| 201 |
+
**args_rmsd,
|
| 202 |
+
)
|
| 203 |
+
best_rmsd = rmsd.item()
|
| 204 |
+
|
| 205 |
+
# atom symmetries (ligand and non-standard), resolved greedily recomputing alignment
|
| 206 |
+
for symmetric_ligand in ligand_symmetries:
|
| 207 |
+
for c in symmetric_ligand:
|
| 208 |
+
new_true_coords = true_coords.clone()
|
| 209 |
+
new_true_resolved_mask = true_resolved_mask.clone()
|
| 210 |
+
for i, j in c:
|
| 211 |
+
new_true_coords[:, j] = true_coords[:, i]
|
| 212 |
+
new_true_resolved_mask[j] = true_resolved_mask[i]
|
| 213 |
+
try:
|
| 214 |
+
# TODO if this is too slow maybe we can get away with not recomputing alignment
|
| 215 |
+
rmsd, aligned_coords, align_weights = weighted_minimum_rmsd_single(
|
| 216 |
+
coords,
|
| 217 |
+
new_true_coords,
|
| 218 |
+
atom_mask=new_true_resolved_mask,
|
| 219 |
+
atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
|
| 220 |
+
mol_type=feats["mol_type"][index_batch : index_batch + 1],
|
| 221 |
+
**args_rmsd,
|
| 222 |
+
)
|
| 223 |
+
except Exception as e:
|
| 224 |
+
raise e
|
| 225 |
+
print(e)
|
| 226 |
+
continue
|
| 227 |
+
rmsd = rmsd.item()
|
| 228 |
+
if rmsd < best_rmsd:
|
| 229 |
+
best_true_coords = aligned_coords
|
| 230 |
+
best_rmsd = rmsd
|
| 231 |
+
best_true_resolved_mask = new_true_resolved_mask
|
| 232 |
+
|
| 233 |
+
true_coords = best_true_coords.clone()
|
| 234 |
+
true_resolved_mask = best_true_resolved_mask.clone()
|
| 235 |
+
|
| 236 |
+
return best_true_coords, best_rmsd, best_true_resolved_mask.unsqueeze(0)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def minimum_lddt_symmetry_coords(
|
| 240 |
+
coords: torch.Tensor,
|
| 241 |
+
feats: dict,
|
| 242 |
+
index_batch: int,
|
| 243 |
+
**args_rmsd,
|
| 244 |
+
):
|
| 245 |
+
all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
|
| 246 |
+
all_resolved_mask = (
|
| 247 |
+
feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
|
| 248 |
+
)
|
| 249 |
+
crop_to_all_atom_map = (
|
| 250 |
+
feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
|
| 251 |
+
)
|
| 252 |
+
chain_symmetries = feats["chain_symmetries"][index_batch]
|
| 253 |
+
amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
|
| 254 |
+
ligand_symmetries = feats["ligand_symmetries"][index_batch]
|
| 255 |
+
|
| 256 |
+
dmat_predicted = torch.cdist(
|
| 257 |
+
coords[:, : len(crop_to_all_atom_map)], coords[:, : len(crop_to_all_atom_map)]
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Check best symmetry on chain swap
|
| 261 |
+
best_true_coords = None
|
| 262 |
+
best_lddt = 0
|
| 263 |
+
for c in chain_symmetries:
|
| 264 |
+
true_all_coords = all_coords.clone()
|
| 265 |
+
true_all_resolved_mask = all_resolved_mask.clone()
|
| 266 |
+
for start1, end1, start2, end2, chainidx1, chainidx2 in c:
|
| 267 |
+
true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
|
| 268 |
+
true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
|
| 269 |
+
true_coords = true_all_coords[:, crop_to_all_atom_map]
|
| 270 |
+
true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
|
| 271 |
+
dmat_true = torch.cdist(true_coords, true_coords)
|
| 272 |
+
pair_mask = (
|
| 273 |
+
true_resolved_mask[:, None]
|
| 274 |
+
* true_resolved_mask[None, :]
|
| 275 |
+
* (1 - torch.eye(len(true_resolved_mask))).to(true_resolved_mask)
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
lddt = lddt_dist(
|
| 279 |
+
dmat_predicted, dmat_true, pair_mask, cutoff=15.0, per_atom=False
|
| 280 |
+
)[0]
|
| 281 |
+
lddt = lddt.item()
|
| 282 |
+
|
| 283 |
+
if lddt > best_lddt:
|
| 284 |
+
best_lddt = lddt
|
| 285 |
+
best_true_coords = true_coords
|
| 286 |
+
best_true_resolved_mask = true_resolved_mask
|
| 287 |
+
|
| 288 |
+
# atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
|
| 289 |
+
true_coords = best_true_coords.clone()
|
| 290 |
+
true_resolved_mask = best_true_resolved_mask.clone()
|
| 291 |
+
for symmetric_amino_or_lig in amino_acids_symmetries + ligand_symmetries:
|
| 292 |
+
for c in symmetric_amino_or_lig:
|
| 293 |
+
# starting from greedy best, try to swap the atoms
|
| 294 |
+
new_true_coords = true_coords.clone()
|
| 295 |
+
new_true_resolved_mask = true_resolved_mask.clone()
|
| 296 |
+
indices = []
|
| 297 |
+
for i, j in c:
|
| 298 |
+
new_true_coords[:, i] = true_coords[:, j]
|
| 299 |
+
new_true_resolved_mask[i] = true_resolved_mask[j]
|
| 300 |
+
indices.append(i)
|
| 301 |
+
|
| 302 |
+
indices = (
|
| 303 |
+
torch.from_numpy(np.asarray(indices)).to(new_true_coords.device).long()
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
pred_coords_subset = coords[:, : len(crop_to_all_atom_map)][:, indices]
|
| 307 |
+
true_coords_subset = true_coords[:, indices]
|
| 308 |
+
new_true_coords_subset = new_true_coords[:, indices]
|
| 309 |
+
|
| 310 |
+
sub_dmat_pred = torch.cdist(
|
| 311 |
+
coords[:, : len(crop_to_all_atom_map)], pred_coords_subset
|
| 312 |
+
)
|
| 313 |
+
sub_dmat_true = torch.cdist(true_coords, true_coords_subset)
|
| 314 |
+
sub_dmat_new_true = torch.cdist(new_true_coords, new_true_coords_subset)
|
| 315 |
+
|
| 316 |
+
sub_true_pair_lddt = (
|
| 317 |
+
true_resolved_mask[:, None] * true_resolved_mask[None, indices]
|
| 318 |
+
)
|
| 319 |
+
sub_true_pair_lddt[indices] = (
|
| 320 |
+
sub_true_pair_lddt[indices]
|
| 321 |
+
* (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
sub_new_true_pair_lddt = (
|
| 325 |
+
new_true_resolved_mask[:, None] * new_true_resolved_mask[None, indices]
|
| 326 |
+
)
|
| 327 |
+
sub_new_true_pair_lddt[indices] = (
|
| 328 |
+
sub_new_true_pair_lddt[indices]
|
| 329 |
+
* (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
lddt = lddt_dist(
|
| 333 |
+
sub_dmat_pred,
|
| 334 |
+
sub_dmat_true,
|
| 335 |
+
sub_true_pair_lddt,
|
| 336 |
+
cutoff=15.0,
|
| 337 |
+
per_atom=False,
|
| 338 |
+
)[0]
|
| 339 |
+
new_lddt = lddt_dist(
|
| 340 |
+
sub_dmat_pred,
|
| 341 |
+
sub_dmat_new_true,
|
| 342 |
+
sub_new_true_pair_lddt,
|
| 343 |
+
cutoff=15.0,
|
| 344 |
+
per_atom=False,
|
| 345 |
+
)[0]
|
| 346 |
+
|
| 347 |
+
if new_lddt > lddt:
|
| 348 |
+
best_true_coords = new_true_coords
|
| 349 |
+
best_true_resolved_mask = new_true_resolved_mask
|
| 350 |
+
|
| 351 |
+
# greedily update best coordinates after each amino acid
|
| 352 |
+
true_coords = best_true_coords.clone()
|
| 353 |
+
true_resolved_mask = best_true_resolved_mask.clone()
|
| 354 |
+
|
| 355 |
+
# Recomputing alignment
|
| 356 |
+
true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
|
| 357 |
+
true_resolved_mask = pad_dim(
|
| 358 |
+
true_resolved_mask,
|
| 359 |
+
0,
|
| 360 |
+
coords.shape[1] - true_resolved_mask.shape[0],
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
try:
|
| 364 |
+
rmsd, true_coords, _ = weighted_minimum_rmsd_single(
|
| 365 |
+
coords,
|
| 366 |
+
true_coords,
|
| 367 |
+
atom_mask=true_resolved_mask,
|
| 368 |
+
atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
|
| 369 |
+
mol_type=feats["mol_type"][index_batch : index_batch + 1],
|
| 370 |
+
**args_rmsd,
|
| 371 |
+
)
|
| 372 |
+
best_rmsd = rmsd.item()
|
| 373 |
+
except Exception as e:
|
| 374 |
+
print("Failed proper RMSD computation, returning inf. Error: ", e)
|
| 375 |
+
best_rmsd = 1000
|
| 376 |
+
|
| 377 |
+
return true_coords, best_rmsd, true_resolved_mask.unsqueeze(0)
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def compute_all_coords_mask(structure):
|
| 381 |
+
# Compute all coords, crop mask and add start_idx to structure
|
| 382 |
+
total_count = 0
|
| 383 |
+
all_coords = []
|
| 384 |
+
all_coords_crop_mask = []
|
| 385 |
+
all_resolved_mask = []
|
| 386 |
+
for i, chain in enumerate(structure.chains):
|
| 387 |
+
chain.start_idx = total_count
|
| 388 |
+
for j, token in enumerate(chain.tokens):
|
| 389 |
+
token.start_idx = total_count - chain.start_idx
|
| 390 |
+
all_coords.extend(
|
| 391 |
+
[[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
|
| 392 |
+
)
|
| 393 |
+
all_coords_crop_mask.extend(
|
| 394 |
+
[token.in_crop for _ in range(len(token.atoms))]
|
| 395 |
+
)
|
| 396 |
+
all_resolved_mask.extend(
|
| 397 |
+
[token.is_present for _ in range(len(token.atoms))]
|
| 398 |
+
)
|
| 399 |
+
total_count += len(token.atoms)
|
| 400 |
+
if len(all_coords_crop_mask) != len(all_resolved_mask):
|
| 401 |
+
pass
|
| 402 |
+
return all_coords, all_coords_crop_mask, all_resolved_mask
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def get_chain_symmetries(cropped, max_n_symmetries=100):
|
| 406 |
+
# get all coordinates and resolved mask
|
| 407 |
+
structure = cropped.structure
|
| 408 |
+
all_coords = []
|
| 409 |
+
all_resolved_mask = []
|
| 410 |
+
original_atom_idx = []
|
| 411 |
+
chain_atom_idx = []
|
| 412 |
+
chain_atom_num = []
|
| 413 |
+
chain_in_crop = []
|
| 414 |
+
chain_asym_id = []
|
| 415 |
+
new_atom_idx = 0
|
| 416 |
+
|
| 417 |
+
for chain in structure.chains:
|
| 418 |
+
atom_idx, atom_num = (
|
| 419 |
+
chain["atom_idx"],
|
| 420 |
+
chain["atom_num"],
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# compute coordinates and resolved mask
|
| 424 |
+
resolved_mask = structure.atoms["is_present"][atom_idx : atom_idx + atom_num]
|
| 425 |
+
|
| 426 |
+
# ensemble_atom_starts = [structure.ensemble[idx]["atom_coord_idx"] for idx in cropped.ensemble_ref_idxs]
|
| 427 |
+
# coords = np.array(
|
| 428 |
+
# [structure.coords[ensemble_atom_start + atom_idx: ensemble_atom_start + atom_idx + atom_num]["coords"] for
|
| 429 |
+
# ensemble_atom_start in ensemble_atom_starts])
|
| 430 |
+
|
| 431 |
+
coords = structure.atoms["coords"][atom_idx : atom_idx + atom_num]
|
| 432 |
+
|
| 433 |
+
in_crop = False
|
| 434 |
+
for token in cropped.tokens:
|
| 435 |
+
if token["asym_id"] == chain["asym_id"]:
|
| 436 |
+
in_crop = True
|
| 437 |
+
break
|
| 438 |
+
|
| 439 |
+
all_coords.append(coords)
|
| 440 |
+
all_resolved_mask.append(resolved_mask)
|
| 441 |
+
original_atom_idx.append(atom_idx)
|
| 442 |
+
chain_atom_idx.append(new_atom_idx)
|
| 443 |
+
chain_atom_num.append(atom_num)
|
| 444 |
+
chain_in_crop.append(in_crop)
|
| 445 |
+
chain_asym_id.append(chain["asym_id"])
|
| 446 |
+
|
| 447 |
+
new_atom_idx += atom_num
|
| 448 |
+
|
| 449 |
+
# Compute backmapping from token to all coords
|
| 450 |
+
crop_to_all_atom_map = []
|
| 451 |
+
for token in cropped.tokens:
|
| 452 |
+
chain_idx = chain_asym_id.index(token["asym_id"])
|
| 453 |
+
start = (
|
| 454 |
+
chain_atom_idx[chain_idx] - original_atom_idx[chain_idx] + token["atom_idx"]
|
| 455 |
+
)
|
| 456 |
+
crop_to_all_atom_map.append(np.arange(start, start + token["atom_num"]))
|
| 457 |
+
|
| 458 |
+
# Compute the symmetries between chains
|
| 459 |
+
swaps = []
|
| 460 |
+
for i, chain in enumerate(structure.chains):
|
| 461 |
+
start = chain_atom_idx[i]
|
| 462 |
+
end = start + chain_atom_num[i]
|
| 463 |
+
if chain_in_crop[i]:
|
| 464 |
+
possible_swaps = []
|
| 465 |
+
for j, chain2 in enumerate(structure.chains):
|
| 466 |
+
start2 = chain_atom_idx[j]
|
| 467 |
+
end2 = start2 + chain_atom_num[j]
|
| 468 |
+
if (
|
| 469 |
+
chain["entity_id"] == chain2["entity_id"]
|
| 470 |
+
and end - start == end2 - start2
|
| 471 |
+
):
|
| 472 |
+
possible_swaps.append((start, end, start2, end2, i, j))
|
| 473 |
+
swaps.append(possible_swaps)
|
| 474 |
+
combinations = itertools.product(*swaps)
|
| 475 |
+
# to avoid combinatorial explosion, bound the number of combinations even considered
|
| 476 |
+
combinations = list(itertools.islice(combinations, max_n_symmetries * 10))
|
| 477 |
+
# filter for all chains getting a different assignment
|
| 478 |
+
combinations = [c for c in combinations if all_different_after_swap(c)]
|
| 479 |
+
|
| 480 |
+
if len(combinations) > max_n_symmetries:
|
| 481 |
+
combinations = random.sample(combinations, max_n_symmetries)
|
| 482 |
+
|
| 483 |
+
if len(combinations) == 0:
|
| 484 |
+
combinations.append([])
|
| 485 |
+
|
| 486 |
+
features = {}
|
| 487 |
+
features["all_coords"] = torch.Tensor(
|
| 488 |
+
np.concatenate(all_coords, axis=0)
|
| 489 |
+
) # axis=1 with ensemble
|
| 490 |
+
|
| 491 |
+
features["all_resolved_mask"] = torch.Tensor(
|
| 492 |
+
np.concatenate(all_resolved_mask, axis=0)
|
| 493 |
+
)
|
| 494 |
+
features["crop_to_all_atom_map"] = torch.Tensor(
|
| 495 |
+
np.concatenate(crop_to_all_atom_map, axis=0)
|
| 496 |
+
)
|
| 497 |
+
features["chain_symmetries"] = combinations
|
| 498 |
+
|
| 499 |
+
return features
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def get_amino_acids_symmetries(cropped):
|
| 503 |
+
# Compute standard amino-acids symmetries
|
| 504 |
+
swaps = []
|
| 505 |
+
start_index_crop = 0
|
| 506 |
+
for token in cropped.tokens:
|
| 507 |
+
symmetries = const.ref_symmetries.get(const.tokens[token["res_type"]], [])
|
| 508 |
+
if len(symmetries) > 0:
|
| 509 |
+
residue_swaps = []
|
| 510 |
+
for sym in symmetries:
|
| 511 |
+
sym_new_idx = [
|
| 512 |
+
(i + start_index_crop, j + start_index_crop) for i, j in sym
|
| 513 |
+
]
|
| 514 |
+
residue_swaps.append(sym_new_idx)
|
| 515 |
+
swaps.append(residue_swaps)
|
| 516 |
+
start_index_crop += token["atom_num"]
|
| 517 |
+
|
| 518 |
+
features = {"amino_acids_symmetries": swaps}
|
| 519 |
+
return features
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def get_ligand_symmetries(cropped, symmetries):
|
| 523 |
+
# Compute ligand and non-standard amino-acids symmetries
|
| 524 |
+
structure = cropped.structure
|
| 525 |
+
|
| 526 |
+
added_molecules = {}
|
| 527 |
+
index_mols = []
|
| 528 |
+
atom_count = 0
|
| 529 |
+
for token in cropped.tokens:
|
| 530 |
+
# check if molecule is already added by identifying it through asym_id and res_idx
|
| 531 |
+
atom_count += token["atom_num"]
|
| 532 |
+
mol_id = (token["asym_id"], token["res_idx"])
|
| 533 |
+
if mol_id in added_molecules.keys():
|
| 534 |
+
added_molecules[mol_id] += token["atom_num"]
|
| 535 |
+
continue
|
| 536 |
+
added_molecules[mol_id] = token["atom_num"]
|
| 537 |
+
|
| 538 |
+
# get the molecule type and indices
|
| 539 |
+
residue_idx = token["res_idx"] + structure.chains[token["asym_id"]]["res_idx"]
|
| 540 |
+
mol_name = structure.residues[residue_idx]["name"]
|
| 541 |
+
atom_idx = structure.residues[residue_idx]["atom_idx"]
|
| 542 |
+
mol_atom_names = structure.atoms[
|
| 543 |
+
atom_idx : atom_idx + structure.residues[residue_idx]["atom_num"]
|
| 544 |
+
]["name"]
|
| 545 |
+
mol_atom_names = [tuple(m) for m in mol_atom_names]
|
| 546 |
+
if mol_name not in const.ref_symmetries.keys():
|
| 547 |
+
index_mols.append(
|
| 548 |
+
(mol_name, atom_count - token["atom_num"], mol_id, mol_atom_names)
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# for each molecule, get the symmetries
|
| 552 |
+
molecule_symmetries = []
|
| 553 |
+
for mol_name, start_mol, mol_id, mol_atom_names in index_mols:
|
| 554 |
+
if not mol_name in symmetries:
|
| 555 |
+
continue
|
| 556 |
+
else:
|
| 557 |
+
swaps = []
|
| 558 |
+
syms_ccd, mol_atom_names_ccd = symmetries[mol_name]
|
| 559 |
+
# Get indices of mol_atom_names_ccd that are in mol_atom_names
|
| 560 |
+
ccd_to_valid_ids = {
|
| 561 |
+
mol_atom_names_ccd.index(name): i
|
| 562 |
+
for i, name in enumerate(mol_atom_names)
|
| 563 |
+
}
|
| 564 |
+
ccd_valid_ids = set(ccd_to_valid_ids.keys())
|
| 565 |
+
|
| 566 |
+
syms = []
|
| 567 |
+
# Get syms
|
| 568 |
+
for sym_ccd in syms_ccd:
|
| 569 |
+
sym_dict = {}
|
| 570 |
+
bool_add = True
|
| 571 |
+
for i, j in enumerate(sym_ccd):
|
| 572 |
+
if i in ccd_valid_ids:
|
| 573 |
+
if j in ccd_valid_ids:
|
| 574 |
+
i_true = ccd_to_valid_ids[i]
|
| 575 |
+
j_true = ccd_to_valid_ids[j]
|
| 576 |
+
sym_dict[i_true] = j_true
|
| 577 |
+
else:
|
| 578 |
+
bool_add = False
|
| 579 |
+
break
|
| 580 |
+
if bool_add:
|
| 581 |
+
syms.append([sym_dict[i] for i in range(len(ccd_valid_ids))])
|
| 582 |
+
|
| 583 |
+
for sym in syms:
|
| 584 |
+
if len(sym) != added_molecules[mol_id]:
|
| 585 |
+
raise Exception(
|
| 586 |
+
f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
|
| 587 |
+
)
|
| 588 |
+
# assert (
|
| 589 |
+
# len(sym) == added_molecules[mol_id]
|
| 590 |
+
# ), f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
|
| 591 |
+
sym_new_idx = []
|
| 592 |
+
for i, j in enumerate(sym):
|
| 593 |
+
if i != int(j):
|
| 594 |
+
sym_new_idx.append((i + start_mol, int(j) + start_mol))
|
| 595 |
+
if len(sym_new_idx) > 0:
|
| 596 |
+
swaps.append(sym_new_idx)
|
| 597 |
+
if len(swaps) > 0:
|
| 598 |
+
molecule_symmetries.append(swaps)
|
| 599 |
+
|
| 600 |
+
features = {"ligand_symmetries": molecule_symmetries}
|
| 601 |
+
|
| 602 |
+
return features
|
protify/FastPLMs/boltz/src/boltz/data/filter/__init__.py
ADDED
|
File without changes
|
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/__init__.py
ADDED
|
File without changes
|
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/date.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
from boltz.data.types import Record
|
| 5 |
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DateFilter(DynamicFilter):
|
| 9 |
+
"""A filter that filters complexes based on their date.
|
| 10 |
+
|
| 11 |
+
The date can be the deposition, release, or revision date.
|
| 12 |
+
If the date is not available, the previous date is used.
|
| 13 |
+
|
| 14 |
+
If no date is available, the complex is rejected.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
date: str,
|
| 21 |
+
ref: Literal["deposited", "revised", "released"],
|
| 22 |
+
) -> None:
|
| 23 |
+
"""Initialize the filter.
|
| 24 |
+
|
| 25 |
+
Parameters
|
| 26 |
+
----------
|
| 27 |
+
date : str, optional
|
| 28 |
+
The maximum date of PDB entries to filter
|
| 29 |
+
ref : Literal["deposited", "revised", "released"]
|
| 30 |
+
The reference date to use.
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
self.filter_date = datetime.fromisoformat(date)
|
| 34 |
+
self.ref = ref
|
| 35 |
+
|
| 36 |
+
if ref not in ["deposited", "revised", "released"]:
|
| 37 |
+
msg = (
|
| 38 |
+
"Invalid reference date. Must be ",
|
| 39 |
+
"deposited, revised, or released",
|
| 40 |
+
)
|
| 41 |
+
raise ValueError(msg)
|
| 42 |
+
|
| 43 |
+
def filter(self, record: Record) -> bool:
|
| 44 |
+
"""Filter a record based on its date.
|
| 45 |
+
|
| 46 |
+
Parameters
|
| 47 |
+
----------
|
| 48 |
+
record : Record
|
| 49 |
+
The record to filter.
|
| 50 |
+
|
| 51 |
+
Returns
|
| 52 |
+
-------
|
| 53 |
+
bool
|
| 54 |
+
Whether the record should be filtered.
|
| 55 |
+
|
| 56 |
+
"""
|
| 57 |
+
structure = record.structure
|
| 58 |
+
|
| 59 |
+
if self.ref == "deposited":
|
| 60 |
+
date = structure.deposited
|
| 61 |
+
elif self.ref == "released":
|
| 62 |
+
date = structure.released
|
| 63 |
+
if not date:
|
| 64 |
+
date = structure.deposited
|
| 65 |
+
elif self.ref == "revised":
|
| 66 |
+
date = structure.revised
|
| 67 |
+
if not date and structure.released:
|
| 68 |
+
date = structure.released
|
| 69 |
+
elif not date:
|
| 70 |
+
date = structure.deposited
|
| 71 |
+
|
| 72 |
+
if date is None or date == "":
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
date = datetime.fromisoformat(date)
|
| 76 |
+
return date <= self.filter_date
|
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/filter.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
from boltz.data.types import Record
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DynamicFilter(ABC):
|
| 7 |
+
"""Base class for data filters."""
|
| 8 |
+
|
| 9 |
+
@abstractmethod
|
| 10 |
+
def filter(self, record: Record) -> bool:
|
| 11 |
+
"""Filter a data record.
|
| 12 |
+
|
| 13 |
+
Parameters
|
| 14 |
+
----------
|
| 15 |
+
record : Record
|
| 16 |
+
The object to consider filtering in / out.
|
| 17 |
+
|
| 18 |
+
Returns
|
| 19 |
+
-------
|
| 20 |
+
bool
|
| 21 |
+
True if the data passes the filter, False otherwise.
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
raise NotImplementedError
|
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/max_residues.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from boltz.data.types import Record
|
| 2 |
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MaxResiduesFilter(DynamicFilter):
|
| 6 |
+
"""A filter that filters structures based on their size."""
|
| 7 |
+
|
| 8 |
+
def __init__(self, min_residues: int = 1, max_residues: int = 500) -> None:
|
| 9 |
+
"""Initialize the filter.
|
| 10 |
+
|
| 11 |
+
Parameters
|
| 12 |
+
----------
|
| 13 |
+
min_chains : int
|
| 14 |
+
The minimum number of chains allowed.
|
| 15 |
+
max_chains : int
|
| 16 |
+
The maximum number of chains allowed.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
self.min_residues = min_residues
|
| 20 |
+
self.max_residues = max_residues
|
| 21 |
+
|
| 22 |
+
def filter(self, record: Record) -> bool:
|
| 23 |
+
"""Filter structures based on their resolution.
|
| 24 |
+
|
| 25 |
+
Parameters
|
| 26 |
+
----------
|
| 27 |
+
record : Record
|
| 28 |
+
The record to filter.
|
| 29 |
+
|
| 30 |
+
Returns
|
| 31 |
+
-------
|
| 32 |
+
bool
|
| 33 |
+
Whether the record should be filtered.
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
num_residues = sum(chain.num_residues for chain in record.chains)
|
| 37 |
+
return num_residues <= self.max_residues and num_residues >= self.min_residues
|
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/resolution.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from boltz.data.types import Record
|
| 2 |
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ResolutionFilter(DynamicFilter):
|
| 6 |
+
"""A filter that filters complexes based on their resolution."""
|
| 7 |
+
|
| 8 |
+
def __init__(self, resolution: float = 9.0) -> None:
|
| 9 |
+
"""Initialize the filter.
|
| 10 |
+
|
| 11 |
+
Parameters
|
| 12 |
+
----------
|
| 13 |
+
resolution : float, optional
|
| 14 |
+
The maximum allowed resolution.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
self.resolution = resolution
|
| 18 |
+
|
| 19 |
+
def filter(self, record: Record) -> bool:
|
| 20 |
+
"""Filter complexes based on their resolution.
|
| 21 |
+
|
| 22 |
+
Parameters
|
| 23 |
+
----------
|
| 24 |
+
record : Record
|
| 25 |
+
The record to filter.
|
| 26 |
+
|
| 27 |
+
Returns
|
| 28 |
+
-------
|
| 29 |
+
bool
|
| 30 |
+
Whether the record should be filtered.
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
structure = record.structure
|
| 34 |
+
return structure.resolution <= self.resolution
|
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/size.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from boltz.data.types import Record
|
| 2 |
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SizeFilter(DynamicFilter):
|
| 6 |
+
"""A filter that filters structures based on their size."""
|
| 7 |
+
|
| 8 |
+
def __init__(self, min_chains: int = 1, max_chains: int = 300) -> None:
|
| 9 |
+
"""Initialize the filter.
|
| 10 |
+
|
| 11 |
+
Parameters
|
| 12 |
+
----------
|
| 13 |
+
min_chains : int
|
| 14 |
+
The minimum number of chains allowed.
|
| 15 |
+
max_chains : int
|
| 16 |
+
The maximum number of chains allowed.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
self.min_chains = min_chains
|
| 20 |
+
self.max_chains = max_chains
|
| 21 |
+
|
| 22 |
+
def filter(self, record: Record) -> bool:
|
| 23 |
+
"""Filter structures based on their resolution.
|
| 24 |
+
|
| 25 |
+
Parameters
|
| 26 |
+
----------
|
| 27 |
+
record : Record
|
| 28 |
+
The record to filter.
|
| 29 |
+
|
| 30 |
+
Returns
|
| 31 |
+
-------
|
| 32 |
+
bool
|
| 33 |
+
Whether the record should be filtered.
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
num_chains = record.structure.num_chains
|
| 37 |
+
num_valid = sum(1 for chain in record.chains if chain.valid)
|
| 38 |
+
return num_chains <= self.max_chains and num_valid >= self.min_chains
|
protify/FastPLMs/boltz/src/boltz/data/filter/dynamic/subset.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from boltz.data.types import Record
|
| 4 |
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SubsetFilter(DynamicFilter):
|
| 8 |
+
"""Filter a data record based on a subset of the data."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, subset: str, reverse: bool = False) -> None:
|
| 11 |
+
"""Initialize the filter.
|
| 12 |
+
|
| 13 |
+
Parameters
|
| 14 |
+
----------
|
| 15 |
+
subset : str
|
| 16 |
+
The subset of data to consider, one per line.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
with Path(subset).open("r") as f:
|
| 20 |
+
subset = f.read().splitlines()
|
| 21 |
+
|
| 22 |
+
self.subset = {s.lower() for s in subset}
|
| 23 |
+
self.reverse = reverse
|
| 24 |
+
|
| 25 |
+
def filter(self, record: Record) -> bool:
|
| 26 |
+
"""Filter a data record.
|
| 27 |
+
|
| 28 |
+
Parameters
|
| 29 |
+
----------
|
| 30 |
+
record : Record
|
| 31 |
+
The object to consider filtering in / out.
|
| 32 |
+
|
| 33 |
+
Returns
|
| 34 |
+
-------
|
| 35 |
+
bool
|
| 36 |
+
True if the data passes the filter, False otherwise.
|
| 37 |
+
|
| 38 |
+
"""
|
| 39 |
+
if self.reverse:
|
| 40 |
+
return record.id.lower() not in self.subset
|
| 41 |
+
else: # noqa: RET505
|
| 42 |
+
return record.id.lower() in self.subset
|
protify/FastPLMs/boltz/src/boltz/data/filter/static/__init__.py
ADDED
|
File without changes
|
protify/FastPLMs/boltz/src/boltz/data/filter/static/filter.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from boltz.data.types import Structure
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class StaticFilter(ABC):
|
| 9 |
+
"""Base class for structure filters."""
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def filter(self, structure: Structure) -> np.ndarray:
|
| 13 |
+
"""Filter chains in a structure.
|
| 14 |
+
|
| 15 |
+
Parameters
|
| 16 |
+
----------
|
| 17 |
+
structure : Structure
|
| 18 |
+
The structure to filter chains from.
|
| 19 |
+
|
| 20 |
+
Returns
|
| 21 |
+
-------
|
| 22 |
+
np.ndarray
|
| 23 |
+
The chains to keep, as a boolean mask.
|
| 24 |
+
|
| 25 |
+
"""
|
| 26 |
+
raise NotImplementedError
|
protify/FastPLMs/boltz/src/boltz/data/filter/static/ligand.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from boltz.data import const
|
| 4 |
+
from boltz.data.filter.static.filter import StaticFilter
|
| 5 |
+
from boltz.data.types import Structure
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ExcludedLigands(StaticFilter):
|
| 9 |
+
"""Filter excluded ligands."""
|
| 10 |
+
|
| 11 |
+
def filter(self, structure: Structure) -> np.ndarray:
|
| 12 |
+
"""Filter excluded ligands.
|
| 13 |
+
|
| 14 |
+
Parameters
|
| 15 |
+
----------
|
| 16 |
+
structure : Structure
|
| 17 |
+
The structure to filter chains from.
|
| 18 |
+
|
| 19 |
+
Returns
|
| 20 |
+
-------
|
| 21 |
+
np.ndarray
|
| 22 |
+
The chains to keep, as a boolean mask.
|
| 23 |
+
|
| 24 |
+
"""
|
| 25 |
+
valid = np.ones(len(structure.chains), dtype=bool)
|
| 26 |
+
|
| 27 |
+
for i, chain in enumerate(structure.chains):
|
| 28 |
+
if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]:
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
res_start = chain["res_idx"]
|
| 32 |
+
res_end = res_start + chain["res_num"]
|
| 33 |
+
residues = structure.residues[res_start:res_end]
|
| 34 |
+
if any(res["name"] in const.ligand_exclusion for res in residues):
|
| 35 |
+
valid[i] = 0
|
| 36 |
+
|
| 37 |
+
return valid
|
protify/FastPLMs/boltz/src/boltz/data/filter/static/polymer.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sklearn.neighbors import KDTree
|
| 6 |
+
|
| 7 |
+
from boltz.data import const
|
| 8 |
+
from boltz.data.filter.static.filter import StaticFilter
|
| 9 |
+
from boltz.data.types import Structure
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MinimumLengthFilter(StaticFilter):
|
| 13 |
+
"""Filter polymers based on their length.
|
| 14 |
+
|
| 15 |
+
We use the number of resolved residues when considering
|
| 16 |
+
the minimum, and the sequence length for the maximum.
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, min_len: int = 4, max_len: int = 5000) -> None:
|
| 21 |
+
"""Initialize the filter.
|
| 22 |
+
|
| 23 |
+
Parameters
|
| 24 |
+
----------
|
| 25 |
+
min_len : float, optional
|
| 26 |
+
The minimum allowed length.
|
| 27 |
+
max_len : float, optional
|
| 28 |
+
The maximum allowed length.
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
self._min = min_len
|
| 32 |
+
self._max = max_len
|
| 33 |
+
|
| 34 |
+
def filter(self, structure: Structure) -> np.ndarray:
|
| 35 |
+
"""Filter a chains based on their length.
|
| 36 |
+
|
| 37 |
+
Parameters
|
| 38 |
+
----------
|
| 39 |
+
structure : Structure
|
| 40 |
+
The structure to filter chains from.
|
| 41 |
+
|
| 42 |
+
Returns
|
| 43 |
+
-------
|
| 44 |
+
np.ndarray
|
| 45 |
+
The chains to keep, as a boolean mask.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
valid = np.ones(len(structure.chains), dtype=bool)
|
| 49 |
+
|
| 50 |
+
for i, chain in enumerate(structure.chains):
|
| 51 |
+
if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
res_start = chain["res_idx"]
|
| 55 |
+
res_end = res_start + chain["res_num"]
|
| 56 |
+
residues = structure.residues[res_start:res_end]
|
| 57 |
+
resolved = residues["is_present"].sum()
|
| 58 |
+
|
| 59 |
+
if (resolved < self._min) or (resolved > self._max):
|
| 60 |
+
valid[i] = 0
|
| 61 |
+
|
| 62 |
+
return valid
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class UnknownFilter(StaticFilter):
|
| 66 |
+
"""Filter proteins with all unknown residues."""
|
| 67 |
+
|
| 68 |
+
def filter(self, structure: Structure) -> np.ndarray:
|
| 69 |
+
"""Filter proteins with all unknown residues.
|
| 70 |
+
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
structure : Structure
|
| 74 |
+
The structure to filter chains from.
|
| 75 |
+
|
| 76 |
+
Returns
|
| 77 |
+
-------
|
| 78 |
+
np.ndarray
|
| 79 |
+
The chains to keep, as a boolean mask.
|
| 80 |
+
|
| 81 |
+
"""
|
| 82 |
+
valid = np.ones(len(structure.chains), dtype=bool)
|
| 83 |
+
unk_toks = {
|
| 84 |
+
const.chain_type_ids["PROTEIN"]: const.unk_token_ids["PROTEIN"],
|
| 85 |
+
const.chain_type_ids["DNA"]: const.unk_token_ids["DNA"],
|
| 86 |
+
const.chain_type_ids["RNA"]: const.unk_token_ids["RNA"],
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
for i, chain in enumerate(structure.chains):
|
| 90 |
+
if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
res_start = chain["res_idx"]
|
| 94 |
+
res_end = res_start + chain["res_num"]
|
| 95 |
+
residues = structure.residues[res_start:res_end]
|
| 96 |
+
|
| 97 |
+
unk_id = unk_toks[chain["mol_type"]]
|
| 98 |
+
if np.all(residues["res_type"] == unk_id):
|
| 99 |
+
valid[i] = 0
|
| 100 |
+
|
| 101 |
+
return valid
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ConsecutiveCA(StaticFilter):
|
| 105 |
+
"""Filter proteins with consecutive CA atoms above a threshold."""
|
| 106 |
+
|
| 107 |
+
def __init__(self, max_dist: int = 10.0) -> None:
|
| 108 |
+
"""Initialize the filter.
|
| 109 |
+
|
| 110 |
+
Parameters
|
| 111 |
+
----------
|
| 112 |
+
max_dist : float, optional
|
| 113 |
+
The maximum allowed distance.
|
| 114 |
+
|
| 115 |
+
"""
|
| 116 |
+
self._max_dist = max_dist
|
| 117 |
+
|
| 118 |
+
def filter(self, structure: Structure) -> np.ndarray:
|
| 119 |
+
"""Filter protein if consecutive CA atoms above a threshold.
|
| 120 |
+
|
| 121 |
+
Parameters
|
| 122 |
+
----------
|
| 123 |
+
structure : Structure
|
| 124 |
+
The structure to filter chains from.
|
| 125 |
+
|
| 126 |
+
Returns
|
| 127 |
+
-------
|
| 128 |
+
np.ndarray
|
| 129 |
+
The chains to keep, as a boolean mask.
|
| 130 |
+
|
| 131 |
+
"""
|
| 132 |
+
valid = np.ones(len(structure.chains), dtype=bool)
|
| 133 |
+
|
| 134 |
+
# Remove chain if consecutive CA atoms are above threshold
|
| 135 |
+
for i, chain in enumerate(structure.chains):
|
| 136 |
+
# Skip non-protein chains
|
| 137 |
+
if chain["mol_type"] != const.chain_type_ids["PROTEIN"]:
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
# Get residues
|
| 141 |
+
res_start = chain["res_idx"]
|
| 142 |
+
res_end = res_start + chain["res_num"]
|
| 143 |
+
residues = structure.residues[res_start:res_end]
|
| 144 |
+
|
| 145 |
+
# Get c-alphas
|
| 146 |
+
ca_ids = residues["atom_center"]
|
| 147 |
+
ca_atoms = structure.atoms[ca_ids]
|
| 148 |
+
|
| 149 |
+
res_valid = residues["is_present"]
|
| 150 |
+
ca_valid = ca_atoms["is_present"] & res_valid
|
| 151 |
+
ca_coords = ca_atoms["coords"]
|
| 152 |
+
|
| 153 |
+
# Compute distances between consecutive atoms
|
| 154 |
+
dist = np.linalg.norm(ca_coords[1:] - ca_coords[:-1], axis=1)
|
| 155 |
+
dist = dist > self._max_dist
|
| 156 |
+
dist = dist[ca_valid[1:] & ca_valid[:-1]]
|
| 157 |
+
|
| 158 |
+
# Remove the chain if any valid pair is above threshold
|
| 159 |
+
if np.any(dist):
|
| 160 |
+
valid[i] = 0
|
| 161 |
+
|
| 162 |
+
return valid
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@dataclass(frozen=True)
|
| 166 |
+
class Clash:
|
| 167 |
+
"""A clash between two chains."""
|
| 168 |
+
|
| 169 |
+
chain: int
|
| 170 |
+
other: int
|
| 171 |
+
num_atoms: int
|
| 172 |
+
num_clashes: int
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class ClashingChainsFilter(StaticFilter):
|
| 176 |
+
"""A filter that filters clashing chains.
|
| 177 |
+
|
| 178 |
+
Clashing chains are defined as those with >30% of atoms
|
| 179 |
+
within 1.7 Å of an atom in another chain. If two chains
|
| 180 |
+
are clashing with each other, the chain with the greater
|
| 181 |
+
percentage of clashing atoms will be removed. If the same
|
| 182 |
+
fraction of atoms are clashing, the chain with fewer total
|
| 183 |
+
atoms is removed. If the chains have the same number of
|
| 184 |
+
atoms, then the chain with the larger chain id is removed.
|
| 185 |
+
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
def __init__(self, dist: float = 1.7, freq: float = 0.3) -> None:
|
| 189 |
+
"""Initialize the filter.
|
| 190 |
+
|
| 191 |
+
Parameters
|
| 192 |
+
----------
|
| 193 |
+
dist : float, optional
|
| 194 |
+
The maximum distance for a clash.
|
| 195 |
+
freq : float, optional
|
| 196 |
+
The maximum allowed frequency of clashes.
|
| 197 |
+
|
| 198 |
+
"""
|
| 199 |
+
self._dist = dist
|
| 200 |
+
self._freq = freq
|
| 201 |
+
|
| 202 |
+
def filter(self, structure: Structure) -> np.ndarray: # noqa: PLR0912, C901
|
| 203 |
+
"""Filter out clashing chains.
|
| 204 |
+
|
| 205 |
+
Parameters
|
| 206 |
+
----------
|
| 207 |
+
structure : Structure
|
| 208 |
+
The structure to filter chains from.
|
| 209 |
+
|
| 210 |
+
Returns
|
| 211 |
+
-------
|
| 212 |
+
np.ndarray
|
| 213 |
+
The chains to keep, as a boolean mask.
|
| 214 |
+
|
| 215 |
+
"""
|
| 216 |
+
num_chains = len(structure.chains)
|
| 217 |
+
if num_chains < 2: # noqa: PLR2004
|
| 218 |
+
return np.ones(num_chains, dtype=bool)
|
| 219 |
+
|
| 220 |
+
# Get unique chain pairs
|
| 221 |
+
pairs = itertools.combinations(range(num_chains), 2)
|
| 222 |
+
|
| 223 |
+
# Compute clashes
|
| 224 |
+
clashes: list[Clash] = []
|
| 225 |
+
for i, j in pairs:
|
| 226 |
+
# Get the chains
|
| 227 |
+
c1 = structure.chains[i]
|
| 228 |
+
c2 = structure.chains[j]
|
| 229 |
+
|
| 230 |
+
# Get the atoms from each chain
|
| 231 |
+
c1_start = c1["atom_idx"]
|
| 232 |
+
c2_start = c2["atom_idx"]
|
| 233 |
+
c1_end = c1_start + c1["atom_num"]
|
| 234 |
+
c2_end = c2_start + c2["atom_num"]
|
| 235 |
+
|
| 236 |
+
atoms1 = structure.atoms[c1_start:c1_end]
|
| 237 |
+
atoms2 = structure.atoms[c2_start:c2_end]
|
| 238 |
+
atoms1 = atoms1[atoms1["is_present"]]
|
| 239 |
+
atoms2 = atoms2[atoms2["is_present"]]
|
| 240 |
+
|
| 241 |
+
# Skip if either chain has no atoms
|
| 242 |
+
if len(atoms1) == 0 or len(atoms2) == 0:
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
# Compute the number of clashes
|
| 246 |
+
# Compute the distance matrix
|
| 247 |
+
tree = KDTree(atoms1["coords"], metric="euclidean")
|
| 248 |
+
query = tree.query_radius(atoms2["coords"], self._dist)
|
| 249 |
+
|
| 250 |
+
c2_clashes = sum(len(neighbors) > 0 for neighbors in query)
|
| 251 |
+
c1_clashes = len(set(itertools.chain.from_iterable(query)))
|
| 252 |
+
|
| 253 |
+
# Save results
|
| 254 |
+
if (c1_clashes / len(atoms1)) > self._freq:
|
| 255 |
+
clashes.append(Clash(i, j, len(atoms1), c1_clashes))
|
| 256 |
+
if (c2_clashes / len(atoms2)) > self._freq:
|
| 257 |
+
clashes.append(Clash(j, i, len(atoms2), c2_clashes))
|
| 258 |
+
|
| 259 |
+
# Compute indices to clash map
|
| 260 |
+
removed = set()
|
| 261 |
+
ids_to_clash = {(c.chain, c.other): c for c in clashes}
|
| 262 |
+
|
| 263 |
+
# Filter out chains according to ruleset
|
| 264 |
+
for clash in clashes:
|
| 265 |
+
# If either is already removed, skip
|
| 266 |
+
if clash.chain in removed or clash.other in removed:
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
+
# Check if the two chains clash with each other
|
| 270 |
+
other_clash = ids_to_clash.get((clash.other, clash.chain))
|
| 271 |
+
if other_clash is not None:
|
| 272 |
+
# Remove the chain with the most clashes
|
| 273 |
+
clash1_freq = clash.num_clashes / clash.num_atoms
|
| 274 |
+
clash2_freq = other_clash.num_clashes / other_clash.num_atoms
|
| 275 |
+
if clash1_freq > clash2_freq:
|
| 276 |
+
removed.add(clash.chain)
|
| 277 |
+
elif clash1_freq < clash2_freq:
|
| 278 |
+
removed.add(clash.other)
|
| 279 |
+
|
| 280 |
+
# If same, remove the chain with fewer atoms
|
| 281 |
+
elif clash.num_atoms < other_clash.num_atoms:
|
| 282 |
+
removed.add(clash.chain)
|
| 283 |
+
elif clash.num_atoms > other_clash.num_atoms:
|
| 284 |
+
removed.add(clash.other)
|
| 285 |
+
|
| 286 |
+
# If same, remove the chain with the larger chain id
|
| 287 |
+
else:
|
| 288 |
+
removed.add(max(clash.chain, clash.other))
|
| 289 |
+
|
| 290 |
+
# Otherwise, just remove the chain directly
|
| 291 |
+
else:
|
| 292 |
+
removed.add(clash.chain)
|
| 293 |
+
|
| 294 |
+
# Remove the chains
|
| 295 |
+
valid = np.ones(len(structure.chains), dtype=bool)
|
| 296 |
+
for i in removed:
|
| 297 |
+
valid[i] = 0
|
| 298 |
+
|
| 299 |
+
return valid
|
protify/FastPLMs/boltz/src/boltz/data/module/__init__.py
ADDED
|
File without changes
|
protify/FastPLMs/boltz/src/boltz/data/module/inference.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
|
| 10 |
+
from boltz.data import const
|
| 11 |
+
from boltz.data.feature.featurizer import BoltzFeaturizer
|
| 12 |
+
from boltz.data.pad import pad_to_max
|
| 13 |
+
from boltz.data.tokenize.boltz import BoltzTokenizer
|
| 14 |
+
from boltz.data.types import (
|
| 15 |
+
MSA,
|
| 16 |
+
Connection,
|
| 17 |
+
Input,
|
| 18 |
+
Manifest,
|
| 19 |
+
Record,
|
| 20 |
+
ResidueConstraints,
|
| 21 |
+
Structure,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_input(
|
| 26 |
+
record: Record,
|
| 27 |
+
target_dir: Path,
|
| 28 |
+
msa_dir: Path,
|
| 29 |
+
constraints_dir: Optional[Path] = None,
|
| 30 |
+
) -> Input:
|
| 31 |
+
"""Load the given input data.
|
| 32 |
+
|
| 33 |
+
Parameters
|
| 34 |
+
----------
|
| 35 |
+
record : Record
|
| 36 |
+
The record to load.
|
| 37 |
+
target_dir : Path
|
| 38 |
+
The path to the data directory.
|
| 39 |
+
msa_dir : Path
|
| 40 |
+
The path to msa directory.
|
| 41 |
+
|
| 42 |
+
Returns
|
| 43 |
+
-------
|
| 44 |
+
Input
|
| 45 |
+
The loaded input.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
# Load the structure
|
| 49 |
+
structure = np.load(target_dir / f"{record.id}.npz")
|
| 50 |
+
structure = Structure(
|
| 51 |
+
atoms=structure["atoms"],
|
| 52 |
+
bonds=structure["bonds"],
|
| 53 |
+
residues=structure["residues"],
|
| 54 |
+
chains=structure["chains"],
|
| 55 |
+
connections=structure["connections"].astype(Connection),
|
| 56 |
+
interfaces=structure["interfaces"],
|
| 57 |
+
mask=structure["mask"],
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
msas = {}
|
| 61 |
+
for chain in record.chains:
|
| 62 |
+
msa_id = chain.msa_id
|
| 63 |
+
# Load the MSA for this chain, if any
|
| 64 |
+
if msa_id != -1:
|
| 65 |
+
msa = np.load(msa_dir / f"{msa_id}.npz")
|
| 66 |
+
msas[chain.chain_id] = MSA(**msa)
|
| 67 |
+
|
| 68 |
+
residue_constraints = None
|
| 69 |
+
if constraints_dir is not None:
|
| 70 |
+
residue_constraints = ResidueConstraints.load(
|
| 71 |
+
constraints_dir / f"{record.id}.npz"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return Input(structure, msas, record, residue_constraints)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
|
| 78 |
+
"""Collate the data.
|
| 79 |
+
|
| 80 |
+
Parameters
|
| 81 |
+
----------
|
| 82 |
+
data : List[Dict[str, Tensor]]
|
| 83 |
+
The data to collate.
|
| 84 |
+
|
| 85 |
+
Returns
|
| 86 |
+
-------
|
| 87 |
+
Dict[str, Tensor]
|
| 88 |
+
The collated data.
|
| 89 |
+
|
| 90 |
+
"""
|
| 91 |
+
# Get the keys
|
| 92 |
+
keys = data[0].keys()
|
| 93 |
+
|
| 94 |
+
# Collate the data
|
| 95 |
+
collated = {}
|
| 96 |
+
for key in keys:
|
| 97 |
+
values = [d[key] for d in data]
|
| 98 |
+
|
| 99 |
+
if key not in [
|
| 100 |
+
"all_coords",
|
| 101 |
+
"all_resolved_mask",
|
| 102 |
+
"crop_to_all_atom_map",
|
| 103 |
+
"chain_symmetries",
|
| 104 |
+
"amino_acids_symmetries",
|
| 105 |
+
"ligand_symmetries",
|
| 106 |
+
"record",
|
| 107 |
+
]:
|
| 108 |
+
# Check if all have the same shape
|
| 109 |
+
shape = values[0].shape
|
| 110 |
+
if not all(v.shape == shape for v in values):
|
| 111 |
+
values, _ = pad_to_max(values, 0)
|
| 112 |
+
else:
|
| 113 |
+
values = torch.stack(values, dim=0)
|
| 114 |
+
|
| 115 |
+
# Stack the values
|
| 116 |
+
collated[key] = values
|
| 117 |
+
|
| 118 |
+
return collated
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class PredictionDataset(torch.utils.data.Dataset):
|
| 122 |
+
"""Base iterable dataset."""
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
manifest: Manifest,
|
| 127 |
+
target_dir: Path,
|
| 128 |
+
msa_dir: Path,
|
| 129 |
+
constraints_dir: Optional[Path] = None,
|
| 130 |
+
) -> None:
|
| 131 |
+
"""Initialize the training dataset.
|
| 132 |
+
|
| 133 |
+
Parameters
|
| 134 |
+
----------
|
| 135 |
+
manifest : Manifest
|
| 136 |
+
The manifest to load data from.
|
| 137 |
+
target_dir : Path
|
| 138 |
+
The path to the target directory.
|
| 139 |
+
msa_dir : Path
|
| 140 |
+
The path to the msa directory.
|
| 141 |
+
|
| 142 |
+
"""
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.manifest = manifest
|
| 145 |
+
self.target_dir = target_dir
|
| 146 |
+
self.msa_dir = msa_dir
|
| 147 |
+
self.constraints_dir = constraints_dir
|
| 148 |
+
self.tokenizer = BoltzTokenizer()
|
| 149 |
+
self.featurizer = BoltzFeaturizer()
|
| 150 |
+
|
| 151 |
+
def __getitem__(self, idx: int) -> dict:
|
| 152 |
+
"""Get an item from the dataset.
|
| 153 |
+
|
| 154 |
+
Returns
|
| 155 |
+
-------
|
| 156 |
+
Dict[str, Tensor]
|
| 157 |
+
The sampled data features.
|
| 158 |
+
|
| 159 |
+
"""
|
| 160 |
+
# Get a sample from the dataset
|
| 161 |
+
record = self.manifest.records[idx]
|
| 162 |
+
|
| 163 |
+
# Get the structure
|
| 164 |
+
try:
|
| 165 |
+
input_data = load_input(
|
| 166 |
+
record,
|
| 167 |
+
self.target_dir,
|
| 168 |
+
self.msa_dir,
|
| 169 |
+
self.constraints_dir,
|
| 170 |
+
)
|
| 171 |
+
except Exception as e: # noqa: BLE001
|
| 172 |
+
print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201
|
| 173 |
+
return self.__getitem__(0)
|
| 174 |
+
|
| 175 |
+
# Tokenize structure
|
| 176 |
+
try:
|
| 177 |
+
tokenized = self.tokenizer.tokenize(input_data)
|
| 178 |
+
except Exception as e: # noqa: BLE001
|
| 179 |
+
print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
|
| 180 |
+
return self.__getitem__(0)
|
| 181 |
+
|
| 182 |
+
# Inference specific options
|
| 183 |
+
options = record.inference_options
|
| 184 |
+
if options is None or len(options.pocket_constraints) == 0:
|
| 185 |
+
binder, pocket = None, None
|
| 186 |
+
else:
|
| 187 |
+
binder, pocket = (
|
| 188 |
+
options.pocket_constraints[0][0],
|
| 189 |
+
options.pocket_constraints[0][1],
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Compute features
|
| 193 |
+
try:
|
| 194 |
+
features = self.featurizer.process(
|
| 195 |
+
tokenized,
|
| 196 |
+
training=False,
|
| 197 |
+
max_atoms=None,
|
| 198 |
+
max_tokens=None,
|
| 199 |
+
max_seqs=const.max_msa_seqs,
|
| 200 |
+
pad_to_max_seqs=False,
|
| 201 |
+
symmetries={},
|
| 202 |
+
compute_symmetries=False,
|
| 203 |
+
inference_binder=binder,
|
| 204 |
+
inference_pocket=pocket,
|
| 205 |
+
compute_constraint_features=True,
|
| 206 |
+
)
|
| 207 |
+
except Exception as e: # noqa: BLE001
|
| 208 |
+
print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
|
| 209 |
+
return self.__getitem__(0)
|
| 210 |
+
|
| 211 |
+
features["record"] = record
|
| 212 |
+
return features
|
| 213 |
+
|
| 214 |
+
def __len__(self) -> int:
|
| 215 |
+
"""Get the length of the dataset.
|
| 216 |
+
|
| 217 |
+
Returns
|
| 218 |
+
-------
|
| 219 |
+
int
|
| 220 |
+
The length of the dataset.
|
| 221 |
+
|
| 222 |
+
"""
|
| 223 |
+
return len(self.manifest.records)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class BoltzInferenceDataModule(pl.LightningDataModule):
|
| 227 |
+
"""DataModule for Boltz inference."""
|
| 228 |
+
|
| 229 |
+
def __init__(
|
| 230 |
+
self,
|
| 231 |
+
manifest: Manifest,
|
| 232 |
+
target_dir: Path,
|
| 233 |
+
msa_dir: Path,
|
| 234 |
+
num_workers: int,
|
| 235 |
+
constraints_dir: Optional[Path] = None,
|
| 236 |
+
) -> None:
|
| 237 |
+
"""Initialize the DataModule.
|
| 238 |
+
|
| 239 |
+
Parameters
|
| 240 |
+
----------
|
| 241 |
+
config : DataConfig
|
| 242 |
+
The data configuration.
|
| 243 |
+
|
| 244 |
+
"""
|
| 245 |
+
super().__init__()
|
| 246 |
+
self.num_workers = num_workers
|
| 247 |
+
self.manifest = manifest
|
| 248 |
+
self.target_dir = target_dir
|
| 249 |
+
self.msa_dir = msa_dir
|
| 250 |
+
self.constraints_dir = constraints_dir
|
| 251 |
+
|
| 252 |
+
def predict_dataloader(self) -> DataLoader:
|
| 253 |
+
"""Get the training dataloader.
|
| 254 |
+
|
| 255 |
+
Returns
|
| 256 |
+
-------
|
| 257 |
+
DataLoader
|
| 258 |
+
The training dataloader.
|
| 259 |
+
|
| 260 |
+
"""
|
| 261 |
+
dataset = PredictionDataset(
|
| 262 |
+
manifest=self.manifest,
|
| 263 |
+
target_dir=self.target_dir,
|
| 264 |
+
msa_dir=self.msa_dir,
|
| 265 |
+
constraints_dir=self.constraints_dir,
|
| 266 |
+
)
|
| 267 |
+
return DataLoader(
|
| 268 |
+
dataset,
|
| 269 |
+
batch_size=1,
|
| 270 |
+
num_workers=self.num_workers,
|
| 271 |
+
pin_memory=True,
|
| 272 |
+
shuffle=False,
|
| 273 |
+
collate_fn=collate,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
def transfer_batch_to_device(
|
| 277 |
+
self,
|
| 278 |
+
batch: dict,
|
| 279 |
+
device: torch.device,
|
| 280 |
+
dataloader_idx: int, # noqa: ARG002
|
| 281 |
+
) -> dict:
|
| 282 |
+
"""Transfer a batch to the given device.
|
| 283 |
+
|
| 284 |
+
Parameters
|
| 285 |
+
----------
|
| 286 |
+
batch : Dict
|
| 287 |
+
The batch to transfer.
|
| 288 |
+
device : torch.device
|
| 289 |
+
The device to transfer to.
|
| 290 |
+
dataloader_idx : int
|
| 291 |
+
The dataloader index.
|
| 292 |
+
|
| 293 |
+
Returns
|
| 294 |
+
-------
|
| 295 |
+
np.Any
|
| 296 |
+
The transferred batch.
|
| 297 |
+
|
| 298 |
+
"""
|
| 299 |
+
for key in batch:
|
| 300 |
+
if key not in [
|
| 301 |
+
"all_coords",
|
| 302 |
+
"all_resolved_mask",
|
| 303 |
+
"crop_to_all_atom_map",
|
| 304 |
+
"chain_symmetries",
|
| 305 |
+
"amino_acids_symmetries",
|
| 306 |
+
"ligand_symmetries",
|
| 307 |
+
"record",
|
| 308 |
+
]:
|
| 309 |
+
batch[key] = batch[key].to(device)
|
| 310 |
+
return batch
|
protify/FastPLMs/boltz/src/boltz/data/module/inferencev2.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
from boltz.data import const
|
| 12 |
+
from boltz.data.crop.affinity import AffinityCropper
|
| 13 |
+
from boltz.data.feature.featurizerv2 import Boltz2Featurizer
|
| 14 |
+
from boltz.data.mol import load_canonicals, load_molecules
|
| 15 |
+
from boltz.data.pad import pad_to_max
|
| 16 |
+
from boltz.data.tokenize.boltz2 import Boltz2Tokenizer
|
| 17 |
+
from boltz.data.types import (
|
| 18 |
+
MSA,
|
| 19 |
+
Input,
|
| 20 |
+
Manifest,
|
| 21 |
+
Record,
|
| 22 |
+
ResidueConstraints,
|
| 23 |
+
StructureV2,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_input(
|
| 28 |
+
record: Record,
|
| 29 |
+
target_dir: Path,
|
| 30 |
+
msa_dir: Path,
|
| 31 |
+
constraints_dir: Optional[Path] = None,
|
| 32 |
+
template_dir: Optional[Path] = None,
|
| 33 |
+
extra_mols_dir: Optional[Path] = None,
|
| 34 |
+
affinity: bool = False,
|
| 35 |
+
) -> Input:
|
| 36 |
+
"""Load the given input data.
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
record : Record
|
| 41 |
+
The record to load.
|
| 42 |
+
target_dir : Path
|
| 43 |
+
The path to the data directory.
|
| 44 |
+
msa_dir : Path
|
| 45 |
+
The path to msa directory.
|
| 46 |
+
constraints_dir : Optional[Path]
|
| 47 |
+
The path to the constraints directory.
|
| 48 |
+
template_dir : Optional[Path]
|
| 49 |
+
The path to the template directory.
|
| 50 |
+
extra_mols_dir : Optional[Path]
|
| 51 |
+
The path to the extra molecules directory.
|
| 52 |
+
affinity : bool
|
| 53 |
+
Whether to load the affinity data.
|
| 54 |
+
|
| 55 |
+
Returns
|
| 56 |
+
-------
|
| 57 |
+
Input
|
| 58 |
+
The loaded input.
|
| 59 |
+
|
| 60 |
+
"""
|
| 61 |
+
# Load the structure
|
| 62 |
+
if affinity:
|
| 63 |
+
structure = StructureV2.load(
|
| 64 |
+
target_dir / record.id / f"pre_affinity_{record.id}.npz"
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
structure = StructureV2.load(target_dir / f"{record.id}.npz")
|
| 68 |
+
|
| 69 |
+
msas = {}
|
| 70 |
+
for chain in record.chains:
|
| 71 |
+
msa_id = chain.msa_id
|
| 72 |
+
# Load the MSA for this chain, if any
|
| 73 |
+
if msa_id != -1:
|
| 74 |
+
msa = MSA.load(msa_dir / f"{msa_id}.npz")
|
| 75 |
+
msas[chain.chain_id] = msa
|
| 76 |
+
|
| 77 |
+
# Load templates
|
| 78 |
+
templates = None
|
| 79 |
+
if record.templates and template_dir is not None:
|
| 80 |
+
templates = {}
|
| 81 |
+
for template_info in record.templates:
|
| 82 |
+
template_id = template_info.name
|
| 83 |
+
template_path = template_dir / f"{record.id}_{template_id}.npz"
|
| 84 |
+
template = StructureV2.load(template_path)
|
| 85 |
+
templates[template_id] = template
|
| 86 |
+
|
| 87 |
+
# Load residue constraints
|
| 88 |
+
residue_constraints = None
|
| 89 |
+
if constraints_dir is not None:
|
| 90 |
+
residue_constraints = ResidueConstraints.load(
|
| 91 |
+
constraints_dir / f"{record.id}.npz"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Load extra molecules
|
| 95 |
+
extra_mols = {}
|
| 96 |
+
if extra_mols_dir is not None:
|
| 97 |
+
extra_mol_path = extra_mols_dir / f"{record.id}.pkl"
|
| 98 |
+
if extra_mol_path.exists():
|
| 99 |
+
with extra_mol_path.open("rb") as f:
|
| 100 |
+
extra_mols = pickle.load(f) # noqa: S301
|
| 101 |
+
|
| 102 |
+
return Input(
|
| 103 |
+
structure,
|
| 104 |
+
msas,
|
| 105 |
+
record=record,
|
| 106 |
+
residue_constraints=residue_constraints,
|
| 107 |
+
templates=templates,
|
| 108 |
+
extra_mols=extra_mols,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
|
| 113 |
+
"""Collate the data.
|
| 114 |
+
|
| 115 |
+
Parameters
|
| 116 |
+
----------
|
| 117 |
+
data : List[Dict[str, Tensor]]
|
| 118 |
+
The data to collate.
|
| 119 |
+
|
| 120 |
+
Returns
|
| 121 |
+
-------
|
| 122 |
+
Dict[str, Tensor]
|
| 123 |
+
The collated data.
|
| 124 |
+
|
| 125 |
+
"""
|
| 126 |
+
# Get the keys
|
| 127 |
+
keys = data[0].keys()
|
| 128 |
+
|
| 129 |
+
# Collate the data
|
| 130 |
+
collated = {}
|
| 131 |
+
for key in keys:
|
| 132 |
+
values = [d[key] for d in data]
|
| 133 |
+
|
| 134 |
+
if key not in [
|
| 135 |
+
"all_coords",
|
| 136 |
+
"all_resolved_mask",
|
| 137 |
+
"crop_to_all_atom_map",
|
| 138 |
+
"chain_symmetries",
|
| 139 |
+
"amino_acids_symmetries",
|
| 140 |
+
"ligand_symmetries",
|
| 141 |
+
"record",
|
| 142 |
+
"affinity_mw",
|
| 143 |
+
]:
|
| 144 |
+
# Check if all have the same shape
|
| 145 |
+
shape = values[0].shape
|
| 146 |
+
if not all(v.shape == shape for v in values):
|
| 147 |
+
values, _ = pad_to_max(values, 0)
|
| 148 |
+
else:
|
| 149 |
+
values = torch.stack(values, dim=0)
|
| 150 |
+
|
| 151 |
+
# Stack the values
|
| 152 |
+
collated[key] = values
|
| 153 |
+
|
| 154 |
+
return collated
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class PredictionDataset(torch.utils.data.Dataset):
|
| 158 |
+
"""Base iterable dataset."""
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
manifest: Manifest,
|
| 163 |
+
target_dir: Path,
|
| 164 |
+
msa_dir: Path,
|
| 165 |
+
mol_dir: Path,
|
| 166 |
+
constraints_dir: Optional[Path] = None,
|
| 167 |
+
template_dir: Optional[Path] = None,
|
| 168 |
+
extra_mols_dir: Optional[Path] = None,
|
| 169 |
+
override_method: Optional[str] = None,
|
| 170 |
+
affinity: bool = False,
|
| 171 |
+
) -> None:
|
| 172 |
+
"""Initialize the training dataset.
|
| 173 |
+
|
| 174 |
+
Parameters
|
| 175 |
+
----------
|
| 176 |
+
manifest : Manifest
|
| 177 |
+
The manifest to load data from.
|
| 178 |
+
target_dir : Path
|
| 179 |
+
The path to the target directory.
|
| 180 |
+
msa_dir : Path
|
| 181 |
+
The path to the msa directory.
|
| 182 |
+
mol_dir : Path
|
| 183 |
+
The path to the moldir.
|
| 184 |
+
constraints_dir : Optional[Path]
|
| 185 |
+
The path to the constraints directory.
|
| 186 |
+
template_dir : Optional[Path]
|
| 187 |
+
The path to the template directory.
|
| 188 |
+
|
| 189 |
+
"""
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.manifest = manifest
|
| 192 |
+
self.target_dir = target_dir
|
| 193 |
+
self.msa_dir = msa_dir
|
| 194 |
+
self.mol_dir = mol_dir
|
| 195 |
+
self.constraints_dir = constraints_dir
|
| 196 |
+
self.template_dir = template_dir
|
| 197 |
+
self.tokenizer = Boltz2Tokenizer()
|
| 198 |
+
self.featurizer = Boltz2Featurizer()
|
| 199 |
+
self.canonicals = load_canonicals(self.mol_dir)
|
| 200 |
+
self.extra_mols_dir = extra_mols_dir
|
| 201 |
+
self.override_method = override_method
|
| 202 |
+
self.affinity = affinity
|
| 203 |
+
if self.affinity:
|
| 204 |
+
self.cropper = AffinityCropper()
|
| 205 |
+
|
| 206 |
+
def __getitem__(self, idx: int) -> dict:
|
| 207 |
+
"""Get an item from the dataset.
|
| 208 |
+
|
| 209 |
+
Returns
|
| 210 |
+
-------
|
| 211 |
+
Dict[str, Tensor]
|
| 212 |
+
The sampled data features.
|
| 213 |
+
|
| 214 |
+
"""
|
| 215 |
+
# Get record
|
| 216 |
+
record = self.manifest.records[idx]
|
| 217 |
+
|
| 218 |
+
# Finalize input data
|
| 219 |
+
input_data = load_input(
|
| 220 |
+
record=record,
|
| 221 |
+
target_dir=self.target_dir,
|
| 222 |
+
msa_dir=self.msa_dir,
|
| 223 |
+
constraints_dir=self.constraints_dir,
|
| 224 |
+
template_dir=self.template_dir,
|
| 225 |
+
extra_mols_dir=self.extra_mols_dir,
|
| 226 |
+
affinity=self.affinity,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# Tokenize structure
|
| 230 |
+
try:
|
| 231 |
+
tokenized = self.tokenizer.tokenize(input_data)
|
| 232 |
+
except Exception as e: # noqa: BLE001
|
| 233 |
+
print( # noqa: T201
|
| 234 |
+
f"Tokenizer failed on {record.id} with error {e}. Skipping."
|
| 235 |
+
)
|
| 236 |
+
return self.__getitem__(0)
|
| 237 |
+
|
| 238 |
+
if self.affinity:
|
| 239 |
+
try:
|
| 240 |
+
tokenized = self.cropper.crop(
|
| 241 |
+
tokenized,
|
| 242 |
+
max_tokens=256,
|
| 243 |
+
max_atoms=2048,
|
| 244 |
+
)
|
| 245 |
+
except Exception as e: # noqa: BLE001
|
| 246 |
+
print(f"Cropper failed on {record.id} with error {e}. Skipping.") # noqa: T201
|
| 247 |
+
return self.__getitem__(0)
|
| 248 |
+
|
| 249 |
+
# Load conformers
|
| 250 |
+
try:
|
| 251 |
+
molecules = {}
|
| 252 |
+
molecules.update(self.canonicals)
|
| 253 |
+
molecules.update(input_data.extra_mols)
|
| 254 |
+
mol_names = set(tokenized.tokens["res_name"].tolist())
|
| 255 |
+
mol_names = mol_names - set(molecules.keys())
|
| 256 |
+
molecules.update(load_molecules(self.mol_dir, mol_names))
|
| 257 |
+
except Exception as e: # noqa: BLE001
|
| 258 |
+
print(f"Molecule loading failed for {record.id} with error {e}. Skipping.")
|
| 259 |
+
return self.__getitem__(0)
|
| 260 |
+
|
| 261 |
+
# Inference specific options
|
| 262 |
+
options = record.inference_options
|
| 263 |
+
if options is None:
|
| 264 |
+
pocket_constraints, contact_constraints = None, None
|
| 265 |
+
else:
|
| 266 |
+
pocket_constraints, contact_constraints = (
|
| 267 |
+
options.pocket_constraints,
|
| 268 |
+
options.contact_constraints,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Get random seed
|
| 272 |
+
seed = 42
|
| 273 |
+
random = np.random.default_rng(seed)
|
| 274 |
+
|
| 275 |
+
# Compute features
|
| 276 |
+
try:
|
| 277 |
+
features = self.featurizer.process(
|
| 278 |
+
tokenized,
|
| 279 |
+
molecules=molecules,
|
| 280 |
+
random=random,
|
| 281 |
+
training=False,
|
| 282 |
+
max_atoms=None,
|
| 283 |
+
max_tokens=None,
|
| 284 |
+
max_seqs=const.max_msa_seqs,
|
| 285 |
+
pad_to_max_seqs=False,
|
| 286 |
+
single_sequence_prop=0.0,
|
| 287 |
+
compute_frames=True,
|
| 288 |
+
inference_pocket_constraints=pocket_constraints,
|
| 289 |
+
inference_contact_constraints=contact_constraints,
|
| 290 |
+
compute_constraint_features=True,
|
| 291 |
+
override_method=self.override_method,
|
| 292 |
+
compute_affinity=self.affinity,
|
| 293 |
+
)
|
| 294 |
+
except Exception as e: # noqa: BLE001
|
| 295 |
+
import traceback
|
| 296 |
+
|
| 297 |
+
traceback.print_exc()
|
| 298 |
+
print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
|
| 299 |
+
return self.__getitem__(0)
|
| 300 |
+
|
| 301 |
+
# Add record
|
| 302 |
+
features["record"] = record
|
| 303 |
+
return features
|
| 304 |
+
|
| 305 |
+
def __len__(self) -> int:
|
| 306 |
+
"""Get the length of the dataset.
|
| 307 |
+
|
| 308 |
+
Returns
|
| 309 |
+
-------
|
| 310 |
+
int
|
| 311 |
+
The length of the dataset.
|
| 312 |
+
|
| 313 |
+
"""
|
| 314 |
+
return len(self.manifest.records)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class Boltz2InferenceDataModule(pl.LightningDataModule):
|
| 318 |
+
"""DataModule for Boltz2 inference."""
|
| 319 |
+
|
| 320 |
+
def __init__(
|
| 321 |
+
self,
|
| 322 |
+
manifest: Manifest,
|
| 323 |
+
target_dir: Path,
|
| 324 |
+
msa_dir: Path,
|
| 325 |
+
mol_dir: Path,
|
| 326 |
+
num_workers: int,
|
| 327 |
+
constraints_dir: Optional[Path] = None,
|
| 328 |
+
template_dir: Optional[Path] = None,
|
| 329 |
+
extra_mols_dir: Optional[Path] = None,
|
| 330 |
+
override_method: Optional[str] = None,
|
| 331 |
+
affinity: bool = False,
|
| 332 |
+
) -> None:
|
| 333 |
+
"""Initialize the DataModule.
|
| 334 |
+
|
| 335 |
+
Parameters
|
| 336 |
+
----------
|
| 337 |
+
manifest : Manifest
|
| 338 |
+
The manifest to load data from.
|
| 339 |
+
target_dir : Path
|
| 340 |
+
The path to the target directory.
|
| 341 |
+
msa_dir : Path
|
| 342 |
+
The path to the msa directory.
|
| 343 |
+
mol_dir : Path
|
| 344 |
+
The path to the moldir.
|
| 345 |
+
num_workers : int
|
| 346 |
+
The number of workers to use.
|
| 347 |
+
constraints_dir : Optional[Path]
|
| 348 |
+
The path to the constraints directory.
|
| 349 |
+
template_dir : Optional[Path]
|
| 350 |
+
The path to the template directory.
|
| 351 |
+
extra_mols_dir : Optional[Path]
|
| 352 |
+
The path to the extra molecules directory.
|
| 353 |
+
override_method : Optional[str]
|
| 354 |
+
The method to override.
|
| 355 |
+
|
| 356 |
+
"""
|
| 357 |
+
super().__init__()
|
| 358 |
+
self.num_workers = num_workers
|
| 359 |
+
self.manifest = manifest
|
| 360 |
+
self.target_dir = target_dir
|
| 361 |
+
self.msa_dir = msa_dir
|
| 362 |
+
self.mol_dir = mol_dir
|
| 363 |
+
self.constraints_dir = constraints_dir
|
| 364 |
+
self.template_dir = template_dir
|
| 365 |
+
self.extra_mols_dir = extra_mols_dir
|
| 366 |
+
self.override_method = override_method
|
| 367 |
+
self.affinity = affinity
|
| 368 |
+
|
| 369 |
+
def predict_dataloader(self) -> DataLoader:
|
| 370 |
+
"""Get the training dataloader.
|
| 371 |
+
|
| 372 |
+
Returns
|
| 373 |
+
-------
|
| 374 |
+
DataLoader
|
| 375 |
+
The training dataloader.
|
| 376 |
+
|
| 377 |
+
"""
|
| 378 |
+
dataset = PredictionDataset(
|
| 379 |
+
manifest=self.manifest,
|
| 380 |
+
target_dir=self.target_dir,
|
| 381 |
+
msa_dir=self.msa_dir,
|
| 382 |
+
mol_dir=self.mol_dir,
|
| 383 |
+
constraints_dir=self.constraints_dir,
|
| 384 |
+
template_dir=self.template_dir,
|
| 385 |
+
extra_mols_dir=self.extra_mols_dir,
|
| 386 |
+
override_method=self.override_method,
|
| 387 |
+
affinity=self.affinity,
|
| 388 |
+
)
|
| 389 |
+
return DataLoader(
|
| 390 |
+
dataset,
|
| 391 |
+
batch_size=1,
|
| 392 |
+
num_workers=self.num_workers,
|
| 393 |
+
pin_memory=True,
|
| 394 |
+
shuffle=False,
|
| 395 |
+
collate_fn=collate,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
def transfer_batch_to_device(
|
| 399 |
+
self,
|
| 400 |
+
batch: dict,
|
| 401 |
+
device: torch.device,
|
| 402 |
+
dataloader_idx: int, # noqa: ARG002
|
| 403 |
+
) -> dict:
|
| 404 |
+
"""Transfer a batch to the given device.
|
| 405 |
+
|
| 406 |
+
Parameters
|
| 407 |
+
----------
|
| 408 |
+
batch : Dict
|
| 409 |
+
The batch to transfer.
|
| 410 |
+
device : torch.device
|
| 411 |
+
The device to transfer to.
|
| 412 |
+
dataloader_idx : int
|
| 413 |
+
The dataloader index.
|
| 414 |
+
|
| 415 |
+
Returns
|
| 416 |
+
-------
|
| 417 |
+
np.Any
|
| 418 |
+
The transferred batch.
|
| 419 |
+
|
| 420 |
+
"""
|
| 421 |
+
for key in batch:
|
| 422 |
+
if key not in [
|
| 423 |
+
"all_coords",
|
| 424 |
+
"all_resolved_mask",
|
| 425 |
+
"crop_to_all_atom_map",
|
| 426 |
+
"chain_symmetries",
|
| 427 |
+
"amino_acids_symmetries",
|
| 428 |
+
"ligand_symmetries",
|
| 429 |
+
"record",
|
| 430 |
+
"affinity_mw",
|
| 431 |
+
]:
|
| 432 |
+
batch[key] = batch[key].to(device)
|
| 433 |
+
return batch
|
protify/FastPLMs/boltz/src/boltz/data/module/training.py
ADDED
|
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
from boltz.data.crop.cropper import Cropper
|
| 12 |
+
from boltz.data.feature.featurizer import BoltzFeaturizer
|
| 13 |
+
from boltz.data.feature.symmetry import get_symmetries
|
| 14 |
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
| 15 |
+
from boltz.data.pad import pad_to_max
|
| 16 |
+
from boltz.data.sample.sampler import Sample, Sampler
|
| 17 |
+
from boltz.data.tokenize.tokenizer import Tokenizer
|
| 18 |
+
from boltz.data.types import MSA, Connection, Input, Manifest, Record, Structure
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class DatasetConfig:
|
| 23 |
+
"""Dataset configuration."""
|
| 24 |
+
|
| 25 |
+
target_dir: str
|
| 26 |
+
msa_dir: str
|
| 27 |
+
prob: float
|
| 28 |
+
sampler: Sampler
|
| 29 |
+
cropper: Cropper
|
| 30 |
+
filters: Optional[list] = None
|
| 31 |
+
split: Optional[str] = None
|
| 32 |
+
manifest_path: Optional[str] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class DataConfig:
|
| 37 |
+
"""Data configuration."""
|
| 38 |
+
|
| 39 |
+
datasets: list[DatasetConfig]
|
| 40 |
+
filters: list[DynamicFilter]
|
| 41 |
+
featurizer: BoltzFeaturizer
|
| 42 |
+
tokenizer: Tokenizer
|
| 43 |
+
max_atoms: int
|
| 44 |
+
max_tokens: int
|
| 45 |
+
max_seqs: int
|
| 46 |
+
samples_per_epoch: int
|
| 47 |
+
batch_size: int
|
| 48 |
+
num_workers: int
|
| 49 |
+
random_seed: int
|
| 50 |
+
pin_memory: bool
|
| 51 |
+
symmetries: str
|
| 52 |
+
atoms_per_window_queries: int
|
| 53 |
+
min_dist: float
|
| 54 |
+
max_dist: float
|
| 55 |
+
num_bins: int
|
| 56 |
+
overfit: Optional[int] = None
|
| 57 |
+
pad_to_max_tokens: bool = False
|
| 58 |
+
pad_to_max_atoms: bool = False
|
| 59 |
+
pad_to_max_seqs: bool = False
|
| 60 |
+
crop_validation: bool = False
|
| 61 |
+
return_train_symmetries: bool = False
|
| 62 |
+
return_val_symmetries: bool = True
|
| 63 |
+
train_binder_pocket_conditioned_prop: float = 0.0
|
| 64 |
+
val_binder_pocket_conditioned_prop: float = 0.0
|
| 65 |
+
binder_pocket_cutoff: float = 6.0
|
| 66 |
+
binder_pocket_sampling_geometric_p: float = 0.0
|
| 67 |
+
val_batch_size: int = 1
|
| 68 |
+
compute_constraint_features: bool = False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class Dataset:
|
| 73 |
+
"""Data holder."""
|
| 74 |
+
|
| 75 |
+
target_dir: Path
|
| 76 |
+
msa_dir: Path
|
| 77 |
+
manifest: Manifest
|
| 78 |
+
prob: float
|
| 79 |
+
sampler: Sampler
|
| 80 |
+
cropper: Cropper
|
| 81 |
+
tokenizer: Tokenizer
|
| 82 |
+
featurizer: BoltzFeaturizer
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input:
|
| 86 |
+
"""Load the given input data.
|
| 87 |
+
|
| 88 |
+
Parameters
|
| 89 |
+
----------
|
| 90 |
+
record : Record
|
| 91 |
+
The record to load.
|
| 92 |
+
target_dir : Path
|
| 93 |
+
The path to the data directory.
|
| 94 |
+
msa_dir : Path
|
| 95 |
+
The path to msa directory.
|
| 96 |
+
|
| 97 |
+
Returns
|
| 98 |
+
-------
|
| 99 |
+
Input
|
| 100 |
+
The loaded input.
|
| 101 |
+
|
| 102 |
+
"""
|
| 103 |
+
# Load the structure
|
| 104 |
+
structure = np.load(target_dir / "structures" / f"{record.id}.npz")
|
| 105 |
+
|
| 106 |
+
# In order to add cyclic_period to chains if it does not exist
|
| 107 |
+
# Extract the chains array
|
| 108 |
+
chains = structure["chains"]
|
| 109 |
+
# Check if the field exists
|
| 110 |
+
if "cyclic_period" not in chains.dtype.names:
|
| 111 |
+
# Create a new dtype with the additional field
|
| 112 |
+
new_dtype = chains.dtype.descr + [("cyclic_period", "i4")]
|
| 113 |
+
# Create a new array with the new dtype
|
| 114 |
+
new_chains = np.empty(chains.shape, dtype=new_dtype)
|
| 115 |
+
# Copy over existing fields
|
| 116 |
+
for name in chains.dtype.names:
|
| 117 |
+
new_chains[name] = chains[name]
|
| 118 |
+
# Set the new field to 0
|
| 119 |
+
new_chains["cyclic_period"] = 0
|
| 120 |
+
# Replace old chains array with new one
|
| 121 |
+
chains = new_chains
|
| 122 |
+
|
| 123 |
+
structure = Structure(
|
| 124 |
+
atoms=structure["atoms"],
|
| 125 |
+
bonds=structure["bonds"],
|
| 126 |
+
residues=structure["residues"],
|
| 127 |
+
chains=chains, # chains var accounting for missing cyclic_period
|
| 128 |
+
connections=structure["connections"].astype(Connection),
|
| 129 |
+
interfaces=structure["interfaces"],
|
| 130 |
+
mask=structure["mask"],
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
msas = {}
|
| 134 |
+
for chain in record.chains:
|
| 135 |
+
msa_id = chain.msa_id
|
| 136 |
+
# Load the MSA for this chain, if any
|
| 137 |
+
if msa_id != -1 and msa_id != "":
|
| 138 |
+
msa = np.load(msa_dir / f"{msa_id}.npz")
|
| 139 |
+
msas[chain.chain_id] = MSA(**msa)
|
| 140 |
+
|
| 141 |
+
return Input(structure, msas)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
|
| 145 |
+
"""Collate the data.
|
| 146 |
+
|
| 147 |
+
Parameters
|
| 148 |
+
----------
|
| 149 |
+
data : list[dict[str, Tensor]]
|
| 150 |
+
The data to collate.
|
| 151 |
+
|
| 152 |
+
Returns
|
| 153 |
+
-------
|
| 154 |
+
dict[str, Tensor]
|
| 155 |
+
The collated data.
|
| 156 |
+
|
| 157 |
+
"""
|
| 158 |
+
# Get the keys
|
| 159 |
+
keys = data[0].keys()
|
| 160 |
+
|
| 161 |
+
# Collate the data
|
| 162 |
+
collated = {}
|
| 163 |
+
for key in keys:
|
| 164 |
+
values = [d[key] for d in data]
|
| 165 |
+
|
| 166 |
+
if key not in [
|
| 167 |
+
"all_coords",
|
| 168 |
+
"all_resolved_mask",
|
| 169 |
+
"crop_to_all_atom_map",
|
| 170 |
+
"chain_symmetries",
|
| 171 |
+
"amino_acids_symmetries",
|
| 172 |
+
"ligand_symmetries",
|
| 173 |
+
]:
|
| 174 |
+
# Check if all have the same shape
|
| 175 |
+
shape = values[0].shape
|
| 176 |
+
if not all(v.shape == shape for v in values):
|
| 177 |
+
values, _ = pad_to_max(values, 0)
|
| 178 |
+
else:
|
| 179 |
+
values = torch.stack(values, dim=0)
|
| 180 |
+
|
| 181 |
+
# Stack the values
|
| 182 |
+
collated[key] = values
|
| 183 |
+
|
| 184 |
+
return collated
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class TrainingDataset(torch.utils.data.Dataset):
|
| 188 |
+
"""Base iterable dataset."""
|
| 189 |
+
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
datasets: list[Dataset],
|
| 193 |
+
samples_per_epoch: int,
|
| 194 |
+
symmetries: dict,
|
| 195 |
+
max_atoms: int,
|
| 196 |
+
max_tokens: int,
|
| 197 |
+
max_seqs: int,
|
| 198 |
+
pad_to_max_atoms: bool = False,
|
| 199 |
+
pad_to_max_tokens: bool = False,
|
| 200 |
+
pad_to_max_seqs: bool = False,
|
| 201 |
+
atoms_per_window_queries: int = 32,
|
| 202 |
+
min_dist: float = 2.0,
|
| 203 |
+
max_dist: float = 22.0,
|
| 204 |
+
num_bins: int = 64,
|
| 205 |
+
overfit: Optional[int] = None,
|
| 206 |
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
| 207 |
+
binder_pocket_cutoff: Optional[float] = 6.0,
|
| 208 |
+
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
|
| 209 |
+
return_symmetries: Optional[bool] = False,
|
| 210 |
+
compute_constraint_features: bool = False,
|
| 211 |
+
) -> None:
|
| 212 |
+
"""Initialize the training dataset."""
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.datasets = datasets
|
| 215 |
+
self.probs = [d.prob for d in datasets]
|
| 216 |
+
self.samples_per_epoch = samples_per_epoch
|
| 217 |
+
self.symmetries = symmetries
|
| 218 |
+
self.max_tokens = max_tokens
|
| 219 |
+
self.max_seqs = max_seqs
|
| 220 |
+
self.max_atoms = max_atoms
|
| 221 |
+
self.pad_to_max_tokens = pad_to_max_tokens
|
| 222 |
+
self.pad_to_max_atoms = pad_to_max_atoms
|
| 223 |
+
self.pad_to_max_seqs = pad_to_max_seqs
|
| 224 |
+
self.atoms_per_window_queries = atoms_per_window_queries
|
| 225 |
+
self.min_dist = min_dist
|
| 226 |
+
self.max_dist = max_dist
|
| 227 |
+
self.num_bins = num_bins
|
| 228 |
+
self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
|
| 229 |
+
self.binder_pocket_cutoff = binder_pocket_cutoff
|
| 230 |
+
self.binder_pocket_sampling_geometric_p = binder_pocket_sampling_geometric_p
|
| 231 |
+
self.return_symmetries = return_symmetries
|
| 232 |
+
self.compute_constraint_features = compute_constraint_features
|
| 233 |
+
self.samples = []
|
| 234 |
+
for dataset in datasets:
|
| 235 |
+
records = dataset.manifest.records
|
| 236 |
+
if overfit is not None:
|
| 237 |
+
records = records[:overfit]
|
| 238 |
+
iterator = dataset.sampler.sample(records, np.random)
|
| 239 |
+
self.samples.append(iterator)
|
| 240 |
+
|
| 241 |
+
def __getitem__(self, idx: int) -> dict[str, Tensor]:
|
| 242 |
+
"""Get an item from the dataset.
|
| 243 |
+
|
| 244 |
+
Parameters
|
| 245 |
+
----------
|
| 246 |
+
idx : int
|
| 247 |
+
The data index.
|
| 248 |
+
|
| 249 |
+
Returns
|
| 250 |
+
-------
|
| 251 |
+
dict[str, Tensor]
|
| 252 |
+
The sampled data features.
|
| 253 |
+
|
| 254 |
+
"""
|
| 255 |
+
# Pick a random dataset
|
| 256 |
+
dataset_idx = np.random.choice(
|
| 257 |
+
len(self.datasets),
|
| 258 |
+
p=self.probs,
|
| 259 |
+
)
|
| 260 |
+
dataset = self.datasets[dataset_idx]
|
| 261 |
+
|
| 262 |
+
# Get a sample from the dataset
|
| 263 |
+
sample: Sample = next(self.samples[dataset_idx])
|
| 264 |
+
|
| 265 |
+
# Get the structure
|
| 266 |
+
try:
|
| 267 |
+
input_data = load_input(sample.record, dataset.target_dir, dataset.msa_dir)
|
| 268 |
+
except Exception as e:
|
| 269 |
+
print(
|
| 270 |
+
f"Failed to load input for {sample.record.id} with error {e}. Skipping."
|
| 271 |
+
)
|
| 272 |
+
return self.__getitem__(idx)
|
| 273 |
+
|
| 274 |
+
# Tokenize structure
|
| 275 |
+
try:
|
| 276 |
+
tokenized = dataset.tokenizer.tokenize(input_data)
|
| 277 |
+
except Exception as e:
|
| 278 |
+
print(f"Tokenizer failed on {sample.record.id} with error {e}. Skipping.")
|
| 279 |
+
return self.__getitem__(idx)
|
| 280 |
+
|
| 281 |
+
# Compute crop
|
| 282 |
+
try:
|
| 283 |
+
if self.max_tokens is not None:
|
| 284 |
+
tokenized = dataset.cropper.crop(
|
| 285 |
+
tokenized,
|
| 286 |
+
max_atoms=self.max_atoms,
|
| 287 |
+
max_tokens=self.max_tokens,
|
| 288 |
+
random=np.random,
|
| 289 |
+
chain_id=sample.chain_id,
|
| 290 |
+
interface_id=sample.interface_id,
|
| 291 |
+
)
|
| 292 |
+
except Exception as e:
|
| 293 |
+
print(f"Cropper failed on {sample.record.id} with error {e}. Skipping.")
|
| 294 |
+
return self.__getitem__(idx)
|
| 295 |
+
|
| 296 |
+
# Check if there are tokens
|
| 297 |
+
if len(tokenized.tokens) == 0:
|
| 298 |
+
msg = "No tokens in cropped structure."
|
| 299 |
+
raise ValueError(msg)
|
| 300 |
+
|
| 301 |
+
# Compute features
|
| 302 |
+
try:
|
| 303 |
+
features = dataset.featurizer.process(
|
| 304 |
+
tokenized,
|
| 305 |
+
training=True,
|
| 306 |
+
max_atoms=self.max_atoms if self.pad_to_max_atoms else None,
|
| 307 |
+
max_tokens=self.max_tokens if self.pad_to_max_tokens else None,
|
| 308 |
+
max_seqs=self.max_seqs,
|
| 309 |
+
pad_to_max_seqs=self.pad_to_max_seqs,
|
| 310 |
+
symmetries=self.symmetries,
|
| 311 |
+
atoms_per_window_queries=self.atoms_per_window_queries,
|
| 312 |
+
min_dist=self.min_dist,
|
| 313 |
+
max_dist=self.max_dist,
|
| 314 |
+
num_bins=self.num_bins,
|
| 315 |
+
compute_symmetries=self.return_symmetries,
|
| 316 |
+
binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
|
| 317 |
+
binder_pocket_cutoff=self.binder_pocket_cutoff,
|
| 318 |
+
binder_pocket_sampling_geometric_p=self.binder_pocket_sampling_geometric_p,
|
| 319 |
+
compute_constraint_features=self.compute_constraint_features,
|
| 320 |
+
)
|
| 321 |
+
except Exception as e:
|
| 322 |
+
print(f"Featurizer failed on {sample.record.id} with error {e}. Skipping.")
|
| 323 |
+
return self.__getitem__(idx)
|
| 324 |
+
|
| 325 |
+
return features
|
| 326 |
+
|
| 327 |
+
def __len__(self) -> int:
|
| 328 |
+
"""Get the length of the dataset.
|
| 329 |
+
|
| 330 |
+
Returns
|
| 331 |
+
-------
|
| 332 |
+
int
|
| 333 |
+
The length of the dataset.
|
| 334 |
+
|
| 335 |
+
"""
|
| 336 |
+
return self.samples_per_epoch
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class ValidationDataset(torch.utils.data.Dataset):
|
| 340 |
+
"""Base iterable dataset."""
|
| 341 |
+
|
| 342 |
+
def __init__(
|
| 343 |
+
self,
|
| 344 |
+
datasets: list[Dataset],
|
| 345 |
+
seed: int,
|
| 346 |
+
symmetries: dict,
|
| 347 |
+
max_atoms: Optional[int] = None,
|
| 348 |
+
max_tokens: Optional[int] = None,
|
| 349 |
+
max_seqs: Optional[int] = None,
|
| 350 |
+
pad_to_max_atoms: bool = False,
|
| 351 |
+
pad_to_max_tokens: bool = False,
|
| 352 |
+
pad_to_max_seqs: bool = False,
|
| 353 |
+
atoms_per_window_queries: int = 32,
|
| 354 |
+
min_dist: float = 2.0,
|
| 355 |
+
max_dist: float = 22.0,
|
| 356 |
+
num_bins: int = 64,
|
| 357 |
+
overfit: Optional[int] = None,
|
| 358 |
+
crop_validation: bool = False,
|
| 359 |
+
return_symmetries: Optional[bool] = False,
|
| 360 |
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
| 361 |
+
binder_pocket_cutoff: Optional[float] = 6.0,
|
| 362 |
+
compute_constraint_features: bool = False,
|
| 363 |
+
) -> None:
|
| 364 |
+
"""Initialize the validation dataset."""
|
| 365 |
+
super().__init__()
|
| 366 |
+
self.datasets = datasets
|
| 367 |
+
self.max_atoms = max_atoms
|
| 368 |
+
self.max_tokens = max_tokens
|
| 369 |
+
self.max_seqs = max_seqs
|
| 370 |
+
self.seed = seed
|
| 371 |
+
self.symmetries = symmetries
|
| 372 |
+
self.random = np.random if overfit else np.random.RandomState(self.seed)
|
| 373 |
+
self.pad_to_max_tokens = pad_to_max_tokens
|
| 374 |
+
self.pad_to_max_atoms = pad_to_max_atoms
|
| 375 |
+
self.pad_to_max_seqs = pad_to_max_seqs
|
| 376 |
+
self.overfit = overfit
|
| 377 |
+
self.crop_validation = crop_validation
|
| 378 |
+
self.atoms_per_window_queries = atoms_per_window_queries
|
| 379 |
+
self.min_dist = min_dist
|
| 380 |
+
self.max_dist = max_dist
|
| 381 |
+
self.num_bins = num_bins
|
| 382 |
+
self.return_symmetries = return_symmetries
|
| 383 |
+
self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
|
| 384 |
+
self.binder_pocket_cutoff = binder_pocket_cutoff
|
| 385 |
+
self.compute_constraint_features = compute_constraint_features
|
| 386 |
+
|
| 387 |
+
def __getitem__(self, idx: int) -> dict[str, Tensor]:
|
| 388 |
+
"""Get an item from the dataset.
|
| 389 |
+
|
| 390 |
+
Parameters
|
| 391 |
+
----------
|
| 392 |
+
idx : int
|
| 393 |
+
The data index.
|
| 394 |
+
|
| 395 |
+
Returns
|
| 396 |
+
-------
|
| 397 |
+
dict[str, Tensor]
|
| 398 |
+
The sampled data features.
|
| 399 |
+
|
| 400 |
+
"""
|
| 401 |
+
# Pick dataset based on idx
|
| 402 |
+
for dataset in self.datasets:
|
| 403 |
+
size = len(dataset.manifest.records)
|
| 404 |
+
if self.overfit is not None:
|
| 405 |
+
size = min(size, self.overfit)
|
| 406 |
+
if idx < size:
|
| 407 |
+
break
|
| 408 |
+
idx -= size
|
| 409 |
+
|
| 410 |
+
# Get a sample from the dataset
|
| 411 |
+
record = dataset.manifest.records[idx]
|
| 412 |
+
|
| 413 |
+
# Get the structure
|
| 414 |
+
try:
|
| 415 |
+
input_data = load_input(record, dataset.target_dir, dataset.msa_dir)
|
| 416 |
+
except Exception as e:
|
| 417 |
+
print(f"Failed to load input for {record.id} with error {e}. Skipping.")
|
| 418 |
+
return self.__getitem__(0)
|
| 419 |
+
|
| 420 |
+
# Tokenize structure
|
| 421 |
+
try:
|
| 422 |
+
tokenized = dataset.tokenizer.tokenize(input_data)
|
| 423 |
+
except Exception as e:
|
| 424 |
+
print(f"Tokenizer failed on {record.id} with error {e}. Skipping.")
|
| 425 |
+
return self.__getitem__(0)
|
| 426 |
+
|
| 427 |
+
# Compute crop
|
| 428 |
+
try:
|
| 429 |
+
if self.crop_validation and (self.max_tokens is not None):
|
| 430 |
+
tokenized = dataset.cropper.crop(
|
| 431 |
+
tokenized,
|
| 432 |
+
max_tokens=self.max_tokens,
|
| 433 |
+
random=self.random,
|
| 434 |
+
max_atoms=self.max_atoms,
|
| 435 |
+
)
|
| 436 |
+
except Exception as e:
|
| 437 |
+
print(f"Cropper failed on {record.id} with error {e}. Skipping.")
|
| 438 |
+
return self.__getitem__(0)
|
| 439 |
+
|
| 440 |
+
# Check if there are tokens
|
| 441 |
+
if len(tokenized.tokens) == 0:
|
| 442 |
+
msg = "No tokens in cropped structure."
|
| 443 |
+
raise ValueError(msg)
|
| 444 |
+
|
| 445 |
+
# Compute features
|
| 446 |
+
try:
|
| 447 |
+
pad_atoms = self.crop_validation and self.pad_to_max_atoms
|
| 448 |
+
pad_tokens = self.crop_validation and self.pad_to_max_tokens
|
| 449 |
+
|
| 450 |
+
features = dataset.featurizer.process(
|
| 451 |
+
tokenized,
|
| 452 |
+
training=False,
|
| 453 |
+
max_atoms=self.max_atoms if pad_atoms else None,
|
| 454 |
+
max_tokens=self.max_tokens if pad_tokens else None,
|
| 455 |
+
max_seqs=self.max_seqs,
|
| 456 |
+
pad_to_max_seqs=self.pad_to_max_seqs,
|
| 457 |
+
symmetries=self.symmetries,
|
| 458 |
+
atoms_per_window_queries=self.atoms_per_window_queries,
|
| 459 |
+
min_dist=self.min_dist,
|
| 460 |
+
max_dist=self.max_dist,
|
| 461 |
+
num_bins=self.num_bins,
|
| 462 |
+
compute_symmetries=self.return_symmetries,
|
| 463 |
+
binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
|
| 464 |
+
binder_pocket_cutoff=self.binder_pocket_cutoff,
|
| 465 |
+
binder_pocket_sampling_geometric_p=1.0, # this will only sample a single pocket token
|
| 466 |
+
only_ligand_binder_pocket=True,
|
| 467 |
+
compute_constraint_features=self.compute_constraint_features,
|
| 468 |
+
)
|
| 469 |
+
except Exception as e:
|
| 470 |
+
print(f"Featurizer failed on {record.id} with error {e}. Skipping.")
|
| 471 |
+
return self.__getitem__(0)
|
| 472 |
+
|
| 473 |
+
return features
|
| 474 |
+
|
| 475 |
+
def __len__(self) -> int:
|
| 476 |
+
"""Get the length of the dataset.
|
| 477 |
+
|
| 478 |
+
Returns
|
| 479 |
+
-------
|
| 480 |
+
int
|
| 481 |
+
The length of the dataset.
|
| 482 |
+
|
| 483 |
+
"""
|
| 484 |
+
if self.overfit is not None:
|
| 485 |
+
length = sum(len(d.manifest.records[: self.overfit]) for d in self.datasets)
|
| 486 |
+
else:
|
| 487 |
+
length = sum(len(d.manifest.records) for d in self.datasets)
|
| 488 |
+
|
| 489 |
+
return length
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class BoltzTrainingDataModule(pl.LightningDataModule):
|
| 493 |
+
"""DataModule for boltz."""
|
| 494 |
+
|
| 495 |
+
def __init__(self, cfg: DataConfig) -> None:
|
| 496 |
+
"""Initialize the DataModule.
|
| 497 |
+
|
| 498 |
+
Parameters
|
| 499 |
+
----------
|
| 500 |
+
config : DataConfig
|
| 501 |
+
The data configuration.
|
| 502 |
+
|
| 503 |
+
"""
|
| 504 |
+
super().__init__()
|
| 505 |
+
self.cfg = cfg
|
| 506 |
+
|
| 507 |
+
assert self.cfg.val_batch_size == 1, "Validation only works with batch size=1."
|
| 508 |
+
|
| 509 |
+
# Load symmetries
|
| 510 |
+
symmetries = get_symmetries(cfg.symmetries)
|
| 511 |
+
|
| 512 |
+
# Load datasets
|
| 513 |
+
train: list[Dataset] = []
|
| 514 |
+
val: list[Dataset] = []
|
| 515 |
+
|
| 516 |
+
for data_config in cfg.datasets:
|
| 517 |
+
# Set target_dir
|
| 518 |
+
target_dir = Path(data_config.target_dir)
|
| 519 |
+
msa_dir = Path(data_config.msa_dir)
|
| 520 |
+
|
| 521 |
+
# Load manifest
|
| 522 |
+
if data_config.manifest_path is not None:
|
| 523 |
+
path = Path(data_config.manifest_path)
|
| 524 |
+
else:
|
| 525 |
+
path = target_dir / "manifest.json"
|
| 526 |
+
manifest: Manifest = Manifest.load(path)
|
| 527 |
+
|
| 528 |
+
# Split records if given
|
| 529 |
+
if data_config.split is not None:
|
| 530 |
+
with Path(data_config.split).open("r") as f:
|
| 531 |
+
split = {x.lower() for x in f.read().splitlines()}
|
| 532 |
+
|
| 533 |
+
train_records = []
|
| 534 |
+
val_records = []
|
| 535 |
+
for record in manifest.records:
|
| 536 |
+
if record.id.lower() in split:
|
| 537 |
+
val_records.append(record)
|
| 538 |
+
else:
|
| 539 |
+
train_records.append(record)
|
| 540 |
+
else:
|
| 541 |
+
train_records = manifest.records
|
| 542 |
+
val_records = []
|
| 543 |
+
|
| 544 |
+
# Filter training records
|
| 545 |
+
train_records = [
|
| 546 |
+
record
|
| 547 |
+
for record in train_records
|
| 548 |
+
if all(f.filter(record) for f in cfg.filters)
|
| 549 |
+
]
|
| 550 |
+
# Filter training records
|
| 551 |
+
if data_config.filters is not None:
|
| 552 |
+
train_records = [
|
| 553 |
+
record
|
| 554 |
+
for record in train_records
|
| 555 |
+
if all(f.filter(record) for f in data_config.filters)
|
| 556 |
+
]
|
| 557 |
+
|
| 558 |
+
# Create train dataset
|
| 559 |
+
train_manifest = Manifest(train_records)
|
| 560 |
+
train.append(
|
| 561 |
+
Dataset(
|
| 562 |
+
target_dir,
|
| 563 |
+
msa_dir,
|
| 564 |
+
train_manifest,
|
| 565 |
+
data_config.prob,
|
| 566 |
+
data_config.sampler,
|
| 567 |
+
data_config.cropper,
|
| 568 |
+
cfg.tokenizer,
|
| 569 |
+
cfg.featurizer,
|
| 570 |
+
)
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# Create validation dataset
|
| 574 |
+
if val_records:
|
| 575 |
+
val_manifest = Manifest(val_records)
|
| 576 |
+
val.append(
|
| 577 |
+
Dataset(
|
| 578 |
+
target_dir,
|
| 579 |
+
msa_dir,
|
| 580 |
+
val_manifest,
|
| 581 |
+
data_config.prob,
|
| 582 |
+
data_config.sampler,
|
| 583 |
+
data_config.cropper,
|
| 584 |
+
cfg.tokenizer,
|
| 585 |
+
cfg.featurizer,
|
| 586 |
+
)
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
# Print dataset sizes
|
| 590 |
+
for dataset in train:
|
| 591 |
+
dataset: Dataset
|
| 592 |
+
print(f"Training dataset size: {len(dataset.manifest.records)}")
|
| 593 |
+
|
| 594 |
+
for dataset in val:
|
| 595 |
+
dataset: Dataset
|
| 596 |
+
print(f"Validation dataset size: {len(dataset.manifest.records)}")
|
| 597 |
+
|
| 598 |
+
# Create wrapper datasets
|
| 599 |
+
self._train_set = TrainingDataset(
|
| 600 |
+
datasets=train,
|
| 601 |
+
samples_per_epoch=cfg.samples_per_epoch,
|
| 602 |
+
max_atoms=cfg.max_atoms,
|
| 603 |
+
max_tokens=cfg.max_tokens,
|
| 604 |
+
max_seqs=cfg.max_seqs,
|
| 605 |
+
pad_to_max_atoms=cfg.pad_to_max_atoms,
|
| 606 |
+
pad_to_max_tokens=cfg.pad_to_max_tokens,
|
| 607 |
+
pad_to_max_seqs=cfg.pad_to_max_seqs,
|
| 608 |
+
symmetries=symmetries,
|
| 609 |
+
atoms_per_window_queries=cfg.atoms_per_window_queries,
|
| 610 |
+
min_dist=cfg.min_dist,
|
| 611 |
+
max_dist=cfg.max_dist,
|
| 612 |
+
num_bins=cfg.num_bins,
|
| 613 |
+
overfit=cfg.overfit,
|
| 614 |
+
binder_pocket_conditioned_prop=cfg.train_binder_pocket_conditioned_prop,
|
| 615 |
+
binder_pocket_cutoff=cfg.binder_pocket_cutoff,
|
| 616 |
+
binder_pocket_sampling_geometric_p=cfg.binder_pocket_sampling_geometric_p,
|
| 617 |
+
return_symmetries=cfg.return_train_symmetries,
|
| 618 |
+
compute_constraint_features=cfg.compute_constraint_features,
|
| 619 |
+
)
|
| 620 |
+
self._val_set = ValidationDataset(
|
| 621 |
+
datasets=train if cfg.overfit is not None else val,
|
| 622 |
+
seed=cfg.random_seed,
|
| 623 |
+
max_atoms=cfg.max_atoms,
|
| 624 |
+
max_tokens=cfg.max_tokens,
|
| 625 |
+
max_seqs=cfg.max_seqs,
|
| 626 |
+
pad_to_max_atoms=cfg.pad_to_max_atoms,
|
| 627 |
+
pad_to_max_tokens=cfg.pad_to_max_tokens,
|
| 628 |
+
pad_to_max_seqs=cfg.pad_to_max_seqs,
|
| 629 |
+
symmetries=symmetries,
|
| 630 |
+
atoms_per_window_queries=cfg.atoms_per_window_queries,
|
| 631 |
+
min_dist=cfg.min_dist,
|
| 632 |
+
max_dist=cfg.max_dist,
|
| 633 |
+
num_bins=cfg.num_bins,
|
| 634 |
+
overfit=cfg.overfit,
|
| 635 |
+
crop_validation=cfg.crop_validation,
|
| 636 |
+
return_symmetries=cfg.return_val_symmetries,
|
| 637 |
+
binder_pocket_conditioned_prop=cfg.val_binder_pocket_conditioned_prop,
|
| 638 |
+
binder_pocket_cutoff=cfg.binder_pocket_cutoff,
|
| 639 |
+
compute_constraint_features=cfg.compute_constraint_features,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
| 643 |
+
"""Run the setup for the DataModule.
|
| 644 |
+
|
| 645 |
+
Parameters
|
| 646 |
+
----------
|
| 647 |
+
stage : str, optional
|
| 648 |
+
The stage, one of 'fit', 'validate', 'test'.
|
| 649 |
+
|
| 650 |
+
"""
|
| 651 |
+
return
|
| 652 |
+
|
| 653 |
+
def train_dataloader(self) -> DataLoader:
|
| 654 |
+
"""Get the training dataloader.
|
| 655 |
+
|
| 656 |
+
Returns
|
| 657 |
+
-------
|
| 658 |
+
DataLoader
|
| 659 |
+
The training dataloader.
|
| 660 |
+
|
| 661 |
+
"""
|
| 662 |
+
return DataLoader(
|
| 663 |
+
self._train_set,
|
| 664 |
+
batch_size=self.cfg.batch_size,
|
| 665 |
+
num_workers=self.cfg.num_workers,
|
| 666 |
+
pin_memory=self.cfg.pin_memory,
|
| 667 |
+
shuffle=False,
|
| 668 |
+
collate_fn=collate,
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
def val_dataloader(self) -> DataLoader:
|
| 672 |
+
"""Get the validation dataloader.
|
| 673 |
+
|
| 674 |
+
Returns
|
| 675 |
+
-------
|
| 676 |
+
DataLoader
|
| 677 |
+
The validation dataloader.
|
| 678 |
+
|
| 679 |
+
"""
|
| 680 |
+
return DataLoader(
|
| 681 |
+
self._val_set,
|
| 682 |
+
batch_size=self.cfg.val_batch_size,
|
| 683 |
+
num_workers=self.cfg.num_workers,
|
| 684 |
+
pin_memory=self.cfg.pin_memory,
|
| 685 |
+
shuffle=False,
|
| 686 |
+
collate_fn=collate,
|
| 687 |
+
)
|
protify/FastPLMs/boltz/src/boltz/data/module/trainingv2.py
ADDED
|
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
|
| 11 |
+
from boltz.data.crop.cropper import Cropper
|
| 12 |
+
from boltz.data.feature.featurizer import BoltzFeaturizer
|
| 13 |
+
from boltz.data.feature.symmetry import get_symmetries
|
| 14 |
+
from boltz.data.filter.dynamic.filter import DynamicFilter
|
| 15 |
+
from boltz.data.pad import pad_to_max
|
| 16 |
+
from boltz.data.sample.sampler import Sample, Sampler
|
| 17 |
+
from boltz.data.tokenize.tokenizer import Tokenizer
|
| 18 |
+
from boltz.data.types import MSA, Connection, Input, Manifest, Record, Structure
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class DatasetConfig:
|
| 23 |
+
"""Dataset configuration."""
|
| 24 |
+
|
| 25 |
+
target_dir: str
|
| 26 |
+
msa_dir: str
|
| 27 |
+
prob: float
|
| 28 |
+
sampler: Sampler
|
| 29 |
+
cropper: Cropper
|
| 30 |
+
filters: Optional[list] = None
|
| 31 |
+
split: Optional[str] = None
|
| 32 |
+
manifest_path: Optional[str] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class DataConfig:
|
| 37 |
+
"""Data configuration."""
|
| 38 |
+
|
| 39 |
+
datasets: list[DatasetConfig]
|
| 40 |
+
filters: list[DynamicFilter]
|
| 41 |
+
featurizer: BoltzFeaturizer
|
| 42 |
+
tokenizer: Tokenizer
|
| 43 |
+
max_atoms: int
|
| 44 |
+
max_tokens: int
|
| 45 |
+
max_seqs: int
|
| 46 |
+
samples_per_epoch: int
|
| 47 |
+
batch_size: int
|
| 48 |
+
num_workers: int
|
| 49 |
+
random_seed: int
|
| 50 |
+
pin_memory: bool
|
| 51 |
+
symmetries: str
|
| 52 |
+
atoms_per_window_queries: int
|
| 53 |
+
min_dist: float
|
| 54 |
+
max_dist: float
|
| 55 |
+
num_bins: int
|
| 56 |
+
overfit: Optional[int] = None
|
| 57 |
+
pad_to_max_tokens: bool = False
|
| 58 |
+
pad_to_max_atoms: bool = False
|
| 59 |
+
pad_to_max_seqs: bool = False
|
| 60 |
+
crop_validation: bool = False
|
| 61 |
+
return_train_symmetries: bool = False
|
| 62 |
+
return_val_symmetries: bool = True
|
| 63 |
+
train_binder_pocket_conditioned_prop: float = 0.0
|
| 64 |
+
val_binder_pocket_conditioned_prop: float = 0.0
|
| 65 |
+
binder_pocket_cutoff: float = 6.0
|
| 66 |
+
binder_pocket_sampling_geometric_p: float = 0.0
|
| 67 |
+
val_batch_size: int = 1
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@dataclass
|
| 71 |
+
class Dataset:
|
| 72 |
+
"""Data holder."""
|
| 73 |
+
|
| 74 |
+
target_dir: Path
|
| 75 |
+
msa_dir: Path
|
| 76 |
+
manifest: Manifest
|
| 77 |
+
prob: float
|
| 78 |
+
sampler: Sampler
|
| 79 |
+
cropper: Cropper
|
| 80 |
+
tokenizer: Tokenizer
|
| 81 |
+
featurizer: BoltzFeaturizer
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input:
|
| 85 |
+
"""Load the given input data.
|
| 86 |
+
|
| 87 |
+
Parameters
|
| 88 |
+
----------
|
| 89 |
+
record : Record
|
| 90 |
+
The record to load.
|
| 91 |
+
target_dir : Path
|
| 92 |
+
The path to the data directory.
|
| 93 |
+
msa_dir : Path
|
| 94 |
+
The path to msa directory.
|
| 95 |
+
|
| 96 |
+
Returns
|
| 97 |
+
-------
|
| 98 |
+
Input
|
| 99 |
+
The loaded input.
|
| 100 |
+
|
| 101 |
+
"""
|
| 102 |
+
# Load the structure
|
| 103 |
+
structure = np.load(target_dir / "structures" / f"{record.id}.npz")
|
| 104 |
+
structure = Structure(
|
| 105 |
+
atoms=structure["atoms"],
|
| 106 |
+
bonds=structure["bonds"],
|
| 107 |
+
residues=structure["residues"],
|
| 108 |
+
chains=structure["chains"],
|
| 109 |
+
connections=structure["connections"].astype(Connection),
|
| 110 |
+
interfaces=structure["interfaces"],
|
| 111 |
+
mask=structure["mask"],
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
msas = {}
|
| 115 |
+
for chain in record.chains:
|
| 116 |
+
msa_id = chain.msa_id
|
| 117 |
+
# Load the MSA for this chain, if any
|
| 118 |
+
if msa_id != -1 and msa_id != "":
|
| 119 |
+
msa = np.load(msa_dir / f"{msa_id}.npz")
|
| 120 |
+
msas[chain.chain_id] = MSA(**msa)
|
| 121 |
+
|
| 122 |
+
return Input(structure, msas)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
|
| 126 |
+
"""Collate the data.
|
| 127 |
+
|
| 128 |
+
Parameters
|
| 129 |
+
----------
|
| 130 |
+
data : list[dict[str, Tensor]]
|
| 131 |
+
The data to collate.
|
| 132 |
+
|
| 133 |
+
Returns
|
| 134 |
+
-------
|
| 135 |
+
dict[str, Tensor]
|
| 136 |
+
The collated data.
|
| 137 |
+
|
| 138 |
+
"""
|
| 139 |
+
# Get the keys
|
| 140 |
+
keys = data[0].keys()
|
| 141 |
+
|
| 142 |
+
# Collate the data
|
| 143 |
+
collated = {}
|
| 144 |
+
for key in keys:
|
| 145 |
+
values = [d[key] for d in data]
|
| 146 |
+
|
| 147 |
+
if key not in [
|
| 148 |
+
"all_coords",
|
| 149 |
+
"all_resolved_mask",
|
| 150 |
+
"crop_to_all_atom_map",
|
| 151 |
+
"chain_symmetries",
|
| 152 |
+
"amino_acids_symmetries",
|
| 153 |
+
"ligand_symmetries",
|
| 154 |
+
]:
|
| 155 |
+
# Check if all have the same shape
|
| 156 |
+
shape = values[0].shape
|
| 157 |
+
if not all(v.shape == shape for v in values):
|
| 158 |
+
values, _ = pad_to_max(values, 0)
|
| 159 |
+
else:
|
| 160 |
+
values = torch.stack(values, dim=0)
|
| 161 |
+
|
| 162 |
+
# Stack the values
|
| 163 |
+
collated[key] = values
|
| 164 |
+
|
| 165 |
+
return collated
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class TrainingDataset(torch.utils.data.Dataset):
|
| 169 |
+
"""Base iterable dataset."""
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
datasets: list[Dataset],
|
| 174 |
+
samples_per_epoch: int,
|
| 175 |
+
symmetries: dict,
|
| 176 |
+
max_atoms: int,
|
| 177 |
+
max_tokens: int,
|
| 178 |
+
max_seqs: int,
|
| 179 |
+
pad_to_max_atoms: bool = False,
|
| 180 |
+
pad_to_max_tokens: bool = False,
|
| 181 |
+
pad_to_max_seqs: bool = False,
|
| 182 |
+
atoms_per_window_queries: int = 32,
|
| 183 |
+
min_dist: float = 2.0,
|
| 184 |
+
max_dist: float = 22.0,
|
| 185 |
+
num_bins: int = 64,
|
| 186 |
+
overfit: Optional[int] = None,
|
| 187 |
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
| 188 |
+
binder_pocket_cutoff: Optional[float] = 6.0,
|
| 189 |
+
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
|
| 190 |
+
return_symmetries: Optional[bool] = False,
|
| 191 |
+
) -> None:
|
| 192 |
+
"""Initialize the training dataset."""
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.datasets = datasets
|
| 195 |
+
self.probs = [d.prob for d in datasets]
|
| 196 |
+
self.samples_per_epoch = samples_per_epoch
|
| 197 |
+
self.symmetries = symmetries
|
| 198 |
+
self.max_tokens = max_tokens
|
| 199 |
+
self.max_seqs = max_seqs
|
| 200 |
+
self.max_atoms = max_atoms
|
| 201 |
+
self.pad_to_max_tokens = pad_to_max_tokens
|
| 202 |
+
self.pad_to_max_atoms = pad_to_max_atoms
|
| 203 |
+
self.pad_to_max_seqs = pad_to_max_seqs
|
| 204 |
+
self.atoms_per_window_queries = atoms_per_window_queries
|
| 205 |
+
self.min_dist = min_dist
|
| 206 |
+
self.max_dist = max_dist
|
| 207 |
+
self.num_bins = num_bins
|
| 208 |
+
self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
|
| 209 |
+
self.binder_pocket_cutoff = binder_pocket_cutoff
|
| 210 |
+
self.binder_pocket_sampling_geometric_p = binder_pocket_sampling_geometric_p
|
| 211 |
+
self.return_symmetries = return_symmetries
|
| 212 |
+
self.samples = []
|
| 213 |
+
for dataset in datasets:
|
| 214 |
+
records = dataset.manifest.records
|
| 215 |
+
if overfit is not None:
|
| 216 |
+
records = records[:overfit]
|
| 217 |
+
iterator = dataset.sampler.sample(records, np.random)
|
| 218 |
+
self.samples.append(iterator)
|
| 219 |
+
|
| 220 |
+
def __getitem__(self, idx: int) -> dict[str, Tensor]:
|
| 221 |
+
"""Get an item from the dataset.
|
| 222 |
+
|
| 223 |
+
Parameters
|
| 224 |
+
----------
|
| 225 |
+
idx : int
|
| 226 |
+
The data index.
|
| 227 |
+
|
| 228 |
+
Returns
|
| 229 |
+
-------
|
| 230 |
+
dict[str, Tensor]
|
| 231 |
+
The sampled data features.
|
| 232 |
+
|
| 233 |
+
"""
|
| 234 |
+
# Pick a random dataset
|
| 235 |
+
dataset_idx = np.random.choice(
|
| 236 |
+
len(self.datasets),
|
| 237 |
+
p=self.probs,
|
| 238 |
+
)
|
| 239 |
+
dataset = self.datasets[dataset_idx]
|
| 240 |
+
|
| 241 |
+
# Get a sample from the dataset
|
| 242 |
+
sample: Sample = next(self.samples[dataset_idx])
|
| 243 |
+
|
| 244 |
+
# Get the structure
|
| 245 |
+
try:
|
| 246 |
+
input_data = load_input(sample.record, dataset.target_dir, dataset.msa_dir)
|
| 247 |
+
except Exception as e:
|
| 248 |
+
print(
|
| 249 |
+
f"Failed to load input for {sample.record.id} with error {e}. Skipping."
|
| 250 |
+
)
|
| 251 |
+
return self.__getitem__(idx)
|
| 252 |
+
|
| 253 |
+
# Tokenize structure
|
| 254 |
+
try:
|
| 255 |
+
tokenized = dataset.tokenizer.tokenize(input_data)
|
| 256 |
+
except Exception as e:
|
| 257 |
+
print(f"Tokenizer failed on {sample.record.id} with error {e}. Skipping.")
|
| 258 |
+
return self.__getitem__(idx)
|
| 259 |
+
|
| 260 |
+
# Compute crop
|
| 261 |
+
try:
|
| 262 |
+
if self.max_tokens is not None:
|
| 263 |
+
tokenized = dataset.cropper.crop(
|
| 264 |
+
tokenized,
|
| 265 |
+
max_atoms=self.max_atoms,
|
| 266 |
+
max_tokens=self.max_tokens,
|
| 267 |
+
random=np.random,
|
| 268 |
+
chain_id=sample.chain_id,
|
| 269 |
+
interface_id=sample.interface_id,
|
| 270 |
+
)
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"Cropper failed on {sample.record.id} with error {e}. Skipping.")
|
| 273 |
+
return self.__getitem__(idx)
|
| 274 |
+
|
| 275 |
+
# Check if there are tokens
|
| 276 |
+
if len(tokenized.tokens) == 0:
|
| 277 |
+
msg = "No tokens in cropped structure."
|
| 278 |
+
raise ValueError(msg)
|
| 279 |
+
|
| 280 |
+
# Compute features
|
| 281 |
+
try:
|
| 282 |
+
features = dataset.featurizer.process(
|
| 283 |
+
tokenized,
|
| 284 |
+
training=True,
|
| 285 |
+
max_atoms=self.max_atoms if self.pad_to_max_atoms else None,
|
| 286 |
+
max_tokens=self.max_tokens if self.pad_to_max_tokens else None,
|
| 287 |
+
max_seqs=self.max_seqs,
|
| 288 |
+
pad_to_max_seqs=self.pad_to_max_seqs,
|
| 289 |
+
symmetries=self.symmetries,
|
| 290 |
+
atoms_per_window_queries=self.atoms_per_window_queries,
|
| 291 |
+
min_dist=self.min_dist,
|
| 292 |
+
max_dist=self.max_dist,
|
| 293 |
+
num_bins=self.num_bins,
|
| 294 |
+
compute_symmetries=self.return_symmetries,
|
| 295 |
+
binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
|
| 296 |
+
binder_pocket_cutoff=self.binder_pocket_cutoff,
|
| 297 |
+
binder_pocket_sampling_geometric_p=self.binder_pocket_sampling_geometric_p,
|
| 298 |
+
)
|
| 299 |
+
except Exception as e:
|
| 300 |
+
print(f"Featurizer failed on {sample.record.id} with error {e}. Skipping.")
|
| 301 |
+
return self.__getitem__(idx)
|
| 302 |
+
|
| 303 |
+
return features
|
| 304 |
+
|
| 305 |
+
def __len__(self) -> int:
|
| 306 |
+
"""Get the length of the dataset.
|
| 307 |
+
|
| 308 |
+
Returns
|
| 309 |
+
-------
|
| 310 |
+
int
|
| 311 |
+
The length of the dataset.
|
| 312 |
+
|
| 313 |
+
"""
|
| 314 |
+
return self.samples_per_epoch
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
class ValidationDataset(torch.utils.data.Dataset):
|
| 318 |
+
"""Base iterable dataset."""
|
| 319 |
+
|
| 320 |
+
def __init__(
|
| 321 |
+
self,
|
| 322 |
+
datasets: list[Dataset],
|
| 323 |
+
seed: int,
|
| 324 |
+
symmetries: dict,
|
| 325 |
+
max_atoms: Optional[int] = None,
|
| 326 |
+
max_tokens: Optional[int] = None,
|
| 327 |
+
max_seqs: Optional[int] = None,
|
| 328 |
+
pad_to_max_atoms: bool = False,
|
| 329 |
+
pad_to_max_tokens: bool = False,
|
| 330 |
+
pad_to_max_seqs: bool = False,
|
| 331 |
+
atoms_per_window_queries: int = 32,
|
| 332 |
+
min_dist: float = 2.0,
|
| 333 |
+
max_dist: float = 22.0,
|
| 334 |
+
num_bins: int = 64,
|
| 335 |
+
overfit: Optional[int] = None,
|
| 336 |
+
crop_validation: bool = False,
|
| 337 |
+
return_symmetries: Optional[bool] = False,
|
| 338 |
+
binder_pocket_conditioned_prop: Optional[float] = 0.0,
|
| 339 |
+
binder_pocket_cutoff: Optional[float] = 6.0,
|
| 340 |
+
) -> None:
|
| 341 |
+
"""Initialize the validation dataset."""
|
| 342 |
+
super().__init__()
|
| 343 |
+
self.datasets = datasets
|
| 344 |
+
self.max_atoms = max_atoms
|
| 345 |
+
self.max_tokens = max_tokens
|
| 346 |
+
self.max_seqs = max_seqs
|
| 347 |
+
self.seed = seed
|
| 348 |
+
self.symmetries = symmetries
|
| 349 |
+
self.random = np.random if overfit else np.random.RandomState(self.seed)
|
| 350 |
+
self.pad_to_max_tokens = pad_to_max_tokens
|
| 351 |
+
self.pad_to_max_atoms = pad_to_max_atoms
|
| 352 |
+
self.pad_to_max_seqs = pad_to_max_seqs
|
| 353 |
+
self.overfit = overfit
|
| 354 |
+
self.crop_validation = crop_validation
|
| 355 |
+
self.atoms_per_window_queries = atoms_per_window_queries
|
| 356 |
+
self.min_dist = min_dist
|
| 357 |
+
self.max_dist = max_dist
|
| 358 |
+
self.num_bins = num_bins
|
| 359 |
+
self.return_symmetries = return_symmetries
|
| 360 |
+
self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
|
| 361 |
+
self.binder_pocket_cutoff = binder_pocket_cutoff
|
| 362 |
+
|
| 363 |
+
def __getitem__(self, idx: int) -> dict[str, Tensor]:
|
| 364 |
+
"""Get an item from the dataset.
|
| 365 |
+
|
| 366 |
+
Parameters
|
| 367 |
+
----------
|
| 368 |
+
idx : int
|
| 369 |
+
The data index.
|
| 370 |
+
|
| 371 |
+
Returns
|
| 372 |
+
-------
|
| 373 |
+
dict[str, Tensor]
|
| 374 |
+
The sampled data features.
|
| 375 |
+
|
| 376 |
+
"""
|
| 377 |
+
# Pick dataset based on idx
|
| 378 |
+
for dataset in self.datasets:
|
| 379 |
+
size = len(dataset.manifest.records)
|
| 380 |
+
if self.overfit is not None:
|
| 381 |
+
size = min(size, self.overfit)
|
| 382 |
+
if idx < size:
|
| 383 |
+
break
|
| 384 |
+
idx -= size
|
| 385 |
+
|
| 386 |
+
# Get a sample from the dataset
|
| 387 |
+
record = dataset.manifest.records[idx]
|
| 388 |
+
|
| 389 |
+
# Get the structure
|
| 390 |
+
try:
|
| 391 |
+
input_data = load_input(record, dataset.target_dir, dataset.msa_dir)
|
| 392 |
+
except Exception as e:
|
| 393 |
+
print(f"Failed to load input for {record.id} with error {e}. Skipping.")
|
| 394 |
+
return self.__getitem__(0)
|
| 395 |
+
|
| 396 |
+
# Tokenize structure
|
| 397 |
+
try:
|
| 398 |
+
tokenized = dataset.tokenizer.tokenize(input_data)
|
| 399 |
+
except Exception as e:
|
| 400 |
+
print(f"Tokenizer failed on {record.id} with error {e}. Skipping.")
|
| 401 |
+
return self.__getitem__(0)
|
| 402 |
+
|
| 403 |
+
# Compute crop
|
| 404 |
+
try:
|
| 405 |
+
if self.crop_validation and (self.max_tokens is not None):
|
| 406 |
+
tokenized = dataset.cropper.crop(
|
| 407 |
+
tokenized,
|
| 408 |
+
max_tokens=self.max_tokens,
|
| 409 |
+
random=self.random,
|
| 410 |
+
max_atoms=self.max_atoms,
|
| 411 |
+
)
|
| 412 |
+
except Exception as e:
|
| 413 |
+
print(f"Cropper failed on {record.id} with error {e}. Skipping.")
|
| 414 |
+
return self.__getitem__(0)
|
| 415 |
+
|
| 416 |
+
# Check if there are tokens
|
| 417 |
+
if len(tokenized.tokens) == 0:
|
| 418 |
+
msg = "No tokens in cropped structure."
|
| 419 |
+
raise ValueError(msg)
|
| 420 |
+
|
| 421 |
+
# Compute features
|
| 422 |
+
try:
|
| 423 |
+
pad_atoms = self.crop_validation and self.pad_to_max_atoms
|
| 424 |
+
pad_tokens = self.crop_validation and self.pad_to_max_tokens
|
| 425 |
+
|
| 426 |
+
features = dataset.featurizer.process(
|
| 427 |
+
tokenized,
|
| 428 |
+
training=False,
|
| 429 |
+
max_atoms=self.max_atoms if pad_atoms else None,
|
| 430 |
+
max_tokens=self.max_tokens if pad_tokens else None,
|
| 431 |
+
max_seqs=self.max_seqs,
|
| 432 |
+
pad_to_max_seqs=self.pad_to_max_seqs,
|
| 433 |
+
symmetries=self.symmetries,
|
| 434 |
+
atoms_per_window_queries=self.atoms_per_window_queries,
|
| 435 |
+
min_dist=self.min_dist,
|
| 436 |
+
max_dist=self.max_dist,
|
| 437 |
+
num_bins=self.num_bins,
|
| 438 |
+
compute_symmetries=self.return_symmetries,
|
| 439 |
+
binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
|
| 440 |
+
binder_pocket_cutoff=self.binder_pocket_cutoff,
|
| 441 |
+
binder_pocket_sampling_geometric_p=1.0, # this will only sample a single pocket token
|
| 442 |
+
only_ligand_binder_pocket=True,
|
| 443 |
+
)
|
| 444 |
+
except Exception as e:
|
| 445 |
+
print(f"Featurizer failed on {record.id} with error {e}. Skipping.")
|
| 446 |
+
return self.__getitem__(0)
|
| 447 |
+
|
| 448 |
+
return features
|
| 449 |
+
|
| 450 |
+
def __len__(self) -> int:
|
| 451 |
+
"""Get the length of the dataset.
|
| 452 |
+
|
| 453 |
+
Returns
|
| 454 |
+
-------
|
| 455 |
+
int
|
| 456 |
+
The length of the dataset.
|
| 457 |
+
|
| 458 |
+
"""
|
| 459 |
+
if self.overfit is not None:
|
| 460 |
+
length = sum(len(d.manifest.records[: self.overfit]) for d in self.datasets)
|
| 461 |
+
else:
|
| 462 |
+
length = sum(len(d.manifest.records) for d in self.datasets)
|
| 463 |
+
|
| 464 |
+
return length
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
class BoltzTrainingDataModule(pl.LightningDataModule):
|
| 468 |
+
"""DataModule for boltz."""
|
| 469 |
+
|
| 470 |
+
def __init__(self, cfg: DataConfig) -> None:
|
| 471 |
+
"""Initialize the DataModule.
|
| 472 |
+
|
| 473 |
+
Parameters
|
| 474 |
+
----------
|
| 475 |
+
config : DataConfig
|
| 476 |
+
The data configuration.
|
| 477 |
+
|
| 478 |
+
"""
|
| 479 |
+
super().__init__()
|
| 480 |
+
self.cfg = cfg
|
| 481 |
+
|
| 482 |
+
assert self.cfg.val_batch_size == 1, "Validation only works with batch size=1."
|
| 483 |
+
|
| 484 |
+
# Load symmetries
|
| 485 |
+
symmetries = get_symmetries(cfg.symmetries)
|
| 486 |
+
|
| 487 |
+
# Load datasets
|
| 488 |
+
train: list[Dataset] = []
|
| 489 |
+
val: list[Dataset] = []
|
| 490 |
+
|
| 491 |
+
for data_config in cfg.datasets:
|
| 492 |
+
# Set target_dir
|
| 493 |
+
target_dir = Path(data_config.target_dir)
|
| 494 |
+
msa_dir = Path(data_config.msa_dir)
|
| 495 |
+
|
| 496 |
+
# Load manifest
|
| 497 |
+
if data_config.manifest_path is not None:
|
| 498 |
+
path = Path(data_config.manifest_path)
|
| 499 |
+
else:
|
| 500 |
+
path = target_dir / "manifest.json"
|
| 501 |
+
manifest: Manifest = Manifest.load(path)
|
| 502 |
+
|
| 503 |
+
# Split records if given
|
| 504 |
+
if data_config.split is not None:
|
| 505 |
+
with Path(data_config.split).open("r") as f:
|
| 506 |
+
split = {x.lower() for x in f.read().splitlines()}
|
| 507 |
+
|
| 508 |
+
train_records = []
|
| 509 |
+
val_records = []
|
| 510 |
+
for record in manifest.records:
|
| 511 |
+
if record.id.lower() in split:
|
| 512 |
+
val_records.append(record)
|
| 513 |
+
else:
|
| 514 |
+
train_records.append(record)
|
| 515 |
+
else:
|
| 516 |
+
train_records = manifest.records
|
| 517 |
+
val_records = []
|
| 518 |
+
|
| 519 |
+
# Filter training records
|
| 520 |
+
train_records = [
|
| 521 |
+
record
|
| 522 |
+
for record in train_records
|
| 523 |
+
if all(f.filter(record) for f in cfg.filters)
|
| 524 |
+
]
|
| 525 |
+
# Filter training records
|
| 526 |
+
if data_config.filters is not None:
|
| 527 |
+
train_records = [
|
| 528 |
+
record
|
| 529 |
+
for record in train_records
|
| 530 |
+
if all(f.filter(record) for f in data_config.filters)
|
| 531 |
+
]
|
| 532 |
+
|
| 533 |
+
# Create train dataset
|
| 534 |
+
train_manifest = Manifest(train_records)
|
| 535 |
+
train.append(
|
| 536 |
+
Dataset(
|
| 537 |
+
target_dir,
|
| 538 |
+
msa_dir,
|
| 539 |
+
train_manifest,
|
| 540 |
+
data_config.prob,
|
| 541 |
+
data_config.sampler,
|
| 542 |
+
data_config.cropper,
|
| 543 |
+
cfg.tokenizer,
|
| 544 |
+
cfg.featurizer,
|
| 545 |
+
)
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# Create validation dataset
|
| 549 |
+
if val_records:
|
| 550 |
+
val_manifest = Manifest(val_records)
|
| 551 |
+
val.append(
|
| 552 |
+
Dataset(
|
| 553 |
+
target_dir,
|
| 554 |
+
msa_dir,
|
| 555 |
+
val_manifest,
|
| 556 |
+
data_config.prob,
|
| 557 |
+
data_config.sampler,
|
| 558 |
+
data_config.cropper,
|
| 559 |
+
cfg.tokenizer,
|
| 560 |
+
cfg.featurizer,
|
| 561 |
+
)
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
# Print dataset sizes
|
| 565 |
+
for dataset in train:
|
| 566 |
+
dataset: Dataset
|
| 567 |
+
print(f"Training dataset size: {len(dataset.manifest.records)}")
|
| 568 |
+
|
| 569 |
+
for dataset in val:
|
| 570 |
+
dataset: Dataset
|
| 571 |
+
print(f"Validation dataset size: {len(dataset.manifest.records)}")
|
| 572 |
+
|
| 573 |
+
# Create wrapper datasets
|
| 574 |
+
self._train_set = TrainingDataset(
|
| 575 |
+
datasets=train,
|
| 576 |
+
samples_per_epoch=cfg.samples_per_epoch,
|
| 577 |
+
max_atoms=cfg.max_atoms,
|
| 578 |
+
max_tokens=cfg.max_tokens,
|
| 579 |
+
max_seqs=cfg.max_seqs,
|
| 580 |
+
pad_to_max_atoms=cfg.pad_to_max_atoms,
|
| 581 |
+
pad_to_max_tokens=cfg.pad_to_max_tokens,
|
| 582 |
+
pad_to_max_seqs=cfg.pad_to_max_seqs,
|
| 583 |
+
symmetries=symmetries,
|
| 584 |
+
atoms_per_window_queries=cfg.atoms_per_window_queries,
|
| 585 |
+
min_dist=cfg.min_dist,
|
| 586 |
+
max_dist=cfg.max_dist,
|
| 587 |
+
num_bins=cfg.num_bins,
|
| 588 |
+
overfit=cfg.overfit,
|
| 589 |
+
binder_pocket_conditioned_prop=cfg.train_binder_pocket_conditioned_prop,
|
| 590 |
+
binder_pocket_cutoff=cfg.binder_pocket_cutoff,
|
| 591 |
+
binder_pocket_sampling_geometric_p=cfg.binder_pocket_sampling_geometric_p,
|
| 592 |
+
return_symmetries=cfg.return_train_symmetries,
|
| 593 |
+
)
|
| 594 |
+
self._val_set = ValidationDataset(
|
| 595 |
+
datasets=train if cfg.overfit is not None else val,
|
| 596 |
+
seed=cfg.random_seed,
|
| 597 |
+
max_atoms=cfg.max_atoms,
|
| 598 |
+
max_tokens=cfg.max_tokens,
|
| 599 |
+
max_seqs=cfg.max_seqs,
|
| 600 |
+
pad_to_max_atoms=cfg.pad_to_max_atoms,
|
| 601 |
+
pad_to_max_tokens=cfg.pad_to_max_tokens,
|
| 602 |
+
pad_to_max_seqs=cfg.pad_to_max_seqs,
|
| 603 |
+
symmetries=symmetries,
|
| 604 |
+
atoms_per_window_queries=cfg.atoms_per_window_queries,
|
| 605 |
+
min_dist=cfg.min_dist,
|
| 606 |
+
max_dist=cfg.max_dist,
|
| 607 |
+
num_bins=cfg.num_bins,
|
| 608 |
+
overfit=cfg.overfit,
|
| 609 |
+
crop_validation=cfg.crop_validation,
|
| 610 |
+
return_symmetries=cfg.return_val_symmetries,
|
| 611 |
+
binder_pocket_conditioned_prop=cfg.val_binder_pocket_conditioned_prop,
|
| 612 |
+
binder_pocket_cutoff=cfg.binder_pocket_cutoff,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
| 616 |
+
"""Run the setup for the DataModule.
|
| 617 |
+
|
| 618 |
+
Parameters
|
| 619 |
+
----------
|
| 620 |
+
stage : str, optional
|
| 621 |
+
The stage, one of 'fit', 'validate', 'test'.
|
| 622 |
+
|
| 623 |
+
"""
|
| 624 |
+
return
|
| 625 |
+
|
| 626 |
+
def train_dataloader(self) -> DataLoader:
|
| 627 |
+
"""Get the training dataloader.
|
| 628 |
+
|
| 629 |
+
Returns
|
| 630 |
+
-------
|
| 631 |
+
DataLoader
|
| 632 |
+
The training dataloader.
|
| 633 |
+
|
| 634 |
+
"""
|
| 635 |
+
return DataLoader(
|
| 636 |
+
self._train_set,
|
| 637 |
+
batch_size=self.cfg.batch_size,
|
| 638 |
+
num_workers=self.cfg.num_workers,
|
| 639 |
+
pin_memory=self.cfg.pin_memory,
|
| 640 |
+
shuffle=False,
|
| 641 |
+
collate_fn=collate,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
def val_dataloader(self) -> DataLoader:
|
| 645 |
+
"""Get the validation dataloader.
|
| 646 |
+
|
| 647 |
+
Returns
|
| 648 |
+
-------
|
| 649 |
+
DataLoader
|
| 650 |
+
The validation dataloader.
|
| 651 |
+
|
| 652 |
+
"""
|
| 653 |
+
return DataLoader(
|
| 654 |
+
self._val_set,
|
| 655 |
+
batch_size=self.cfg.val_batch_size,
|
| 656 |
+
num_workers=self.cfg.num_workers,
|
| 657 |
+
pin_memory=self.cfg.pin_memory,
|
| 658 |
+
shuffle=False,
|
| 659 |
+
collate_fn=collate,
|
| 660 |
+
)
|
protify/FastPLMs/boltz/src/boltz/data/mol.py
ADDED
|
@@ -0,0 +1,900 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import pickle
|
| 3 |
+
import random
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from rdkit.Chem import Mol
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from boltz.data import const
|
| 12 |
+
from boltz.data.pad import pad_dim
|
| 13 |
+
from boltz.model.loss.confidence import lddt_dist
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_molecules(moldir: str, molecules: list[str]) -> dict[str, Mol]:
|
| 17 |
+
"""Load the given input data.
|
| 18 |
+
|
| 19 |
+
Parameters
|
| 20 |
+
----------
|
| 21 |
+
moldir : str
|
| 22 |
+
The path to the molecules directory.
|
| 23 |
+
molecules : list[str]
|
| 24 |
+
The molecules to load.
|
| 25 |
+
|
| 26 |
+
Returns
|
| 27 |
+
-------
|
| 28 |
+
dict[str, Mol]
|
| 29 |
+
The loaded molecules.
|
| 30 |
+
"""
|
| 31 |
+
loaded_mols = {}
|
| 32 |
+
for molecule in molecules:
|
| 33 |
+
path = Path(moldir) / f"{molecule}.pkl"
|
| 34 |
+
if not path.exists():
|
| 35 |
+
msg = f"CCD component {molecule} not found!"
|
| 36 |
+
raise ValueError(msg)
|
| 37 |
+
with path.open("rb") as f:
|
| 38 |
+
loaded_mols[molecule] = pickle.load(f) # noqa: S301
|
| 39 |
+
return loaded_mols
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_canonicals(moldir: str) -> dict[str, Mol]:
|
| 43 |
+
"""Load the given input data.
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
moldir : str
|
| 48 |
+
The molecules to load.
|
| 49 |
+
|
| 50 |
+
Returns
|
| 51 |
+
-------
|
| 52 |
+
dict[str, Mol]
|
| 53 |
+
The loaded molecules.
|
| 54 |
+
|
| 55 |
+
"""
|
| 56 |
+
return load_molecules(moldir, const.canonical_tokens)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_all_molecules(moldir: str) -> dict[str, Mol]:
|
| 60 |
+
"""Load the given input data.
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
moldir : str
|
| 65 |
+
The path to the molecules directory.
|
| 66 |
+
molecules : list[str]
|
| 67 |
+
The molecules to load.
|
| 68 |
+
|
| 69 |
+
Returns
|
| 70 |
+
-------
|
| 71 |
+
dict[str, Mol]
|
| 72 |
+
The loaded molecules.
|
| 73 |
+
|
| 74 |
+
"""
|
| 75 |
+
loaded_mols = {}
|
| 76 |
+
files = list(Path(moldir).glob("*.pkl"))
|
| 77 |
+
for path in tqdm(files, total=len(files), desc="Loading molecules", leave=False):
|
| 78 |
+
mol_name = path.stem
|
| 79 |
+
with path.open("rb") as f:
|
| 80 |
+
loaded_mols[mol_name] = pickle.load(f) # noqa: S301
|
| 81 |
+
return loaded_mols
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_symmetries(mols: dict[str, Mol]) -> dict: # noqa: PLR0912
|
| 85 |
+
"""Create a dictionary for the ligand symmetries.
|
| 86 |
+
|
| 87 |
+
Parameters
|
| 88 |
+
----------
|
| 89 |
+
path : str
|
| 90 |
+
The path to the ligand symmetries.
|
| 91 |
+
|
| 92 |
+
Returns
|
| 93 |
+
-------
|
| 94 |
+
dict
|
| 95 |
+
The ligand symmetries.
|
| 96 |
+
|
| 97 |
+
"""
|
| 98 |
+
symmetries = {}
|
| 99 |
+
for key, mol in mols.items():
|
| 100 |
+
try:
|
| 101 |
+
sym = pickle.loads(bytes.fromhex(mol.GetProp("symmetries"))) # noqa: S301
|
| 102 |
+
|
| 103 |
+
if mol.HasProp("pb_edge_index"):
|
| 104 |
+
edge_index = pickle.loads(
|
| 105 |
+
bytes.fromhex(mol.GetProp("pb_edge_index"))
|
| 106 |
+
).astype(np.int64) # noqa: S301
|
| 107 |
+
lower_bounds = pickle.loads(
|
| 108 |
+
bytes.fromhex(mol.GetProp("pb_lower_bounds"))
|
| 109 |
+
) # noqa: S301
|
| 110 |
+
upper_bounds = pickle.loads(
|
| 111 |
+
bytes.fromhex(mol.GetProp("pb_upper_bounds"))
|
| 112 |
+
) # noqa: S301
|
| 113 |
+
bond_mask = pickle.loads(bytes.fromhex(mol.GetProp("pb_bond_mask"))) # noqa: S301
|
| 114 |
+
angle_mask = pickle.loads(bytes.fromhex(mol.GetProp("pb_angle_mask"))) # noqa: S301
|
| 115 |
+
else:
|
| 116 |
+
edge_index = np.empty((2, 0), dtype=np.int64)
|
| 117 |
+
lower_bounds = np.array([], dtype=np.float32)
|
| 118 |
+
upper_bounds = np.array([], dtype=np.float32)
|
| 119 |
+
bond_mask = np.array([], dtype=np.float32)
|
| 120 |
+
angle_mask = np.array([], dtype=np.float32)
|
| 121 |
+
|
| 122 |
+
if mol.HasProp("chiral_atom_index"):
|
| 123 |
+
chiral_atom_index = pickle.loads(
|
| 124 |
+
bytes.fromhex(mol.GetProp("chiral_atom_index"))
|
| 125 |
+
).astype(np.int64)
|
| 126 |
+
chiral_check_mask = pickle.loads(
|
| 127 |
+
bytes.fromhex(mol.GetProp("chiral_check_mask"))
|
| 128 |
+
).astype(np.int64)
|
| 129 |
+
chiral_atom_orientations = pickle.loads(
|
| 130 |
+
bytes.fromhex(mol.GetProp("chiral_atom_orientations"))
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
chiral_atom_index = np.empty((4, 0), dtype=np.int64)
|
| 134 |
+
chiral_check_mask = np.array([], dtype=bool)
|
| 135 |
+
chiral_atom_orientations = np.array([], dtype=bool)
|
| 136 |
+
|
| 137 |
+
if mol.HasProp("stereo_bond_index"):
|
| 138 |
+
stereo_bond_index = pickle.loads(
|
| 139 |
+
bytes.fromhex(mol.GetProp("stereo_bond_index"))
|
| 140 |
+
).astype(np.int64)
|
| 141 |
+
stereo_check_mask = pickle.loads(
|
| 142 |
+
bytes.fromhex(mol.GetProp("stereo_check_mask"))
|
| 143 |
+
).astype(np.int64)
|
| 144 |
+
stereo_bond_orientations = pickle.loads(
|
| 145 |
+
bytes.fromhex(mol.GetProp("stereo_bond_orientations"))
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
stereo_bond_index = np.empty((4, 0), dtype=np.int64)
|
| 149 |
+
stereo_check_mask = np.array([], dtype=bool)
|
| 150 |
+
stereo_bond_orientations = np.array([], dtype=bool)
|
| 151 |
+
|
| 152 |
+
if mol.HasProp("aromatic_5_ring_index"):
|
| 153 |
+
aromatic_5_ring_index = pickle.loads(
|
| 154 |
+
bytes.fromhex(mol.GetProp("aromatic_5_ring_index"))
|
| 155 |
+
).astype(np.int64)
|
| 156 |
+
else:
|
| 157 |
+
aromatic_5_ring_index = np.empty((5, 0), dtype=np.int64)
|
| 158 |
+
if mol.HasProp("aromatic_6_ring_index"):
|
| 159 |
+
aromatic_6_ring_index = pickle.loads(
|
| 160 |
+
bytes.fromhex(mol.GetProp("aromatic_6_ring_index"))
|
| 161 |
+
).astype(np.int64)
|
| 162 |
+
else:
|
| 163 |
+
aromatic_6_ring_index = np.empty((6, 0), dtype=np.int64)
|
| 164 |
+
if mol.HasProp("planar_double_bond_index"):
|
| 165 |
+
planar_double_bond_index = pickle.loads(
|
| 166 |
+
bytes.fromhex(mol.GetProp("planar_double_bond_index"))
|
| 167 |
+
).astype(np.int64)
|
| 168 |
+
else:
|
| 169 |
+
planar_double_bond_index = np.empty((6, 0), dtype=np.int64)
|
| 170 |
+
|
| 171 |
+
atom_names = [atom.GetProp("name") for atom in mol.GetAtoms()]
|
| 172 |
+
symmetries[key] = (
|
| 173 |
+
sym,
|
| 174 |
+
atom_names,
|
| 175 |
+
edge_index,
|
| 176 |
+
lower_bounds,
|
| 177 |
+
upper_bounds,
|
| 178 |
+
bond_mask,
|
| 179 |
+
angle_mask,
|
| 180 |
+
chiral_atom_index,
|
| 181 |
+
chiral_check_mask,
|
| 182 |
+
chiral_atom_orientations,
|
| 183 |
+
stereo_bond_index,
|
| 184 |
+
stereo_check_mask,
|
| 185 |
+
stereo_bond_orientations,
|
| 186 |
+
aromatic_5_ring_index,
|
| 187 |
+
aromatic_6_ring_index,
|
| 188 |
+
planar_double_bond_index,
|
| 189 |
+
)
|
| 190 |
+
except Exception as e: # noqa: BLE001, PERF203, S110
|
| 191 |
+
pass
|
| 192 |
+
|
| 193 |
+
return symmetries
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def compute_symmetry_idx_dictionary(data):
|
| 197 |
+
# Compute the symmetry index dictionary
|
| 198 |
+
total_count = 0
|
| 199 |
+
all_coords = []
|
| 200 |
+
for i, chain in enumerate(data.chains):
|
| 201 |
+
chain.start_idx = total_count
|
| 202 |
+
for j, token in enumerate(chain.tokens):
|
| 203 |
+
token.start_idx = total_count - chain.start_idx
|
| 204 |
+
all_coords.extend(
|
| 205 |
+
[[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
|
| 206 |
+
)
|
| 207 |
+
total_count += len(token.atoms)
|
| 208 |
+
return all_coords
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def get_current_idx_list(data):
|
| 212 |
+
idx = []
|
| 213 |
+
for chain in data.chains:
|
| 214 |
+
if chain.in_crop:
|
| 215 |
+
for token in chain.tokens:
|
| 216 |
+
if token.in_crop:
|
| 217 |
+
idx.extend(
|
| 218 |
+
[
|
| 219 |
+
chain.start_idx + token.start_idx + i
|
| 220 |
+
for i in range(len(token.atoms))
|
| 221 |
+
]
|
| 222 |
+
)
|
| 223 |
+
return idx
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def all_different_after_swap(l):
|
| 227 |
+
final = [s[-1] for s in l]
|
| 228 |
+
return len(final) == len(set(final))
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def minimum_lddt_symmetry_coords(
|
| 232 |
+
coords: torch.Tensor,
|
| 233 |
+
feats: dict,
|
| 234 |
+
index_batch: int,
|
| 235 |
+
):
|
| 236 |
+
all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
|
| 237 |
+
all_resolved_mask = (
|
| 238 |
+
feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
|
| 239 |
+
)
|
| 240 |
+
crop_to_all_atom_map = (
|
| 241 |
+
feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
|
| 242 |
+
)
|
| 243 |
+
chain_symmetries = feats["chain_swaps"][index_batch]
|
| 244 |
+
amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
|
| 245 |
+
ligand_symmetries = feats["ligand_symmetries"][index_batch]
|
| 246 |
+
|
| 247 |
+
dmat_predicted = torch.cdist(
|
| 248 |
+
coords[:, : len(crop_to_all_atom_map)], coords[:, : len(crop_to_all_atom_map)]
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# Check best symmetry on chain swap
|
| 252 |
+
best_true_coords = all_coords[:, crop_to_all_atom_map].clone()
|
| 253 |
+
best_true_resolved_mask = all_resolved_mask[crop_to_all_atom_map].clone()
|
| 254 |
+
best_lddt = -1.0
|
| 255 |
+
for c in chain_symmetries:
|
| 256 |
+
true_all_coords = all_coords.clone()
|
| 257 |
+
true_all_resolved_mask = all_resolved_mask.clone()
|
| 258 |
+
for start1, end1, start2, end2, chainidx1, chainidx2 in c:
|
| 259 |
+
true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
|
| 260 |
+
true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
|
| 261 |
+
true_coords = true_all_coords[:, crop_to_all_atom_map]
|
| 262 |
+
true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
|
| 263 |
+
dmat_true = torch.cdist(true_coords, true_coords)
|
| 264 |
+
pair_mask = (
|
| 265 |
+
true_resolved_mask[:, None]
|
| 266 |
+
* true_resolved_mask[None, :]
|
| 267 |
+
* (1 - torch.eye(len(true_resolved_mask))).to(true_resolved_mask)
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
lddt = lddt_dist(
|
| 271 |
+
dmat_predicted, dmat_true, pair_mask, cutoff=15.0, per_atom=False
|
| 272 |
+
)[0]
|
| 273 |
+
lddt = lddt.item()
|
| 274 |
+
|
| 275 |
+
if lddt > best_lddt and torch.sum(true_resolved_mask) > 3:
|
| 276 |
+
best_lddt = lddt
|
| 277 |
+
best_true_coords = true_coords
|
| 278 |
+
best_true_resolved_mask = true_resolved_mask
|
| 279 |
+
|
| 280 |
+
# atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
|
| 281 |
+
true_coords = best_true_coords.clone()
|
| 282 |
+
true_resolved_mask = best_true_resolved_mask.clone()
|
| 283 |
+
for symmetric_amino_or_lig in amino_acids_symmetries + ligand_symmetries:
|
| 284 |
+
best_lddt_improvement = 0.0
|
| 285 |
+
|
| 286 |
+
indices = set()
|
| 287 |
+
for c in symmetric_amino_or_lig:
|
| 288 |
+
for i, j in c:
|
| 289 |
+
indices.add(i)
|
| 290 |
+
indices = sorted(list(indices))
|
| 291 |
+
indices = torch.from_numpy(np.asarray(indices)).to(true_coords.device).long()
|
| 292 |
+
pred_coords_subset = coords[:, : len(crop_to_all_atom_map)][:, indices]
|
| 293 |
+
sub_dmat_pred = torch.cdist(
|
| 294 |
+
coords[:, : len(crop_to_all_atom_map)], pred_coords_subset
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
for c in symmetric_amino_or_lig:
|
| 298 |
+
# starting from greedy best, try to swap the atoms
|
| 299 |
+
new_true_coords = true_coords.clone()
|
| 300 |
+
new_true_resolved_mask = true_resolved_mask.clone()
|
| 301 |
+
for i, j in c:
|
| 302 |
+
new_true_coords[:, i] = true_coords[:, j]
|
| 303 |
+
new_true_resolved_mask[i] = true_resolved_mask[j]
|
| 304 |
+
|
| 305 |
+
true_coords_subset = true_coords[:, indices]
|
| 306 |
+
new_true_coords_subset = new_true_coords[:, indices]
|
| 307 |
+
|
| 308 |
+
sub_dmat_true = torch.cdist(true_coords, true_coords_subset)
|
| 309 |
+
sub_dmat_new_true = torch.cdist(new_true_coords, new_true_coords_subset)
|
| 310 |
+
|
| 311 |
+
sub_true_pair_lddt = (
|
| 312 |
+
true_resolved_mask[:, None] * true_resolved_mask[None, indices]
|
| 313 |
+
)
|
| 314 |
+
sub_true_pair_lddt[indices] = (
|
| 315 |
+
sub_true_pair_lddt[indices]
|
| 316 |
+
* (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
sub_new_true_pair_lddt = (
|
| 320 |
+
new_true_resolved_mask[:, None] * new_true_resolved_mask[None, indices]
|
| 321 |
+
)
|
| 322 |
+
sub_new_true_pair_lddt[indices] = (
|
| 323 |
+
sub_new_true_pair_lddt[indices]
|
| 324 |
+
* (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
lddt, total = lddt_dist(
|
| 328 |
+
sub_dmat_pred,
|
| 329 |
+
sub_dmat_true,
|
| 330 |
+
sub_true_pair_lddt,
|
| 331 |
+
cutoff=15.0,
|
| 332 |
+
per_atom=False,
|
| 333 |
+
)
|
| 334 |
+
new_lddt, new_total = lddt_dist(
|
| 335 |
+
sub_dmat_pred,
|
| 336 |
+
sub_dmat_new_true,
|
| 337 |
+
sub_new_true_pair_lddt,
|
| 338 |
+
cutoff=15.0,
|
| 339 |
+
per_atom=False,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
lddt_improvement = new_lddt - lddt
|
| 343 |
+
|
| 344 |
+
if lddt_improvement > best_lddt_improvement:
|
| 345 |
+
best_true_coords = new_true_coords
|
| 346 |
+
best_true_resolved_mask = new_true_resolved_mask
|
| 347 |
+
best_lddt_improvement = lddt_improvement
|
| 348 |
+
|
| 349 |
+
# greedily update best coordinates after each amino acid
|
| 350 |
+
true_coords = best_true_coords.clone()
|
| 351 |
+
true_resolved_mask = best_true_resolved_mask.clone()
|
| 352 |
+
|
| 353 |
+
# Recomputing alignment
|
| 354 |
+
true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
|
| 355 |
+
true_resolved_mask = pad_dim(
|
| 356 |
+
true_resolved_mask,
|
| 357 |
+
0,
|
| 358 |
+
coords.shape[1] - true_resolved_mask.shape[0],
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
return true_coords, true_resolved_mask.unsqueeze(0)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def compute_single_distogram_loss(pred, target, mask):
|
| 365 |
+
# Compute the distogram loss
|
| 366 |
+
errors = -1 * torch.sum(
|
| 367 |
+
target * torch.nn.functional.log_softmax(pred, dim=-1),
|
| 368 |
+
dim=-1,
|
| 369 |
+
)
|
| 370 |
+
denom = 1e-5 + torch.sum(mask, dim=(-1, -2))
|
| 371 |
+
mean = errors * mask
|
| 372 |
+
mean = torch.sum(mean, dim=-1)
|
| 373 |
+
mean = mean / denom[..., None]
|
| 374 |
+
batch_loss = torch.sum(mean, dim=-1)
|
| 375 |
+
global_loss = torch.mean(batch_loss)
|
| 376 |
+
return global_loss
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def minimum_lddt_symmetry_dist(
|
| 380 |
+
pred_distogram: torch.Tensor,
|
| 381 |
+
feats: dict,
|
| 382 |
+
index_batch: int,
|
| 383 |
+
):
|
| 384 |
+
# Note: for now only ligand symmetries are resolved
|
| 385 |
+
|
| 386 |
+
disto_target = feats["disto_target"][index_batch]
|
| 387 |
+
mask = feats["token_disto_mask"][index_batch]
|
| 388 |
+
mask = mask[None, :] * mask[:, None]
|
| 389 |
+
mask = mask * (1 - torch.eye(mask.shape[1])).to(disto_target)
|
| 390 |
+
|
| 391 |
+
coords = feats["coords"][index_batch]
|
| 392 |
+
|
| 393 |
+
ligand_symmetries = feats["ligand_symmetries"][index_batch]
|
| 394 |
+
atom_to_token_map = feats["atom_to_token"][index_batch].argmax(dim=-1)
|
| 395 |
+
|
| 396 |
+
# atom symmetries, resolved greedily without recomputing alignment
|
| 397 |
+
for symmetric_amino_or_lig in ligand_symmetries:
|
| 398 |
+
best_c, best_disto, best_loss_improvement = None, None, 0.0
|
| 399 |
+
for c in symmetric_amino_or_lig:
|
| 400 |
+
# starting from greedy best, try to swap the atoms
|
| 401 |
+
new_disto_target = disto_target.clone()
|
| 402 |
+
indices = []
|
| 403 |
+
|
| 404 |
+
# fix the distogram by replacing first the columns then the rows
|
| 405 |
+
disto_temp = new_disto_target.clone()
|
| 406 |
+
for i, j in c:
|
| 407 |
+
new_disto_target[:, atom_to_token_map[i]] = disto_temp[
|
| 408 |
+
:, atom_to_token_map[j]
|
| 409 |
+
]
|
| 410 |
+
indices.append(atom_to_token_map[i].item())
|
| 411 |
+
disto_temp = new_disto_target.clone()
|
| 412 |
+
for i, j in c:
|
| 413 |
+
new_disto_target[atom_to_token_map[i], :] = disto_temp[
|
| 414 |
+
atom_to_token_map[j], :
|
| 415 |
+
]
|
| 416 |
+
|
| 417 |
+
indices = (
|
| 418 |
+
torch.from_numpy(np.asarray(indices)).to(disto_target.device).long()
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
pred_distogram_subset = pred_distogram[:, indices]
|
| 422 |
+
disto_target_subset = disto_target[:, indices]
|
| 423 |
+
new_disto_target_subset = new_disto_target[:, indices]
|
| 424 |
+
mask_subset = mask[:, indices]
|
| 425 |
+
|
| 426 |
+
loss = compute_single_distogram_loss(
|
| 427 |
+
pred_distogram_subset, disto_target_subset, mask_subset
|
| 428 |
+
)
|
| 429 |
+
new_loss = compute_single_distogram_loss(
|
| 430 |
+
pred_distogram_subset, new_disto_target_subset, mask_subset
|
| 431 |
+
)
|
| 432 |
+
loss_improvement = (loss - new_loss) * len(indices)
|
| 433 |
+
|
| 434 |
+
if loss_improvement > best_loss_improvement:
|
| 435 |
+
best_c = c
|
| 436 |
+
best_disto = new_disto_target
|
| 437 |
+
best_loss_improvement = loss_improvement
|
| 438 |
+
|
| 439 |
+
# greedily update best coordinates after each ligand
|
| 440 |
+
if best_loss_improvement > 0:
|
| 441 |
+
disto_target = best_disto.clone()
|
| 442 |
+
old_coords = coords.clone()
|
| 443 |
+
for i, j in best_c:
|
| 444 |
+
coords[:, i] = old_coords[:, j]
|
| 445 |
+
|
| 446 |
+
# update features to be used in diffusion and in distogram loss
|
| 447 |
+
feats["disto_target"][index_batch] = disto_target
|
| 448 |
+
feats["coords"][index_batch] = coords
|
| 449 |
+
return
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def compute_all_coords_mask(structure):
|
| 453 |
+
# Compute all coords, crop mask and add start_idx to structure
|
| 454 |
+
total_count = 0
|
| 455 |
+
all_coords = []
|
| 456 |
+
all_coords_crop_mask = []
|
| 457 |
+
all_resolved_mask = []
|
| 458 |
+
for i, chain in enumerate(structure.chains):
|
| 459 |
+
chain.start_idx = total_count
|
| 460 |
+
for j, token in enumerate(chain.tokens):
|
| 461 |
+
token.start_idx = total_count - chain.start_idx
|
| 462 |
+
all_coords.extend(
|
| 463 |
+
[[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
|
| 464 |
+
)
|
| 465 |
+
all_coords_crop_mask.extend(
|
| 466 |
+
[token.in_crop for _ in range(len(token.atoms))]
|
| 467 |
+
)
|
| 468 |
+
all_resolved_mask.extend(
|
| 469 |
+
[token.is_present for _ in range(len(token.atoms))]
|
| 470 |
+
)
|
| 471 |
+
total_count += len(token.atoms)
|
| 472 |
+
if len(all_coords_crop_mask) != len(all_resolved_mask):
|
| 473 |
+
pass
|
| 474 |
+
return all_coords, all_coords_crop_mask, all_resolved_mask
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def get_chain_symmetries(cropped, max_n_symmetries=100):
|
| 478 |
+
# get all coordinates and resolved mask
|
| 479 |
+
structure = cropped.structure
|
| 480 |
+
all_coords = []
|
| 481 |
+
all_resolved_mask = []
|
| 482 |
+
original_atom_idx = []
|
| 483 |
+
chain_atom_idx = []
|
| 484 |
+
chain_atom_num = []
|
| 485 |
+
chain_in_crop = []
|
| 486 |
+
chain_asym_id = []
|
| 487 |
+
new_atom_idx = 0
|
| 488 |
+
|
| 489 |
+
for chain in structure.chains:
|
| 490 |
+
atom_idx, atom_num = (
|
| 491 |
+
chain["atom_idx"], # Global index of first atom in the chain
|
| 492 |
+
chain["atom_num"], # Number of atoms in the chain
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# compute coordinates and resolved mask
|
| 496 |
+
resolved_mask = structure.atoms["is_present"][
|
| 497 |
+
atom_idx : atom_idx + atom_num
|
| 498 |
+
] # Whether each atom in the chain is actually resolved
|
| 499 |
+
|
| 500 |
+
# ensemble_atom_starts = [structure.ensemble[idx]["atom_coord_idx"] for idx in cropped.ensemble_ref_idxs]
|
| 501 |
+
# coords = np.array(
|
| 502 |
+
# [structure.coords[ensemble_atom_start + atom_idx: ensemble_atom_start + atom_idx + atom_num]["coords"] for
|
| 503 |
+
# ensemble_atom_start in ensemble_atom_starts])
|
| 504 |
+
|
| 505 |
+
coords = structure.atoms["coords"][atom_idx : atom_idx + atom_num]
|
| 506 |
+
|
| 507 |
+
in_crop = False
|
| 508 |
+
for token in cropped.tokens:
|
| 509 |
+
if token["asym_id"] == chain["asym_id"]:
|
| 510 |
+
in_crop = True
|
| 511 |
+
break
|
| 512 |
+
|
| 513 |
+
all_coords.append(coords)
|
| 514 |
+
all_resolved_mask.append(resolved_mask)
|
| 515 |
+
original_atom_idx.append(atom_idx)
|
| 516 |
+
chain_atom_idx.append(new_atom_idx)
|
| 517 |
+
chain_atom_num.append(atom_num)
|
| 518 |
+
chain_in_crop.append(in_crop)
|
| 519 |
+
chain_asym_id.append(chain["asym_id"])
|
| 520 |
+
|
| 521 |
+
new_atom_idx += atom_num
|
| 522 |
+
|
| 523 |
+
all_coords = np.concatenate(all_coords, axis=0)
|
| 524 |
+
# Compute backmapping from token to all coords
|
| 525 |
+
crop_to_all_atom_map = []
|
| 526 |
+
for token in cropped.tokens:
|
| 527 |
+
chain_idx = chain_asym_id.index(token["asym_id"])
|
| 528 |
+
start = (
|
| 529 |
+
chain_atom_idx[chain_idx] - original_atom_idx[chain_idx] + token["atom_idx"]
|
| 530 |
+
)
|
| 531 |
+
crop_to_all_atom_map.append(np.arange(start, start + token["atom_num"]))
|
| 532 |
+
crop_to_all_atom_map = np.concatenate(crop_to_all_atom_map, axis=0)
|
| 533 |
+
|
| 534 |
+
# Compute the connections edge index for covalent bonds
|
| 535 |
+
all_atom_to_crop_map = np.zeros(all_coords.shape[0], dtype=np.int64)
|
| 536 |
+
all_atom_to_crop_map[crop_to_all_atom_map.astype(np.int64)] = np.arange(
|
| 537 |
+
crop_to_all_atom_map.shape[0]
|
| 538 |
+
)
|
| 539 |
+
connections_edge_index = []
|
| 540 |
+
for connection in structure.bonds:
|
| 541 |
+
if (connection["chain_1"] == connection["chain_2"]) and (
|
| 542 |
+
connection["res_1"] == connection["res_2"]
|
| 543 |
+
):
|
| 544 |
+
continue
|
| 545 |
+
connections_edge_index.append([connection["atom_1"], connection["atom_2"]])
|
| 546 |
+
if len(connections_edge_index) > 0:
|
| 547 |
+
connections_edge_index = np.array(connections_edge_index, dtype=np.int64).T
|
| 548 |
+
connections_edge_index = all_atom_to_crop_map[connections_edge_index]
|
| 549 |
+
else:
|
| 550 |
+
connections_edge_index = np.empty((2, 0))
|
| 551 |
+
|
| 552 |
+
# Compute the symmetries between chains
|
| 553 |
+
symmetries = []
|
| 554 |
+
swaps = []
|
| 555 |
+
for i, chain in enumerate(structure.chains):
|
| 556 |
+
start = chain_atom_idx[i]
|
| 557 |
+
end = start + chain_atom_num[i]
|
| 558 |
+
|
| 559 |
+
if chain_in_crop[i]:
|
| 560 |
+
possible_swaps = []
|
| 561 |
+
for j, chain2 in enumerate(structure.chains):
|
| 562 |
+
start2 = chain_atom_idx[j]
|
| 563 |
+
end2 = start2 + chain_atom_num[j]
|
| 564 |
+
if (
|
| 565 |
+
chain["entity_id"] == chain2["entity_id"]
|
| 566 |
+
and end - start == end2 - start2
|
| 567 |
+
):
|
| 568 |
+
possible_swaps.append((start, end, start2, end2, i, j))
|
| 569 |
+
swaps.append(possible_swaps)
|
| 570 |
+
|
| 571 |
+
found = False
|
| 572 |
+
for symmetry_idx, symmetry in enumerate(symmetries):
|
| 573 |
+
j = symmetry[0][0]
|
| 574 |
+
chain2 = structure.chains[j]
|
| 575 |
+
start2 = chain_atom_idx[j]
|
| 576 |
+
end2 = start2 + chain_atom_num[j]
|
| 577 |
+
if (
|
| 578 |
+
chain["entity_id"] == chain2["entity_id"]
|
| 579 |
+
and end - start == end2 - start2
|
| 580 |
+
):
|
| 581 |
+
symmetries[symmetry_idx].append(
|
| 582 |
+
(i, start, end, chain_in_crop[i], chain["mol_type"])
|
| 583 |
+
)
|
| 584 |
+
found = True
|
| 585 |
+
if not found:
|
| 586 |
+
symmetries.append([(i, start, end, chain_in_crop[i], chain["mol_type"])])
|
| 587 |
+
|
| 588 |
+
combinations = itertools.product(*swaps)
|
| 589 |
+
# to avoid combinatorial explosion, bound the number of combinations even considered
|
| 590 |
+
combinations = list(itertools.islice(combinations, max_n_symmetries * 10))
|
| 591 |
+
# filter for all chains getting a different assignment
|
| 592 |
+
combinations = [c for c in combinations if all_different_after_swap(c)]
|
| 593 |
+
|
| 594 |
+
if len(combinations) > max_n_symmetries:
|
| 595 |
+
combinations = random.sample(combinations, max_n_symmetries)
|
| 596 |
+
|
| 597 |
+
if len(combinations) == 0:
|
| 598 |
+
combinations.append([])
|
| 599 |
+
|
| 600 |
+
for i in range(len(symmetries) - 1, -1, -1):
|
| 601 |
+
if not any(chain[3] for chain in symmetries[i]):
|
| 602 |
+
symmetries.pop(i)
|
| 603 |
+
|
| 604 |
+
features = {}
|
| 605 |
+
features["all_coords"] = torch.Tensor(all_coords) # axis=1 with ensemble
|
| 606 |
+
|
| 607 |
+
features["all_resolved_mask"] = torch.Tensor(
|
| 608 |
+
np.concatenate(all_resolved_mask, axis=0)
|
| 609 |
+
)
|
| 610 |
+
features["crop_to_all_atom_map"] = torch.Tensor(crop_to_all_atom_map)
|
| 611 |
+
features["chain_symmetries"] = symmetries
|
| 612 |
+
features["connections_edge_index"] = torch.tensor(connections_edge_index)
|
| 613 |
+
features["chain_swaps"] = combinations
|
| 614 |
+
|
| 615 |
+
return features
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def get_amino_acids_symmetries(cropped):
|
| 619 |
+
# Compute standard amino-acids symmetries
|
| 620 |
+
swaps = []
|
| 621 |
+
start_index_crop = 0
|
| 622 |
+
for token in cropped.tokens:
|
| 623 |
+
symmetries = const.ref_symmetries.get(const.tokens[token["res_type"]], [])
|
| 624 |
+
if len(symmetries) > 0:
|
| 625 |
+
residue_swaps = []
|
| 626 |
+
for sym in symmetries:
|
| 627 |
+
sym_new_idx = [
|
| 628 |
+
(i + start_index_crop, j + start_index_crop) for i, j in sym
|
| 629 |
+
]
|
| 630 |
+
residue_swaps.append(sym_new_idx)
|
| 631 |
+
swaps.append(residue_swaps)
|
| 632 |
+
start_index_crop += token["atom_num"]
|
| 633 |
+
|
| 634 |
+
features = {"amino_acids_symmetries": swaps}
|
| 635 |
+
return features
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def slice_valid_index(index, ccd_to_valid_id_array, args=None):
|
| 639 |
+
index = ccd_to_valid_id_array[index]
|
| 640 |
+
valid_index_mask = (~np.isnan(index)).all(axis=0)
|
| 641 |
+
index = index[:, valid_index_mask]
|
| 642 |
+
if args is None:
|
| 643 |
+
return index
|
| 644 |
+
args = (arg[valid_index_mask] for arg in args)
|
| 645 |
+
return index, args
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
def get_ligand_symmetries(cropped, symmetries, return_physical_metrics=False):
|
| 649 |
+
# Compute ligand and non-standard amino-acids symmetries
|
| 650 |
+
structure = cropped.structure
|
| 651 |
+
|
| 652 |
+
added_molecules = {}
|
| 653 |
+
index_mols = []
|
| 654 |
+
atom_count = 0
|
| 655 |
+
|
| 656 |
+
for token in cropped.tokens:
|
| 657 |
+
# check if molecule is already added by identifying it through asym_id and res_idx
|
| 658 |
+
atom_count += token["atom_num"]
|
| 659 |
+
mol_id = (token["asym_id"], token["res_idx"])
|
| 660 |
+
if mol_id in added_molecules:
|
| 661 |
+
added_molecules[mol_id] += token["atom_num"]
|
| 662 |
+
continue
|
| 663 |
+
added_molecules[mol_id] = token["atom_num"]
|
| 664 |
+
|
| 665 |
+
# get the molecule type and indices
|
| 666 |
+
residue_idx = token["res_idx"] + structure.chains[token["asym_id"]]["res_idx"]
|
| 667 |
+
mol_name = structure.residues[residue_idx]["name"]
|
| 668 |
+
atom_idx = structure.residues[residue_idx]["atom_idx"]
|
| 669 |
+
mol_atom_names = structure.atoms[
|
| 670 |
+
atom_idx : atom_idx + structure.residues[residue_idx]["atom_num"]
|
| 671 |
+
]["name"]
|
| 672 |
+
if mol_name not in const.ref_symmetries:
|
| 673 |
+
index_mols.append(
|
| 674 |
+
(mol_name, atom_count - token["atom_num"], mol_id, mol_atom_names)
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# for each molecule, get the symmetries
|
| 678 |
+
molecule_symmetries = []
|
| 679 |
+
all_edge_index = []
|
| 680 |
+
all_lower_bounds, all_upper_bounds = [], []
|
| 681 |
+
all_bond_mask, all_angle_mask = [], []
|
| 682 |
+
all_chiral_atom_index, all_chiral_check_mask, all_chiral_atom_orientations = (
|
| 683 |
+
[],
|
| 684 |
+
[],
|
| 685 |
+
[],
|
| 686 |
+
)
|
| 687 |
+
all_stereo_bond_index, all_stereo_check_mask, all_stereo_bond_orientations = (
|
| 688 |
+
[],
|
| 689 |
+
[],
|
| 690 |
+
[],
|
| 691 |
+
)
|
| 692 |
+
(
|
| 693 |
+
all_aromatic_5_ring_index,
|
| 694 |
+
all_aromatic_6_ring_index,
|
| 695 |
+
all_planar_double_bond_index,
|
| 696 |
+
) = (
|
| 697 |
+
[],
|
| 698 |
+
[],
|
| 699 |
+
[],
|
| 700 |
+
)
|
| 701 |
+
for mol_name, start_mol, mol_id, mol_atom_names in index_mols:
|
| 702 |
+
if not mol_name in symmetries:
|
| 703 |
+
continue
|
| 704 |
+
else:
|
| 705 |
+
swaps = []
|
| 706 |
+
(
|
| 707 |
+
syms_ccd,
|
| 708 |
+
mol_atom_names_ccd,
|
| 709 |
+
edge_index,
|
| 710 |
+
lower_bounds,
|
| 711 |
+
upper_bounds,
|
| 712 |
+
bond_mask,
|
| 713 |
+
angle_mask,
|
| 714 |
+
chiral_atom_index,
|
| 715 |
+
chiral_check_mask,
|
| 716 |
+
chiral_atom_orientations,
|
| 717 |
+
stereo_bond_index,
|
| 718 |
+
stereo_check_mask,
|
| 719 |
+
stereo_bond_orientations,
|
| 720 |
+
aromatic_5_ring_index,
|
| 721 |
+
aromatic_6_ring_index,
|
| 722 |
+
planar_double_bond_index,
|
| 723 |
+
) = symmetries[mol_name]
|
| 724 |
+
# Get indices of mol_atom_names_ccd that are in mol_atom_names
|
| 725 |
+
ccd_to_valid_ids = {
|
| 726 |
+
mol_atom_names_ccd.index(name): i
|
| 727 |
+
for i, name in enumerate(mol_atom_names)
|
| 728 |
+
}
|
| 729 |
+
ccd_to_valid_id_array = np.array(
|
| 730 |
+
[
|
| 731 |
+
float("nan") if i not in ccd_to_valid_ids else ccd_to_valid_ids[i]
|
| 732 |
+
for i in range(len(mol_atom_names_ccd))
|
| 733 |
+
]
|
| 734 |
+
)
|
| 735 |
+
ccd_valid_ids = set(ccd_to_valid_ids.keys())
|
| 736 |
+
syms = []
|
| 737 |
+
# Get syms
|
| 738 |
+
for sym_ccd in syms_ccd:
|
| 739 |
+
sym_dict = {}
|
| 740 |
+
bool_add = True
|
| 741 |
+
for i, j in enumerate(sym_ccd):
|
| 742 |
+
if i in ccd_valid_ids:
|
| 743 |
+
if j in ccd_valid_ids:
|
| 744 |
+
i_true = ccd_to_valid_ids[i]
|
| 745 |
+
j_true = ccd_to_valid_ids[j]
|
| 746 |
+
sym_dict[i_true] = j_true
|
| 747 |
+
else:
|
| 748 |
+
bool_add = False
|
| 749 |
+
break
|
| 750 |
+
if bool_add:
|
| 751 |
+
syms.append([sym_dict[i] for i in range(len(ccd_valid_ids))])
|
| 752 |
+
for sym in syms:
|
| 753 |
+
if len(sym) != added_molecules[mol_id]:
|
| 754 |
+
raise Exception(
|
| 755 |
+
f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
|
| 756 |
+
)
|
| 757 |
+
# assert (
|
| 758 |
+
# len(sym) == added_molecules[mol_id]
|
| 759 |
+
# ), f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
|
| 760 |
+
sym_new_idx = []
|
| 761 |
+
for i, j in enumerate(sym):
|
| 762 |
+
if i != int(j):
|
| 763 |
+
sym_new_idx.append((i + start_mol, int(j) + start_mol))
|
| 764 |
+
if len(sym_new_idx) > 0:
|
| 765 |
+
swaps.append(sym_new_idx)
|
| 766 |
+
|
| 767 |
+
if len(swaps) > 0:
|
| 768 |
+
molecule_symmetries.append(swaps)
|
| 769 |
+
|
| 770 |
+
if return_physical_metrics:
|
| 771 |
+
edge_index, (lower_bounds, upper_bounds, bond_mask, angle_mask) = (
|
| 772 |
+
slice_valid_index(
|
| 773 |
+
edge_index,
|
| 774 |
+
ccd_to_valid_id_array,
|
| 775 |
+
(lower_bounds, upper_bounds, bond_mask, angle_mask),
|
| 776 |
+
)
|
| 777 |
+
)
|
| 778 |
+
all_edge_index.append(edge_index + start_mol)
|
| 779 |
+
all_lower_bounds.append(lower_bounds)
|
| 780 |
+
all_upper_bounds.append(upper_bounds)
|
| 781 |
+
all_bond_mask.append(bond_mask)
|
| 782 |
+
all_angle_mask.append(angle_mask)
|
| 783 |
+
|
| 784 |
+
chiral_atom_index, (chiral_check_mask, chiral_atom_orientations) = (
|
| 785 |
+
slice_valid_index(
|
| 786 |
+
chiral_atom_index,
|
| 787 |
+
ccd_to_valid_id_array,
|
| 788 |
+
(chiral_check_mask, chiral_atom_orientations),
|
| 789 |
+
)
|
| 790 |
+
)
|
| 791 |
+
all_chiral_atom_index.append(chiral_atom_index + start_mol)
|
| 792 |
+
all_chiral_check_mask.append(chiral_check_mask)
|
| 793 |
+
all_chiral_atom_orientations.append(chiral_atom_orientations)
|
| 794 |
+
|
| 795 |
+
stereo_bond_index, (stereo_check_mask, stereo_bond_orientations) = (
|
| 796 |
+
slice_valid_index(
|
| 797 |
+
stereo_bond_index,
|
| 798 |
+
ccd_to_valid_id_array,
|
| 799 |
+
(stereo_check_mask, stereo_bond_orientations),
|
| 800 |
+
)
|
| 801 |
+
)
|
| 802 |
+
all_stereo_bond_index.append(stereo_bond_index + start_mol)
|
| 803 |
+
all_stereo_check_mask.append(stereo_check_mask)
|
| 804 |
+
all_stereo_bond_orientations.append(stereo_bond_orientations)
|
| 805 |
+
|
| 806 |
+
aromatic_5_ring_index = slice_valid_index(
|
| 807 |
+
aromatic_5_ring_index, ccd_to_valid_id_array
|
| 808 |
+
)
|
| 809 |
+
aromatic_6_ring_index = slice_valid_index(
|
| 810 |
+
aromatic_6_ring_index, ccd_to_valid_id_array
|
| 811 |
+
)
|
| 812 |
+
planar_double_bond_index = slice_valid_index(
|
| 813 |
+
planar_double_bond_index, ccd_to_valid_id_array
|
| 814 |
+
)
|
| 815 |
+
all_aromatic_5_ring_index.append(aromatic_5_ring_index + start_mol)
|
| 816 |
+
all_aromatic_6_ring_index.append(aromatic_6_ring_index + start_mol)
|
| 817 |
+
all_planar_double_bond_index.append(
|
| 818 |
+
planar_double_bond_index + start_mol
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
if return_physical_metrics:
|
| 822 |
+
if len(all_edge_index) > 0:
|
| 823 |
+
all_edge_index = np.concatenate(all_edge_index, axis=1)
|
| 824 |
+
all_lower_bounds = np.concatenate(all_lower_bounds, axis=0)
|
| 825 |
+
all_upper_bounds = np.concatenate(all_upper_bounds, axis=0)
|
| 826 |
+
all_bond_mask = np.concatenate(all_bond_mask, axis=0)
|
| 827 |
+
all_angle_mask = np.concatenate(all_angle_mask, axis=0)
|
| 828 |
+
|
| 829 |
+
all_chiral_atom_index = np.concatenate(all_chiral_atom_index, axis=1)
|
| 830 |
+
all_chiral_check_mask = np.concatenate(all_chiral_check_mask, axis=0)
|
| 831 |
+
all_chiral_atom_orientations = np.concatenate(
|
| 832 |
+
all_chiral_atom_orientations, axis=0
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
all_stereo_bond_index = np.concatenate(all_stereo_bond_index, axis=1)
|
| 836 |
+
all_stereo_check_mask = np.concatenate(all_stereo_check_mask, axis=0)
|
| 837 |
+
all_stereo_bond_orientations = np.concatenate(
|
| 838 |
+
all_stereo_bond_orientations, axis=0
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
all_aromatic_5_ring_index = np.concatenate(
|
| 842 |
+
all_aromatic_5_ring_index, axis=1
|
| 843 |
+
)
|
| 844 |
+
all_aromatic_6_ring_index = np.concatenate(
|
| 845 |
+
all_aromatic_6_ring_index, axis=1
|
| 846 |
+
)
|
| 847 |
+
all_planar_double_bond_index = np.empty(
|
| 848 |
+
(6, 0), dtype=np.int64
|
| 849 |
+
) # TODO remove np.concatenate(all_planar_double_bond_index, axis=1)
|
| 850 |
+
else:
|
| 851 |
+
all_edge_index = np.empty((2, 0), dtype=np.int64)
|
| 852 |
+
all_lower_bounds = np.array([], dtype=np.float32)
|
| 853 |
+
all_upper_bounds = np.array([], dtype=np.float32)
|
| 854 |
+
all_bond_mask = np.array([], dtype=bool)
|
| 855 |
+
all_angle_mask = np.array([], dtype=bool)
|
| 856 |
+
|
| 857 |
+
all_chiral_atom_index = np.empty((4, 0), dtype=np.int64)
|
| 858 |
+
all_chiral_check_mask = np.array([], dtype=bool)
|
| 859 |
+
all_chiral_atom_orientations = np.array([], dtype=bool)
|
| 860 |
+
|
| 861 |
+
all_stereo_bond_index = np.empty((4, 0), dtype=np.int64)
|
| 862 |
+
all_stereo_check_mask = np.array([], dtype=bool)
|
| 863 |
+
all_stereo_bond_orientations = np.array([], dtype=bool)
|
| 864 |
+
|
| 865 |
+
all_aromatic_5_ring_index = np.empty((5, 0), dtype=np.int64)
|
| 866 |
+
all_aromatic_6_ring_index = np.empty((6, 0), dtype=np.int64)
|
| 867 |
+
all_planar_double_bond_index = np.empty((6, 0), dtype=np.int64)
|
| 868 |
+
|
| 869 |
+
features = {
|
| 870 |
+
"ligand_symmetries": molecule_symmetries,
|
| 871 |
+
"ligand_edge_index": torch.tensor(all_edge_index).long(),
|
| 872 |
+
"ligand_edge_lower_bounds": torch.tensor(all_lower_bounds),
|
| 873 |
+
"ligand_edge_upper_bounds": torch.tensor(all_upper_bounds),
|
| 874 |
+
"ligand_edge_bond_mask": torch.tensor(all_bond_mask),
|
| 875 |
+
"ligand_edge_angle_mask": torch.tensor(all_angle_mask),
|
| 876 |
+
"ligand_chiral_atom_index": torch.tensor(all_chiral_atom_index).long(),
|
| 877 |
+
"ligand_chiral_check_mask": torch.tensor(all_chiral_check_mask),
|
| 878 |
+
"ligand_chiral_atom_orientations": torch.tensor(
|
| 879 |
+
all_chiral_atom_orientations
|
| 880 |
+
),
|
| 881 |
+
"ligand_stereo_bond_index": torch.tensor(all_stereo_bond_index).long(),
|
| 882 |
+
"ligand_stereo_check_mask": torch.tensor(all_stereo_check_mask),
|
| 883 |
+
"ligand_stereo_bond_orientations": torch.tensor(
|
| 884 |
+
all_stereo_bond_orientations
|
| 885 |
+
),
|
| 886 |
+
"ligand_aromatic_5_ring_index": torch.tensor(
|
| 887 |
+
all_aromatic_5_ring_index
|
| 888 |
+
).long(),
|
| 889 |
+
"ligand_aromatic_6_ring_index": torch.tensor(
|
| 890 |
+
all_aromatic_6_ring_index
|
| 891 |
+
).long(),
|
| 892 |
+
"ligand_planar_double_bond_index": torch.tensor(
|
| 893 |
+
all_planar_double_bond_index
|
| 894 |
+
).long(),
|
| 895 |
+
}
|
| 896 |
+
else:
|
| 897 |
+
features = {
|
| 898 |
+
"ligand_symmetries": molecule_symmetries,
|
| 899 |
+
}
|
| 900 |
+
return features
|
protify/FastPLMs/boltz/src/boltz/data/msa/__init__.py
ADDED
|
File without changes
|
protify/FastPLMs/boltz/src/boltz/data/msa/mmseqs2.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# From https://github.com/sokrypton/ColabFold/blob/main/colabfold/colabfold.py
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import tarfile
|
| 7 |
+
import time
|
| 8 |
+
from typing import Optional, Union, Dict
|
| 9 |
+
|
| 10 |
+
import requests
|
| 11 |
+
from requests.auth import HTTPBasicAuth
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
TQDM_BAR_FORMAT = (
|
| 17 |
+
"{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def run_mmseqs2( # noqa: PLR0912, D103, C901, PLR0915
|
| 22 |
+
x: Union[str, list[str]],
|
| 23 |
+
prefix: str = "tmp",
|
| 24 |
+
use_env: bool = True,
|
| 25 |
+
use_filter: bool = True,
|
| 26 |
+
use_pairing: bool = False,
|
| 27 |
+
pairing_strategy: str = "greedy",
|
| 28 |
+
host_url: str = "https://api.colabfold.com",
|
| 29 |
+
msa_server_username: Optional[str] = None,
|
| 30 |
+
msa_server_password: Optional[str] = None,
|
| 31 |
+
auth_headers: Optional[Dict[str, str]] = None,
|
| 32 |
+
) -> tuple[list[str], list[str]]:
|
| 33 |
+
submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"
|
| 34 |
+
|
| 35 |
+
# Validate mutually exclusive authentication methods
|
| 36 |
+
has_basic_auth = msa_server_username and msa_server_password
|
| 37 |
+
has_header_auth = auth_headers is not None
|
| 38 |
+
if has_basic_auth and (has_header_auth or auth_headers):
|
| 39 |
+
raise ValueError(
|
| 40 |
+
"Cannot use both basic authentication (username/password) and header/API key authentication. "
|
| 41 |
+
"Please use only one authentication method."
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Set header agent as boltz
|
| 45 |
+
headers = {}
|
| 46 |
+
headers["User-Agent"] = "boltz"
|
| 47 |
+
|
| 48 |
+
# Set up authentication
|
| 49 |
+
auth = None
|
| 50 |
+
if has_basic_auth:
|
| 51 |
+
auth = HTTPBasicAuth(msa_server_username, msa_server_password)
|
| 52 |
+
logger.debug(f"MMSeqs2 server authentication: using basic auth for user '{msa_server_username}'")
|
| 53 |
+
elif has_header_auth:
|
| 54 |
+
headers.update(auth_headers)
|
| 55 |
+
logger.debug("MMSeqs2 server authentication: using header-based authentication")
|
| 56 |
+
else:
|
| 57 |
+
logger.debug("MMSeqs2 server authentication: no credentials provided")
|
| 58 |
+
|
| 59 |
+
logger.debug(f"Connecting to MMSeqs2 server at: {host_url}")
|
| 60 |
+
logger.debug(f"Using endpoint: {submission_endpoint}")
|
| 61 |
+
logger.debug(f"Pairing strategy: {pairing_strategy}")
|
| 62 |
+
logger.debug(f"Use environment databases: {use_env}")
|
| 63 |
+
logger.debug(f"Use filtering: {use_filter}")
|
| 64 |
+
|
| 65 |
+
def submit(seqs, mode, N=101):
|
| 66 |
+
n, query = N, ""
|
| 67 |
+
for seq in seqs:
|
| 68 |
+
query += f">{n}\n{seq}\n"
|
| 69 |
+
n += 1
|
| 70 |
+
|
| 71 |
+
error_count = 0
|
| 72 |
+
while True:
|
| 73 |
+
try:
|
| 74 |
+
# https://requests.readthedocs.io/en/latest/user/advanced/#advanced
|
| 75 |
+
# "good practice to set connect timeouts to slightly larger than a multiple of 3"
|
| 76 |
+
logger.debug(f"Submitting MSA request to {host_url}/{submission_endpoint}")
|
| 77 |
+
res = requests.post(
|
| 78 |
+
f"{host_url}/{submission_endpoint}",
|
| 79 |
+
data={"q": query, "mode": mode},
|
| 80 |
+
timeout=6.02,
|
| 81 |
+
headers=headers,
|
| 82 |
+
auth=auth,
|
| 83 |
+
)
|
| 84 |
+
logger.debug(f"MSA submission response status: {res.status_code}")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
error_count += 1
|
| 87 |
+
logger.warning(
|
| 88 |
+
f"Error while fetching result from MSA server. Retrying... ({error_count}/5)"
|
| 89 |
+
)
|
| 90 |
+
logger.warning(f"Error: {e}")
|
| 91 |
+
if error_count > 5:
|
| 92 |
+
raise Exception(
|
| 93 |
+
"Too many failed attempts for the MSA generation request."
|
| 94 |
+
)
|
| 95 |
+
time.sleep(5)
|
| 96 |
+
else:
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
out = res.json()
|
| 101 |
+
except ValueError:
|
| 102 |
+
logger.error(f"Server didn't reply with json: {res.text}")
|
| 103 |
+
out = {"status": "ERROR"}
|
| 104 |
+
return out
|
| 105 |
+
|
| 106 |
+
def status(ID):
|
| 107 |
+
error_count = 0
|
| 108 |
+
while True:
|
| 109 |
+
try:
|
| 110 |
+
logger.debug(f"Checking MSA job status for ID: {ID}")
|
| 111 |
+
res = requests.get(
|
| 112 |
+
f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers, auth=auth
|
| 113 |
+
)
|
| 114 |
+
logger.debug(f"MSA status check response status: {res.status_code}")
|
| 115 |
+
except Exception as e:
|
| 116 |
+
error_count += 1
|
| 117 |
+
logger.warning(
|
| 118 |
+
f"Error while fetching result from MSA server. Retrying... ({error_count}/5)"
|
| 119 |
+
)
|
| 120 |
+
logger.warning(f"Error: {e}")
|
| 121 |
+
if error_count > 5:
|
| 122 |
+
raise Exception(
|
| 123 |
+
"Too many failed attempts for the MSA generation request."
|
| 124 |
+
)
|
| 125 |
+
time.sleep(5)
|
| 126 |
+
else:
|
| 127 |
+
break
|
| 128 |
+
try:
|
| 129 |
+
out = res.json()
|
| 130 |
+
except ValueError:
|
| 131 |
+
logger.error(f"Server didn't reply with json: {res.text}")
|
| 132 |
+
out = {"status": "ERROR"}
|
| 133 |
+
return out
|
| 134 |
+
|
| 135 |
+
def download(ID, path):
|
| 136 |
+
error_count = 0
|
| 137 |
+
while True:
|
| 138 |
+
try:
|
| 139 |
+
logger.debug(f"Downloading MSA results for ID: {ID}")
|
| 140 |
+
res = requests.get(
|
| 141 |
+
f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers, auth=auth
|
| 142 |
+
)
|
| 143 |
+
logger.debug(f"MSA download response status: {res.status_code}")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
error_count += 1
|
| 146 |
+
logger.warning(
|
| 147 |
+
f"Error while fetching result from MSA server. Retrying... ({error_count}/5)"
|
| 148 |
+
)
|
| 149 |
+
logger.warning(f"Error: {e}")
|
| 150 |
+
if error_count > 5:
|
| 151 |
+
raise Exception(
|
| 152 |
+
"Too many failed attempts for the MSA generation request."
|
| 153 |
+
)
|
| 154 |
+
time.sleep(5)
|
| 155 |
+
else:
|
| 156 |
+
break
|
| 157 |
+
with open(path, "wb") as out:
|
| 158 |
+
out.write(res.content)
|
| 159 |
+
|
| 160 |
+
# process input x
|
| 161 |
+
seqs = [x] if isinstance(x, str) else x
|
| 162 |
+
|
| 163 |
+
# setup mode
|
| 164 |
+
if use_filter:
|
| 165 |
+
mode = "env" if use_env else "all"
|
| 166 |
+
else:
|
| 167 |
+
mode = "env-nofilter" if use_env else "nofilter"
|
| 168 |
+
|
| 169 |
+
if use_pairing:
|
| 170 |
+
mode = ""
|
| 171 |
+
# greedy is default, complete was the previous behavior
|
| 172 |
+
if pairing_strategy == "greedy":
|
| 173 |
+
mode = "pairgreedy"
|
| 174 |
+
elif pairing_strategy == "complete":
|
| 175 |
+
mode = "paircomplete"
|
| 176 |
+
if use_env:
|
| 177 |
+
mode = mode + "-env"
|
| 178 |
+
|
| 179 |
+
# define path
|
| 180 |
+
path = f"{prefix}_{mode}"
|
| 181 |
+
if not os.path.isdir(path):
|
| 182 |
+
os.mkdir(path)
|
| 183 |
+
|
| 184 |
+
# call mmseqs2 api
|
| 185 |
+
tar_gz_file = f"{path}/out.tar.gz"
|
| 186 |
+
N, REDO = 101, True
|
| 187 |
+
|
| 188 |
+
# deduplicate and keep track of order
|
| 189 |
+
seqs_unique = []
|
| 190 |
+
# TODO this might be slow for large sets
|
| 191 |
+
[seqs_unique.append(x) for x in seqs if x not in seqs_unique]
|
| 192 |
+
Ms = [N + seqs_unique.index(seq) for seq in seqs]
|
| 193 |
+
# lets do it!
|
| 194 |
+
if not os.path.isfile(tar_gz_file):
|
| 195 |
+
TIME_ESTIMATE = 150 * len(seqs_unique)
|
| 196 |
+
with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
|
| 197 |
+
while REDO:
|
| 198 |
+
pbar.set_description("SUBMIT")
|
| 199 |
+
|
| 200 |
+
# Resubmit job until it goes through
|
| 201 |
+
out = submit(seqs_unique, mode, N)
|
| 202 |
+
while out["status"] in ["UNKNOWN", "RATELIMIT"]:
|
| 203 |
+
sleep_time = 5 + random.randint(0, 5)
|
| 204 |
+
logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
|
| 205 |
+
# resubmit
|
| 206 |
+
time.sleep(sleep_time)
|
| 207 |
+
out = submit(seqs_unique, mode, N)
|
| 208 |
+
|
| 209 |
+
if out["status"] == "ERROR":
|
| 210 |
+
msg = (
|
| 211 |
+
"MMseqs2 API is giving errors. Please confirm your "
|
| 212 |
+
" input is a valid protein sequence. If error persists, "
|
| 213 |
+
"please try again an hour later."
|
| 214 |
+
)
|
| 215 |
+
raise Exception(msg)
|
| 216 |
+
|
| 217 |
+
if out["status"] == "MAINTENANCE":
|
| 218 |
+
msg = (
|
| 219 |
+
"MMseqs2 API is undergoing maintenance. "
|
| 220 |
+
"Please try again in a few minutes."
|
| 221 |
+
)
|
| 222 |
+
raise Exception(msg)
|
| 223 |
+
|
| 224 |
+
# wait for job to finish
|
| 225 |
+
ID, TIME = out["id"], 0
|
| 226 |
+
logger.debug(f"MSA job submitted successfully with ID: {ID}")
|
| 227 |
+
pbar.set_description(out["status"])
|
| 228 |
+
while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]:
|
| 229 |
+
t = 5 + random.randint(0, 5)
|
| 230 |
+
logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
|
| 231 |
+
time.sleep(t)
|
| 232 |
+
out = status(ID)
|
| 233 |
+
pbar.set_description(out["status"])
|
| 234 |
+
if out["status"] == "RUNNING":
|
| 235 |
+
TIME += t
|
| 236 |
+
pbar.update(n=t)
|
| 237 |
+
|
| 238 |
+
if out["status"] == "COMPLETE":
|
| 239 |
+
logger.debug(f"MSA job completed successfully for ID: {ID}")
|
| 240 |
+
if TIME < TIME_ESTIMATE:
|
| 241 |
+
pbar.update(n=(TIME_ESTIMATE - TIME))
|
| 242 |
+
REDO = False
|
| 243 |
+
|
| 244 |
+
if out["status"] == "ERROR":
|
| 245 |
+
REDO = False
|
| 246 |
+
msg = (
|
| 247 |
+
"MMseqs2 API is giving errors. Please confirm your "
|
| 248 |
+
" input is a valid protein sequence. If error persists, "
|
| 249 |
+
"please try again an hour later."
|
| 250 |
+
)
|
| 251 |
+
raise Exception(msg)
|
| 252 |
+
|
| 253 |
+
# Download results
|
| 254 |
+
download(ID, tar_gz_file)
|
| 255 |
+
|
| 256 |
+
# prep list of a3m files
|
| 257 |
+
if use_pairing:
|
| 258 |
+
a3m_files = [f"{path}/pair.a3m"]
|
| 259 |
+
else:
|
| 260 |
+
a3m_files = [f"{path}/uniref.a3m"]
|
| 261 |
+
if use_env:
|
| 262 |
+
a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
|
| 263 |
+
|
| 264 |
+
# extract a3m files
|
| 265 |
+
if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
|
| 266 |
+
with tarfile.open(tar_gz_file) as tar_gz:
|
| 267 |
+
tar_gz.extractall(path)
|
| 268 |
+
|
| 269 |
+
# gather a3m lines
|
| 270 |
+
a3m_lines = {}
|
| 271 |
+
for a3m_file in a3m_files:
|
| 272 |
+
update_M, M = True, None
|
| 273 |
+
for line in open(a3m_file, "r"):
|
| 274 |
+
if len(line) > 0:
|
| 275 |
+
if "\x00" in line:
|
| 276 |
+
line = line.replace("\x00", "")
|
| 277 |
+
update_M = True
|
| 278 |
+
if line.startswith(">") and update_M:
|
| 279 |
+
M = int(line[1:].rstrip())
|
| 280 |
+
update_M = False
|
| 281 |
+
if M not in a3m_lines:
|
| 282 |
+
a3m_lines[M] = []
|
| 283 |
+
a3m_lines[M].append(line)
|
| 284 |
+
|
| 285 |
+
a3m_lines = ["".join(a3m_lines[n]) for n in Ms]
|
| 286 |
+
return a3m_lines
|
protify/FastPLMs/boltz/src/boltz/data/pad.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
from torch.nn.functional import pad
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def pad_dim(data: Tensor, dim: int, pad_len: float, value: float = 0) -> Tensor:
|
| 7 |
+
"""Pad a tensor along a given dimension.
|
| 8 |
+
|
| 9 |
+
Parameters
|
| 10 |
+
----------
|
| 11 |
+
data : Tensor
|
| 12 |
+
The input tensor.
|
| 13 |
+
dim : int
|
| 14 |
+
The dimension to pad.
|
| 15 |
+
pad_len : float
|
| 16 |
+
The padding length.
|
| 17 |
+
value : int, optional
|
| 18 |
+
The value to pad with.
|
| 19 |
+
|
| 20 |
+
Returns
|
| 21 |
+
-------
|
| 22 |
+
Tensor
|
| 23 |
+
The padded tensor.
|
| 24 |
+
|
| 25 |
+
"""
|
| 26 |
+
if pad_len == 0:
|
| 27 |
+
return data
|
| 28 |
+
|
| 29 |
+
total_dims = len(data.shape)
|
| 30 |
+
padding = [0] * (2 * (total_dims - dim))
|
| 31 |
+
padding[2 * (total_dims - 1 - dim) + 1] = pad_len
|
| 32 |
+
return pad(data, tuple(padding), value=value)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def pad_to_max(data: list[Tensor], value: float = 0) -> tuple[Tensor, Tensor]:
|
| 36 |
+
"""Pad the data in all dimensions to the maximum found.
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
data : list[Tensor]
|
| 41 |
+
list of tensors to pad.
|
| 42 |
+
value : float
|
| 43 |
+
The value to use for padding.
|
| 44 |
+
|
| 45 |
+
Returns
|
| 46 |
+
-------
|
| 47 |
+
Tensor
|
| 48 |
+
The padded tensor.
|
| 49 |
+
Tensor
|
| 50 |
+
The padding mask.
|
| 51 |
+
|
| 52 |
+
"""
|
| 53 |
+
if isinstance(data[0], str):
|
| 54 |
+
return data, 0
|
| 55 |
+
|
| 56 |
+
# Check if all have the same shape
|
| 57 |
+
if all(d.shape == data[0].shape for d in data):
|
| 58 |
+
return torch.stack(data, dim=0), 0
|
| 59 |
+
|
| 60 |
+
# Get the maximum in each dimension
|
| 61 |
+
num_dims = len(data[0].shape)
|
| 62 |
+
max_dims = [max(d.shape[i] for d in data) for i in range(num_dims)]
|
| 63 |
+
|
| 64 |
+
# Get the padding lengths
|
| 65 |
+
pad_lengths = []
|
| 66 |
+
for d in data:
|
| 67 |
+
dims = []
|
| 68 |
+
for i in range(num_dims):
|
| 69 |
+
dims.append(0)
|
| 70 |
+
dims.append(max_dims[num_dims - i - 1] - d.shape[num_dims - i - 1])
|
| 71 |
+
pad_lengths.append(dims)
|
| 72 |
+
|
| 73 |
+
# Pad the data
|
| 74 |
+
padding = [
|
| 75 |
+
pad(torch.ones_like(d), pad_len, value=0)
|
| 76 |
+
for d, pad_len in zip(data, pad_lengths)
|
| 77 |
+
]
|
| 78 |
+
data = [pad(d, pad_len, value=value) for d, pad_len in zip(data, pad_lengths)]
|
| 79 |
+
|
| 80 |
+
# Stack the data
|
| 81 |
+
padding = torch.stack(padding, dim=0)
|
| 82 |
+
data = torch.stack(data, dim=0)
|
| 83 |
+
|
| 84 |
+
return data, padding
|
protify/FastPLMs/boltz/src/boltz/data/parse/__init__.py
ADDED
|
File without changes
|
protify/FastPLMs/boltz/src/boltz/data/parse/a3m.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gzip
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional, TextIO
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from boltz.data import const
|
| 8 |
+
from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _parse_a3m( # noqa: C901
|
| 12 |
+
lines: TextIO,
|
| 13 |
+
taxonomy: Optional[dict[str, str]],
|
| 14 |
+
max_seqs: Optional[int] = None,
|
| 15 |
+
) -> MSA:
|
| 16 |
+
"""Process an MSA file.
|
| 17 |
+
|
| 18 |
+
Parameters
|
| 19 |
+
----------
|
| 20 |
+
lines : TextIO
|
| 21 |
+
The lines of the MSA file.
|
| 22 |
+
taxonomy : dict[str, str]
|
| 23 |
+
The taxonomy database, if available.
|
| 24 |
+
max_seqs : int, optional
|
| 25 |
+
The maximum number of sequences.
|
| 26 |
+
|
| 27 |
+
Returns
|
| 28 |
+
-------
|
| 29 |
+
MSA
|
| 30 |
+
The MSA object.
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
visited = set()
|
| 34 |
+
sequences = []
|
| 35 |
+
deletions = []
|
| 36 |
+
residues = []
|
| 37 |
+
|
| 38 |
+
seq_idx = 0
|
| 39 |
+
for line in lines:
|
| 40 |
+
line: str
|
| 41 |
+
line = line.strip() # noqa: PLW2901
|
| 42 |
+
if not line or line.startswith("#"):
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
# Get taxonomy, if annotated
|
| 46 |
+
if line.startswith(">"):
|
| 47 |
+
header = line.split()[0]
|
| 48 |
+
if taxonomy and header.startswith(">UniRef100"):
|
| 49 |
+
uniref_id = header.split("_")[1]
|
| 50 |
+
taxonomy_id = taxonomy.get(uniref_id)
|
| 51 |
+
if taxonomy_id is None:
|
| 52 |
+
taxonomy_id = -1
|
| 53 |
+
else:
|
| 54 |
+
taxonomy_id = -1
|
| 55 |
+
continue
|
| 56 |
+
|
| 57 |
+
# Skip if duplicate sequence
|
| 58 |
+
str_seq = line.replace("-", "").upper()
|
| 59 |
+
if str_seq not in visited:
|
| 60 |
+
visited.add(str_seq)
|
| 61 |
+
else:
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
# Process sequence
|
| 65 |
+
residue = []
|
| 66 |
+
deletion = []
|
| 67 |
+
count = 0
|
| 68 |
+
res_idx = 0
|
| 69 |
+
for c in line:
|
| 70 |
+
if c != "-" and c.islower():
|
| 71 |
+
count += 1
|
| 72 |
+
continue
|
| 73 |
+
token = const.prot_letter_to_token[c]
|
| 74 |
+
token = const.token_ids[token]
|
| 75 |
+
residue.append(token)
|
| 76 |
+
if count > 0:
|
| 77 |
+
deletion.append((res_idx, count))
|
| 78 |
+
count = 0
|
| 79 |
+
res_idx += 1
|
| 80 |
+
|
| 81 |
+
res_start = len(residues)
|
| 82 |
+
res_end = res_start + len(residue)
|
| 83 |
+
|
| 84 |
+
del_start = len(deletions)
|
| 85 |
+
del_end = del_start + len(deletion)
|
| 86 |
+
|
| 87 |
+
sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
|
| 88 |
+
residues.extend(residue)
|
| 89 |
+
deletions.extend(deletion)
|
| 90 |
+
|
| 91 |
+
seq_idx += 1
|
| 92 |
+
if (max_seqs is not None) and (seq_idx >= max_seqs):
|
| 93 |
+
break
|
| 94 |
+
|
| 95 |
+
# Create MSA object
|
| 96 |
+
msa = MSA(
|
| 97 |
+
residues=np.array(residues, dtype=MSAResidue),
|
| 98 |
+
deletions=np.array(deletions, dtype=MSADeletion),
|
| 99 |
+
sequences=np.array(sequences, dtype=MSASequence),
|
| 100 |
+
)
|
| 101 |
+
return msa
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def parse_a3m(
|
| 105 |
+
path: Path,
|
| 106 |
+
taxonomy: Optional[dict[str, str]],
|
| 107 |
+
max_seqs: Optional[int] = None,
|
| 108 |
+
) -> MSA:
|
| 109 |
+
"""Process an A3M file.
|
| 110 |
+
|
| 111 |
+
Parameters
|
| 112 |
+
----------
|
| 113 |
+
path : Path
|
| 114 |
+
The path to the a3m(.gz) file.
|
| 115 |
+
taxonomy : Redis
|
| 116 |
+
The taxonomy database.
|
| 117 |
+
max_seqs : int, optional
|
| 118 |
+
The maximum number of sequences.
|
| 119 |
+
|
| 120 |
+
Returns
|
| 121 |
+
-------
|
| 122 |
+
MSA
|
| 123 |
+
The MSA object.
|
| 124 |
+
|
| 125 |
+
"""
|
| 126 |
+
# Read the file
|
| 127 |
+
if path.suffix == ".gz":
|
| 128 |
+
with gzip.open(str(path), "rt") as f:
|
| 129 |
+
msa = _parse_a3m(f, taxonomy, max_seqs)
|
| 130 |
+
else:
|
| 131 |
+
with path.open("r") as f:
|
| 132 |
+
msa = _parse_a3m(f, taxonomy, max_seqs)
|
| 133 |
+
|
| 134 |
+
return msa
|
protify/FastPLMs/boltz/src/boltz/data/parse/csv.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
from boltz.data import const
|
| 8 |
+
from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def parse_csv(
|
| 12 |
+
path: Path,
|
| 13 |
+
max_seqs: Optional[int] = None,
|
| 14 |
+
) -> MSA:
|
| 15 |
+
"""Process an A3M file.
|
| 16 |
+
|
| 17 |
+
Parameters
|
| 18 |
+
----------
|
| 19 |
+
path : Path
|
| 20 |
+
The path to the a3m(.gz) file.
|
| 21 |
+
max_seqs : int, optional
|
| 22 |
+
The maximum number of sequences.
|
| 23 |
+
|
| 24 |
+
Returns
|
| 25 |
+
-------
|
| 26 |
+
MSA
|
| 27 |
+
The MSA object.
|
| 28 |
+
|
| 29 |
+
"""
|
| 30 |
+
# Read file
|
| 31 |
+
data = pd.read_csv(path)
|
| 32 |
+
|
| 33 |
+
# Check columns
|
| 34 |
+
if tuple(sorted(data.columns)) != ("key", "sequence"):
|
| 35 |
+
msg = "Invalid CSV format, expected columns: ['sequence', 'key']"
|
| 36 |
+
raise ValueError(msg)
|
| 37 |
+
|
| 38 |
+
# Create taxonomy mapping
|
| 39 |
+
visited = set()
|
| 40 |
+
sequences = []
|
| 41 |
+
deletions = []
|
| 42 |
+
residues = []
|
| 43 |
+
|
| 44 |
+
seq_idx = 0
|
| 45 |
+
for line, key in zip(data["sequence"], data["key"]):
|
| 46 |
+
line: str
|
| 47 |
+
line = line.strip() # noqa: PLW2901
|
| 48 |
+
if not line:
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
# Get taxonomy, if annotated
|
| 52 |
+
taxonomy_id = -1
|
| 53 |
+
if (str(key) != "nan") and (key is not None) and (key != ""):
|
| 54 |
+
taxonomy_id = key
|
| 55 |
+
|
| 56 |
+
# Skip if duplicate sequence
|
| 57 |
+
str_seq = line.replace("-", "").upper()
|
| 58 |
+
if str_seq not in visited:
|
| 59 |
+
visited.add(str_seq)
|
| 60 |
+
else:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
# Process sequence
|
| 64 |
+
residue = []
|
| 65 |
+
deletion = []
|
| 66 |
+
count = 0
|
| 67 |
+
res_idx = 0
|
| 68 |
+
for c in line:
|
| 69 |
+
if c != "-" and c.islower():
|
| 70 |
+
count += 1
|
| 71 |
+
continue
|
| 72 |
+
token = const.prot_letter_to_token[c]
|
| 73 |
+
token = const.token_ids[token]
|
| 74 |
+
residue.append(token)
|
| 75 |
+
if count > 0:
|
| 76 |
+
deletion.append((res_idx, count))
|
| 77 |
+
count = 0
|
| 78 |
+
res_idx += 1
|
| 79 |
+
|
| 80 |
+
res_start = len(residues)
|
| 81 |
+
res_end = res_start + len(residue)
|
| 82 |
+
|
| 83 |
+
del_start = len(deletions)
|
| 84 |
+
del_end = del_start + len(deletion)
|
| 85 |
+
|
| 86 |
+
sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
|
| 87 |
+
residues.extend(residue)
|
| 88 |
+
deletions.extend(deletion)
|
| 89 |
+
|
| 90 |
+
seq_idx += 1
|
| 91 |
+
if (max_seqs is not None) and (seq_idx >= max_seqs):
|
| 92 |
+
break
|
| 93 |
+
|
| 94 |
+
# Create MSA object
|
| 95 |
+
msa = MSA(
|
| 96 |
+
residues=np.array(residues, dtype=MSAResidue),
|
| 97 |
+
deletions=np.array(deletions, dtype=MSADeletion),
|
| 98 |
+
sequences=np.array(sequences, dtype=MSASequence),
|
| 99 |
+
)
|
| 100 |
+
return msa
|
protify/FastPLMs/boltz/src/boltz/data/parse/fasta.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Mapping
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from Bio import SeqIO
|
| 5 |
+
from rdkit.Chem.rdchem import Mol
|
| 6 |
+
|
| 7 |
+
from boltz.data.parse.yaml import parse_boltz_schema
|
| 8 |
+
from boltz.data.types import Target
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def parse_fasta( # noqa: C901, PLR0912
|
| 12 |
+
path: Path,
|
| 13 |
+
ccd: Mapping[str, Mol],
|
| 14 |
+
mol_dir: Path,
|
| 15 |
+
boltz2: bool = False,
|
| 16 |
+
) -> Target:
|
| 17 |
+
"""Parse a fasta file.
|
| 18 |
+
|
| 19 |
+
The name of the fasta file is used as the name of this job.
|
| 20 |
+
We rely on the fasta record id to determine the entity type.
|
| 21 |
+
|
| 22 |
+
> CHAIN_ID|ENTITY_TYPE|MSA_ID
|
| 23 |
+
SEQUENCE
|
| 24 |
+
> CHAIN_ID|ENTITY_TYPE|MSA_ID
|
| 25 |
+
...
|
| 26 |
+
|
| 27 |
+
Where ENTITY_TYPE is either protein, rna, dna, ccd or smiles,
|
| 28 |
+
and CHAIN_ID is the chain identifier, which should be unique.
|
| 29 |
+
The MSA_ID is optional and should only be used on proteins.
|
| 30 |
+
|
| 31 |
+
Parameters
|
| 32 |
+
----------
|
| 33 |
+
fasta_file : Path
|
| 34 |
+
Path to the fasta file.
|
| 35 |
+
ccd : Dict
|
| 36 |
+
Dictionary of CCD components.
|
| 37 |
+
mol_dir : Path
|
| 38 |
+
Path to the directory containing the molecules.
|
| 39 |
+
boltz2 : bool
|
| 40 |
+
Whether to parse the input for Boltz2.
|
| 41 |
+
|
| 42 |
+
Returns
|
| 43 |
+
-------
|
| 44 |
+
Target
|
| 45 |
+
The parsed target.
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
# Read fasta file
|
| 49 |
+
with path.open("r") as f:
|
| 50 |
+
records = list(SeqIO.parse(f, "fasta"))
|
| 51 |
+
|
| 52 |
+
# Make sure all records have a chain id and entity
|
| 53 |
+
for seq_record in records:
|
| 54 |
+
if "|" not in seq_record.id:
|
| 55 |
+
msg = f"Invalid record id: {seq_record.id}"
|
| 56 |
+
raise ValueError(msg)
|
| 57 |
+
|
| 58 |
+
header = seq_record.id.split("|")
|
| 59 |
+
assert len(header) >= 2, f"Invalid record id: {seq_record.id}"
|
| 60 |
+
|
| 61 |
+
chain_id, entity_type = header[:2]
|
| 62 |
+
if entity_type.lower() not in {"protein", "dna", "rna", "ccd", "smiles"}:
|
| 63 |
+
msg = f"Invalid entity type: {entity_type}"
|
| 64 |
+
raise ValueError(msg)
|
| 65 |
+
if chain_id == "":
|
| 66 |
+
msg = "Empty chain id in input fasta!"
|
| 67 |
+
raise ValueError(msg)
|
| 68 |
+
if entity_type == "":
|
| 69 |
+
msg = "Empty entity type in input fasta!"
|
| 70 |
+
raise ValueError(msg)
|
| 71 |
+
|
| 72 |
+
# Convert to yaml format
|
| 73 |
+
sequences = []
|
| 74 |
+
for seq_record in records:
|
| 75 |
+
# Get chain id, entity type and sequence
|
| 76 |
+
header = seq_record.id.split("|")
|
| 77 |
+
chain_id, entity_type = header[:2]
|
| 78 |
+
if len(header) == 3 and header[2] != "":
|
| 79 |
+
assert entity_type.lower() == "protein", (
|
| 80 |
+
"MSA_ID is only allowed for proteins"
|
| 81 |
+
)
|
| 82 |
+
msa_id = header[2]
|
| 83 |
+
else:
|
| 84 |
+
msa_id = None
|
| 85 |
+
|
| 86 |
+
entity_type = entity_type.upper()
|
| 87 |
+
seq = str(seq_record.seq)
|
| 88 |
+
|
| 89 |
+
if entity_type == "PROTEIN":
|
| 90 |
+
molecule = {
|
| 91 |
+
"protein": {
|
| 92 |
+
"id": chain_id,
|
| 93 |
+
"sequence": seq,
|
| 94 |
+
"modifications": [],
|
| 95 |
+
"msa": msa_id,
|
| 96 |
+
},
|
| 97 |
+
}
|
| 98 |
+
elif entity_type == "RNA":
|
| 99 |
+
molecule = {
|
| 100 |
+
"rna": {
|
| 101 |
+
"id": chain_id,
|
| 102 |
+
"sequence": seq,
|
| 103 |
+
"modifications": [],
|
| 104 |
+
},
|
| 105 |
+
}
|
| 106 |
+
elif entity_type == "DNA":
|
| 107 |
+
molecule = {
|
| 108 |
+
"dna": {
|
| 109 |
+
"id": chain_id,
|
| 110 |
+
"sequence": seq,
|
| 111 |
+
"modifications": [],
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
elif entity_type.upper() == "CCD":
|
| 115 |
+
molecule = {
|
| 116 |
+
"ligand": {
|
| 117 |
+
"id": chain_id,
|
| 118 |
+
"ccd": seq,
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
elif entity_type.upper() == "SMILES":
|
| 122 |
+
molecule = {
|
| 123 |
+
"ligand": {
|
| 124 |
+
"id": chain_id,
|
| 125 |
+
"smiles": seq,
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
sequences.append(molecule)
|
| 130 |
+
|
| 131 |
+
data = {
|
| 132 |
+
"sequences": sequences,
|
| 133 |
+
"bonds": [],
|
| 134 |
+
"version": 1,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
name = path.stem
|
| 138 |
+
return parse_boltz_schema(name, data, ccd, mol_dir, boltz2)
|