Spaces:
Running on Zero
Running on Zero
initial commit from source repo
Browse files- app.py +25 -0
- data_utils.py +988 -0
- inputs/1BC8.pdb +0 -0
- model_utils.py +1772 -0
- openfold/__init__.py +6 -0
- openfold/config.py +558 -0
- openfold/data/__init__.py +0 -0
- openfold/data/data_modules.py +721 -0
- openfold/data/data_pipeline.py +826 -0
- openfold/data/data_transforms.py +1212 -0
- openfold/data/errors.py +22 -0
- openfold/data/feature_pipeline.py +116 -0
- openfold/data/input_pipeline.py +208 -0
- openfold/data/mmcif_parsing.py +485 -0
- openfold/data/parsers.py +388 -0
- openfold/data/templates.py +1108 -0
- openfold/data/tools/__init__.py +0 -0
- openfold/data/tools/hhblits.py +175 -0
- openfold/data/tools/hhsearch.py +106 -0
- openfold/data/tools/jackhmmer.py +228 -0
- openfold/data/tools/kalign.py +115 -0
- openfold/data/tools/utils.py +48 -0
- openfold/np/__init__.py +16 -0
- openfold/np/protein.py +441 -0
- openfold/np/relax/__init__.py +16 -0
- openfold/np/relax/amber_minimize.py +625 -0
- openfold/np/relax/cleanup.py +131 -0
- openfold/np/relax/relax.py +93 -0
- openfold/np/relax/utils.py +86 -0
- openfold/np/residue_constants.py +1310 -0
- openfold/resources/__init__.py +0 -0
- openfold/utils/feats.py +274 -0
- openfold/utils/loss.py +1614 -0
- openfold/utils/rigid_utils.py +1367 -0
- openfold/utils/tensor_utils.py +121 -0
- requirements.txt +29 -0
- run.py +990 -0
- run_examples.sh +244 -0
- sc_examples.sh +55 -0
- sc_utils.py +1158 -0
- score.py +549 -0
- space_utils/download_weights.py +7 -0
app.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import warnings
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import shutil
|
| 7 |
+
import spaces
|
| 8 |
+
from space_utils.download_weights import download_ligandmpnn_weights
|
| 9 |
+
|
| 10 |
+
download_ligandmpnn_weights()
|
| 11 |
+
|
| 12 |
+
with gr.Blocks(title="RFD3 Test") as demo:
|
| 13 |
+
out_dir = "./output/test"
|
| 14 |
+
command= f"python run.py --pdb_path ./inputs/1BC8.pdb --out_folder {out_dir}"
|
| 15 |
+
subprocess.run(command, shell=True, check=True, text=True)
|
| 16 |
+
|
| 17 |
+
command = f"ls {out_dir}"
|
| 18 |
+
res = subprocess.run(command, shell=True, check=True, text=True, capture_output=True)
|
| 19 |
+
|
| 20 |
+
gr.Markdown("### Command Output")
|
| 21 |
+
gr.Textbox(value=res.stdout, lines=20)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
demo.launch()
|
data_utils.py
ADDED
|
@@ -0,0 +1,988 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils
|
| 6 |
+
from prody import *
|
| 7 |
+
|
| 8 |
+
confProDy(verbosity="none")
|
| 9 |
+
|
| 10 |
+
restype_1to3 = {
|
| 11 |
+
"A": "ALA",
|
| 12 |
+
"R": "ARG",
|
| 13 |
+
"N": "ASN",
|
| 14 |
+
"D": "ASP",
|
| 15 |
+
"C": "CYS",
|
| 16 |
+
"Q": "GLN",
|
| 17 |
+
"E": "GLU",
|
| 18 |
+
"G": "GLY",
|
| 19 |
+
"H": "HIS",
|
| 20 |
+
"I": "ILE",
|
| 21 |
+
"L": "LEU",
|
| 22 |
+
"K": "LYS",
|
| 23 |
+
"M": "MET",
|
| 24 |
+
"F": "PHE",
|
| 25 |
+
"P": "PRO",
|
| 26 |
+
"S": "SER",
|
| 27 |
+
"T": "THR",
|
| 28 |
+
"W": "TRP",
|
| 29 |
+
"Y": "TYR",
|
| 30 |
+
"V": "VAL",
|
| 31 |
+
"X": "UNK",
|
| 32 |
+
}
|
| 33 |
+
restype_str_to_int = {
|
| 34 |
+
"A": 0,
|
| 35 |
+
"C": 1,
|
| 36 |
+
"D": 2,
|
| 37 |
+
"E": 3,
|
| 38 |
+
"F": 4,
|
| 39 |
+
"G": 5,
|
| 40 |
+
"H": 6,
|
| 41 |
+
"I": 7,
|
| 42 |
+
"K": 8,
|
| 43 |
+
"L": 9,
|
| 44 |
+
"M": 10,
|
| 45 |
+
"N": 11,
|
| 46 |
+
"P": 12,
|
| 47 |
+
"Q": 13,
|
| 48 |
+
"R": 14,
|
| 49 |
+
"S": 15,
|
| 50 |
+
"T": 16,
|
| 51 |
+
"V": 17,
|
| 52 |
+
"W": 18,
|
| 53 |
+
"Y": 19,
|
| 54 |
+
"X": 20,
|
| 55 |
+
}
|
| 56 |
+
restype_int_to_str = {
|
| 57 |
+
0: "A",
|
| 58 |
+
1: "C",
|
| 59 |
+
2: "D",
|
| 60 |
+
3: "E",
|
| 61 |
+
4: "F",
|
| 62 |
+
5: "G",
|
| 63 |
+
6: "H",
|
| 64 |
+
7: "I",
|
| 65 |
+
8: "K",
|
| 66 |
+
9: "L",
|
| 67 |
+
10: "M",
|
| 68 |
+
11: "N",
|
| 69 |
+
12: "P",
|
| 70 |
+
13: "Q",
|
| 71 |
+
14: "R",
|
| 72 |
+
15: "S",
|
| 73 |
+
16: "T",
|
| 74 |
+
17: "V",
|
| 75 |
+
18: "W",
|
| 76 |
+
19: "Y",
|
| 77 |
+
20: "X",
|
| 78 |
+
}
|
| 79 |
+
alphabet = list(restype_str_to_int)
|
| 80 |
+
|
| 81 |
+
element_list = [
|
| 82 |
+
"H",
|
| 83 |
+
"He",
|
| 84 |
+
"Li",
|
| 85 |
+
"Be",
|
| 86 |
+
"B",
|
| 87 |
+
"C",
|
| 88 |
+
"N",
|
| 89 |
+
"O",
|
| 90 |
+
"F",
|
| 91 |
+
"Ne",
|
| 92 |
+
"Na",
|
| 93 |
+
"Mg",
|
| 94 |
+
"Al",
|
| 95 |
+
"Si",
|
| 96 |
+
"P",
|
| 97 |
+
"S",
|
| 98 |
+
"Cl",
|
| 99 |
+
"Ar",
|
| 100 |
+
"K",
|
| 101 |
+
"Ca",
|
| 102 |
+
"Sc",
|
| 103 |
+
"Ti",
|
| 104 |
+
"V",
|
| 105 |
+
"Cr",
|
| 106 |
+
"Mn",
|
| 107 |
+
"Fe",
|
| 108 |
+
"Co",
|
| 109 |
+
"Ni",
|
| 110 |
+
"Cu",
|
| 111 |
+
"Zn",
|
| 112 |
+
"Ga",
|
| 113 |
+
"Ge",
|
| 114 |
+
"As",
|
| 115 |
+
"Se",
|
| 116 |
+
"Br",
|
| 117 |
+
"Kr",
|
| 118 |
+
"Rb",
|
| 119 |
+
"Sr",
|
| 120 |
+
"Y",
|
| 121 |
+
"Zr",
|
| 122 |
+
"Nb",
|
| 123 |
+
"Mb",
|
| 124 |
+
"Tc",
|
| 125 |
+
"Ru",
|
| 126 |
+
"Rh",
|
| 127 |
+
"Pd",
|
| 128 |
+
"Ag",
|
| 129 |
+
"Cd",
|
| 130 |
+
"In",
|
| 131 |
+
"Sn",
|
| 132 |
+
"Sb",
|
| 133 |
+
"Te",
|
| 134 |
+
"I",
|
| 135 |
+
"Xe",
|
| 136 |
+
"Cs",
|
| 137 |
+
"Ba",
|
| 138 |
+
"La",
|
| 139 |
+
"Ce",
|
| 140 |
+
"Pr",
|
| 141 |
+
"Nd",
|
| 142 |
+
"Pm",
|
| 143 |
+
"Sm",
|
| 144 |
+
"Eu",
|
| 145 |
+
"Gd",
|
| 146 |
+
"Tb",
|
| 147 |
+
"Dy",
|
| 148 |
+
"Ho",
|
| 149 |
+
"Er",
|
| 150 |
+
"Tm",
|
| 151 |
+
"Yb",
|
| 152 |
+
"Lu",
|
| 153 |
+
"Hf",
|
| 154 |
+
"Ta",
|
| 155 |
+
"W",
|
| 156 |
+
"Re",
|
| 157 |
+
"Os",
|
| 158 |
+
"Ir",
|
| 159 |
+
"Pt",
|
| 160 |
+
"Au",
|
| 161 |
+
"Hg",
|
| 162 |
+
"Tl",
|
| 163 |
+
"Pb",
|
| 164 |
+
"Bi",
|
| 165 |
+
"Po",
|
| 166 |
+
"At",
|
| 167 |
+
"Rn",
|
| 168 |
+
"Fr",
|
| 169 |
+
"Ra",
|
| 170 |
+
"Ac",
|
| 171 |
+
"Th",
|
| 172 |
+
"Pa",
|
| 173 |
+
"U",
|
| 174 |
+
"Np",
|
| 175 |
+
"Pu",
|
| 176 |
+
"Am",
|
| 177 |
+
"Cm",
|
| 178 |
+
"Bk",
|
| 179 |
+
"Cf",
|
| 180 |
+
"Es",
|
| 181 |
+
"Fm",
|
| 182 |
+
"Md",
|
| 183 |
+
"No",
|
| 184 |
+
"Lr",
|
| 185 |
+
"Rf",
|
| 186 |
+
"Db",
|
| 187 |
+
"Sg",
|
| 188 |
+
"Bh",
|
| 189 |
+
"Hs",
|
| 190 |
+
"Mt",
|
| 191 |
+
"Ds",
|
| 192 |
+
"Rg",
|
| 193 |
+
"Cn",
|
| 194 |
+
"Uut",
|
| 195 |
+
"Fl",
|
| 196 |
+
"Uup",
|
| 197 |
+
"Lv",
|
| 198 |
+
"Uus",
|
| 199 |
+
"Uuo",
|
| 200 |
+
]
|
| 201 |
+
element_list = [item.upper() for item in element_list]
|
| 202 |
+
# element_dict = dict(zip(element_list, range(1,len(element_list))))
|
| 203 |
+
element_dict_rev = dict(zip(range(1, len(element_list)), element_list))
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_seq_rec(S: torch.Tensor, S_pred: torch.Tensor, mask: torch.Tensor):
|
| 207 |
+
"""
|
| 208 |
+
S : true sequence shape=[batch, length]
|
| 209 |
+
S_pred : predicted sequence shape=[batch, length]
|
| 210 |
+
mask : mask to compute average over the region shape=[batch, length]
|
| 211 |
+
|
| 212 |
+
average : averaged sequence recovery shape=[batch]
|
| 213 |
+
"""
|
| 214 |
+
match = S == S_pred
|
| 215 |
+
average = torch.sum(match * mask, dim=-1) / torch.sum(mask, dim=-1)
|
| 216 |
+
return average
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_score(S: torch.Tensor, log_probs: torch.Tensor, mask: torch.Tensor):
|
| 220 |
+
"""
|
| 221 |
+
S : true sequence shape=[batch, length]
|
| 222 |
+
log_probs : predicted sequence shape=[batch, length]
|
| 223 |
+
mask : mask to compute average over the region shape=[batch, length]
|
| 224 |
+
|
| 225 |
+
average_loss : averaged categorical cross entropy (CCE) [batch]
|
| 226 |
+
loss_per_resdue : per position CCE [batch, length]
|
| 227 |
+
"""
|
| 228 |
+
S_one_hot = torch.nn.functional.one_hot(S, 21)
|
| 229 |
+
loss_per_residue = -(S_one_hot * log_probs).sum(-1) # [B, L]
|
| 230 |
+
average_loss = torch.sum(loss_per_residue * mask, dim=-1) / (
|
| 231 |
+
torch.sum(mask, dim=-1) + 1e-8
|
| 232 |
+
)
|
| 233 |
+
return average_loss, loss_per_residue
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def write_full_PDB(
|
| 237 |
+
save_path: str,
|
| 238 |
+
X: np.ndarray,
|
| 239 |
+
X_m: np.ndarray,
|
| 240 |
+
b_factors: np.ndarray,
|
| 241 |
+
R_idx: np.ndarray,
|
| 242 |
+
chain_letters: np.ndarray,
|
| 243 |
+
S: np.ndarray,
|
| 244 |
+
other_atoms=None,
|
| 245 |
+
icodes=None,
|
| 246 |
+
force_hetatm=False,
|
| 247 |
+
):
|
| 248 |
+
"""
|
| 249 |
+
save_path : path where the PDB will be written to
|
| 250 |
+
X : protein atom xyz coordinates shape=[length, 14, 3]
|
| 251 |
+
X_m : protein atom mask shape=[length, 14]
|
| 252 |
+
b_factors: shape=[length, 14]
|
| 253 |
+
R_idx: protein residue indices shape=[length]
|
| 254 |
+
chain_letters: protein chain letters shape=[length]
|
| 255 |
+
S : protein amino acid sequence shape=[length]
|
| 256 |
+
other_atoms: other atoms parsed by prody
|
| 257 |
+
icodes: a list of insertion codes for the PDB; e.g. antibody loops
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
restype_1to3 = {
|
| 261 |
+
"A": "ALA",
|
| 262 |
+
"R": "ARG",
|
| 263 |
+
"N": "ASN",
|
| 264 |
+
"D": "ASP",
|
| 265 |
+
"C": "CYS",
|
| 266 |
+
"Q": "GLN",
|
| 267 |
+
"E": "GLU",
|
| 268 |
+
"G": "GLY",
|
| 269 |
+
"H": "HIS",
|
| 270 |
+
"I": "ILE",
|
| 271 |
+
"L": "LEU",
|
| 272 |
+
"K": "LYS",
|
| 273 |
+
"M": "MET",
|
| 274 |
+
"F": "PHE",
|
| 275 |
+
"P": "PRO",
|
| 276 |
+
"S": "SER",
|
| 277 |
+
"T": "THR",
|
| 278 |
+
"W": "TRP",
|
| 279 |
+
"Y": "TYR",
|
| 280 |
+
"V": "VAL",
|
| 281 |
+
"X": "UNK",
|
| 282 |
+
}
|
| 283 |
+
restype_INTtoSTR = {
|
| 284 |
+
0: "A",
|
| 285 |
+
1: "C",
|
| 286 |
+
2: "D",
|
| 287 |
+
3: "E",
|
| 288 |
+
4: "F",
|
| 289 |
+
5: "G",
|
| 290 |
+
6: "H",
|
| 291 |
+
7: "I",
|
| 292 |
+
8: "K",
|
| 293 |
+
9: "L",
|
| 294 |
+
10: "M",
|
| 295 |
+
11: "N",
|
| 296 |
+
12: "P",
|
| 297 |
+
13: "Q",
|
| 298 |
+
14: "R",
|
| 299 |
+
15: "S",
|
| 300 |
+
16: "T",
|
| 301 |
+
17: "V",
|
| 302 |
+
18: "W",
|
| 303 |
+
19: "Y",
|
| 304 |
+
20: "X",
|
| 305 |
+
}
|
| 306 |
+
restype_name_to_atom14_names = {
|
| 307 |
+
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
|
| 308 |
+
"ARG": [
|
| 309 |
+
"N",
|
| 310 |
+
"CA",
|
| 311 |
+
"C",
|
| 312 |
+
"O",
|
| 313 |
+
"CB",
|
| 314 |
+
"CG",
|
| 315 |
+
"CD",
|
| 316 |
+
"NE",
|
| 317 |
+
"CZ",
|
| 318 |
+
"NH1",
|
| 319 |
+
"NH2",
|
| 320 |
+
"",
|
| 321 |
+
"",
|
| 322 |
+
"",
|
| 323 |
+
],
|
| 324 |
+
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
|
| 325 |
+
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
|
| 326 |
+
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
|
| 327 |
+
"GLN": [
|
| 328 |
+
"N",
|
| 329 |
+
"CA",
|
| 330 |
+
"C",
|
| 331 |
+
"O",
|
| 332 |
+
"CB",
|
| 333 |
+
"CG",
|
| 334 |
+
"CD",
|
| 335 |
+
"OE1",
|
| 336 |
+
"NE2",
|
| 337 |
+
"",
|
| 338 |
+
"",
|
| 339 |
+
"",
|
| 340 |
+
"",
|
| 341 |
+
"",
|
| 342 |
+
],
|
| 343 |
+
"GLU": [
|
| 344 |
+
"N",
|
| 345 |
+
"CA",
|
| 346 |
+
"C",
|
| 347 |
+
"O",
|
| 348 |
+
"CB",
|
| 349 |
+
"CG",
|
| 350 |
+
"CD",
|
| 351 |
+
"OE1",
|
| 352 |
+
"OE2",
|
| 353 |
+
"",
|
| 354 |
+
"",
|
| 355 |
+
"",
|
| 356 |
+
"",
|
| 357 |
+
"",
|
| 358 |
+
],
|
| 359 |
+
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
|
| 360 |
+
"HIS": [
|
| 361 |
+
"N",
|
| 362 |
+
"CA",
|
| 363 |
+
"C",
|
| 364 |
+
"O",
|
| 365 |
+
"CB",
|
| 366 |
+
"CG",
|
| 367 |
+
"ND1",
|
| 368 |
+
"CD2",
|
| 369 |
+
"CE1",
|
| 370 |
+
"NE2",
|
| 371 |
+
"",
|
| 372 |
+
"",
|
| 373 |
+
"",
|
| 374 |
+
"",
|
| 375 |
+
],
|
| 376 |
+
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
|
| 377 |
+
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
|
| 378 |
+
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
|
| 379 |
+
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
|
| 380 |
+
"PHE": [
|
| 381 |
+
"N",
|
| 382 |
+
"CA",
|
| 383 |
+
"C",
|
| 384 |
+
"O",
|
| 385 |
+
"CB",
|
| 386 |
+
"CG",
|
| 387 |
+
"CD1",
|
| 388 |
+
"CD2",
|
| 389 |
+
"CE1",
|
| 390 |
+
"CE2",
|
| 391 |
+
"CZ",
|
| 392 |
+
"",
|
| 393 |
+
"",
|
| 394 |
+
"",
|
| 395 |
+
],
|
| 396 |
+
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
|
| 397 |
+
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
|
| 398 |
+
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
|
| 399 |
+
"TRP": [
|
| 400 |
+
"N",
|
| 401 |
+
"CA",
|
| 402 |
+
"C",
|
| 403 |
+
"O",
|
| 404 |
+
"CB",
|
| 405 |
+
"CG",
|
| 406 |
+
"CD1",
|
| 407 |
+
"CD2",
|
| 408 |
+
"CE2",
|
| 409 |
+
"CE3",
|
| 410 |
+
"NE1",
|
| 411 |
+
"CZ2",
|
| 412 |
+
"CZ3",
|
| 413 |
+
"CH2",
|
| 414 |
+
],
|
| 415 |
+
"TYR": [
|
| 416 |
+
"N",
|
| 417 |
+
"CA",
|
| 418 |
+
"C",
|
| 419 |
+
"O",
|
| 420 |
+
"CB",
|
| 421 |
+
"CG",
|
| 422 |
+
"CD1",
|
| 423 |
+
"CD2",
|
| 424 |
+
"CE1",
|
| 425 |
+
"CE2",
|
| 426 |
+
"CZ",
|
| 427 |
+
"OH",
|
| 428 |
+
"",
|
| 429 |
+
"",
|
| 430 |
+
],
|
| 431 |
+
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
|
| 432 |
+
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
S_str = [restype_1to3[AA] for AA in [restype_INTtoSTR[AA] for AA in S]]
|
| 436 |
+
|
| 437 |
+
X_list = []
|
| 438 |
+
b_factor_list = []
|
| 439 |
+
atom_name_list = []
|
| 440 |
+
element_name_list = []
|
| 441 |
+
residue_name_list = []
|
| 442 |
+
residue_number_list = []
|
| 443 |
+
chain_id_list = []
|
| 444 |
+
icodes_list = []
|
| 445 |
+
for i, AA in enumerate(S_str):
|
| 446 |
+
sel = X_m[i].astype(np.int32) == 1
|
| 447 |
+
total = np.sum(sel)
|
| 448 |
+
tmp = np.array(restype_name_to_atom14_names[AA])[sel]
|
| 449 |
+
X_list.append(X[i][sel])
|
| 450 |
+
b_factor_list.append(b_factors[i][sel])
|
| 451 |
+
atom_name_list.append(tmp)
|
| 452 |
+
element_name_list += [AA[:1] for AA in list(tmp)]
|
| 453 |
+
residue_name_list += total * [AA]
|
| 454 |
+
residue_number_list += total * [R_idx[i]]
|
| 455 |
+
chain_id_list += total * [chain_letters[i]]
|
| 456 |
+
icodes_list += total * [icodes[i]]
|
| 457 |
+
|
| 458 |
+
X_stack = np.concatenate(X_list, 0)
|
| 459 |
+
b_factor_stack = np.concatenate(b_factor_list, 0)
|
| 460 |
+
atom_name_stack = np.concatenate(atom_name_list, 0)
|
| 461 |
+
|
| 462 |
+
protein = prody.AtomGroup()
|
| 463 |
+
protein.setCoords(X_stack)
|
| 464 |
+
protein.setBetas(b_factor_stack)
|
| 465 |
+
protein.setNames(atom_name_stack)
|
| 466 |
+
protein.setResnames(residue_name_list)
|
| 467 |
+
protein.setElements(element_name_list)
|
| 468 |
+
protein.setOccupancies(np.ones([X_stack.shape[0]]))
|
| 469 |
+
protein.setResnums(residue_number_list)
|
| 470 |
+
protein.setChids(chain_id_list)
|
| 471 |
+
protein.setIcodes(icodes_list)
|
| 472 |
+
|
| 473 |
+
if other_atoms:
|
| 474 |
+
other_atoms_g = prody.AtomGroup()
|
| 475 |
+
other_atoms_g.setCoords(other_atoms.getCoords())
|
| 476 |
+
other_atoms_g.setNames(other_atoms.getNames())
|
| 477 |
+
other_atoms_g.setResnames(other_atoms.getResnames())
|
| 478 |
+
other_atoms_g.setElements(other_atoms.getElements())
|
| 479 |
+
other_atoms_g.setOccupancies(other_atoms.getOccupancies())
|
| 480 |
+
other_atoms_g.setResnums(other_atoms.getResnums())
|
| 481 |
+
other_atoms_g.setChids(other_atoms.getChids())
|
| 482 |
+
if force_hetatm:
|
| 483 |
+
other_atoms_g.setFlags("hetatm", other_atoms.getFlags("hetatm"))
|
| 484 |
+
writePDB(save_path, protein + other_atoms_g)
|
| 485 |
+
else:
|
| 486 |
+
writePDB(save_path, protein)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def get_aligned_coordinates(protein_atoms, CA_dict: dict, atom_name: str):
|
| 490 |
+
"""
|
| 491 |
+
protein_atoms: prody atom group
|
| 492 |
+
CA_dict: mapping between chain_residue_idx_icodes and integers
|
| 493 |
+
atom_name: atom to be parsed; e.g. CA
|
| 494 |
+
"""
|
| 495 |
+
atom_atoms = protein_atoms.select(f"name {atom_name}")
|
| 496 |
+
|
| 497 |
+
if atom_atoms != None:
|
| 498 |
+
atom_coords = atom_atoms.getCoords()
|
| 499 |
+
atom_resnums = atom_atoms.getResnums()
|
| 500 |
+
atom_chain_ids = atom_atoms.getChids()
|
| 501 |
+
atom_icodes = atom_atoms.getIcodes()
|
| 502 |
+
|
| 503 |
+
atom_coords_ = np.zeros([len(CA_dict), 3], np.float32)
|
| 504 |
+
atom_coords_m = np.zeros([len(CA_dict)], np.int32)
|
| 505 |
+
if atom_atoms != None:
|
| 506 |
+
for i in range(len(atom_resnums)):
|
| 507 |
+
code = atom_chain_ids[i] + "_" + str(atom_resnums[i]) + "_" + atom_icodes[i]
|
| 508 |
+
if code in list(CA_dict):
|
| 509 |
+
atom_coords_[CA_dict[code], :] = atom_coords[i]
|
| 510 |
+
atom_coords_m[CA_dict[code]] = 1
|
| 511 |
+
return atom_coords_, atom_coords_m
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def parse_PDB(
|
| 515 |
+
input_path: str,
|
| 516 |
+
device: str = "cpu",
|
| 517 |
+
chains: list = [],
|
| 518 |
+
parse_all_atoms: bool = False,
|
| 519 |
+
parse_atoms_with_zero_occupancy: bool = False
|
| 520 |
+
):
|
| 521 |
+
"""
|
| 522 |
+
input_path : path for the input PDB
|
| 523 |
+
device: device for the torch.Tensor
|
| 524 |
+
chains: a list specifying which chains need to be parsed; e.g. ["A", "B"]
|
| 525 |
+
parse_all_atoms: if False parse only N,CA,C,O otherwise all 37 atoms
|
| 526 |
+
parse_atoms_with_zero_occupancy: if True atoms with zero occupancy will be parsed
|
| 527 |
+
"""
|
| 528 |
+
element_list = [
|
| 529 |
+
"H",
|
| 530 |
+
"He",
|
| 531 |
+
"Li",
|
| 532 |
+
"Be",
|
| 533 |
+
"B",
|
| 534 |
+
"C",
|
| 535 |
+
"N",
|
| 536 |
+
"O",
|
| 537 |
+
"F",
|
| 538 |
+
"Ne",
|
| 539 |
+
"Na",
|
| 540 |
+
"Mg",
|
| 541 |
+
"Al",
|
| 542 |
+
"Si",
|
| 543 |
+
"P",
|
| 544 |
+
"S",
|
| 545 |
+
"Cl",
|
| 546 |
+
"Ar",
|
| 547 |
+
"K",
|
| 548 |
+
"Ca",
|
| 549 |
+
"Sc",
|
| 550 |
+
"Ti",
|
| 551 |
+
"V",
|
| 552 |
+
"Cr",
|
| 553 |
+
"Mn",
|
| 554 |
+
"Fe",
|
| 555 |
+
"Co",
|
| 556 |
+
"Ni",
|
| 557 |
+
"Cu",
|
| 558 |
+
"Zn",
|
| 559 |
+
"Ga",
|
| 560 |
+
"Ge",
|
| 561 |
+
"As",
|
| 562 |
+
"Se",
|
| 563 |
+
"Br",
|
| 564 |
+
"Kr",
|
| 565 |
+
"Rb",
|
| 566 |
+
"Sr",
|
| 567 |
+
"Y",
|
| 568 |
+
"Zr",
|
| 569 |
+
"Nb",
|
| 570 |
+
"Mb",
|
| 571 |
+
"Tc",
|
| 572 |
+
"Ru",
|
| 573 |
+
"Rh",
|
| 574 |
+
"Pd",
|
| 575 |
+
"Ag",
|
| 576 |
+
"Cd",
|
| 577 |
+
"In",
|
| 578 |
+
"Sn",
|
| 579 |
+
"Sb",
|
| 580 |
+
"Te",
|
| 581 |
+
"I",
|
| 582 |
+
"Xe",
|
| 583 |
+
"Cs",
|
| 584 |
+
"Ba",
|
| 585 |
+
"La",
|
| 586 |
+
"Ce",
|
| 587 |
+
"Pr",
|
| 588 |
+
"Nd",
|
| 589 |
+
"Pm",
|
| 590 |
+
"Sm",
|
| 591 |
+
"Eu",
|
| 592 |
+
"Gd",
|
| 593 |
+
"Tb",
|
| 594 |
+
"Dy",
|
| 595 |
+
"Ho",
|
| 596 |
+
"Er",
|
| 597 |
+
"Tm",
|
| 598 |
+
"Yb",
|
| 599 |
+
"Lu",
|
| 600 |
+
"Hf",
|
| 601 |
+
"Ta",
|
| 602 |
+
"W",
|
| 603 |
+
"Re",
|
| 604 |
+
"Os",
|
| 605 |
+
"Ir",
|
| 606 |
+
"Pt",
|
| 607 |
+
"Au",
|
| 608 |
+
"Hg",
|
| 609 |
+
"Tl",
|
| 610 |
+
"Pb",
|
| 611 |
+
"Bi",
|
| 612 |
+
"Po",
|
| 613 |
+
"At",
|
| 614 |
+
"Rn",
|
| 615 |
+
"Fr",
|
| 616 |
+
"Ra",
|
| 617 |
+
"Ac",
|
| 618 |
+
"Th",
|
| 619 |
+
"Pa",
|
| 620 |
+
"U",
|
| 621 |
+
"Np",
|
| 622 |
+
"Pu",
|
| 623 |
+
"Am",
|
| 624 |
+
"Cm",
|
| 625 |
+
"Bk",
|
| 626 |
+
"Cf",
|
| 627 |
+
"Es",
|
| 628 |
+
"Fm",
|
| 629 |
+
"Md",
|
| 630 |
+
"No",
|
| 631 |
+
"Lr",
|
| 632 |
+
"Rf",
|
| 633 |
+
"Db",
|
| 634 |
+
"Sg",
|
| 635 |
+
"Bh",
|
| 636 |
+
"Hs",
|
| 637 |
+
"Mt",
|
| 638 |
+
"Ds",
|
| 639 |
+
"Rg",
|
| 640 |
+
"Cn",
|
| 641 |
+
"Uut",
|
| 642 |
+
"Fl",
|
| 643 |
+
"Uup",
|
| 644 |
+
"Lv",
|
| 645 |
+
"Uus",
|
| 646 |
+
"Uuo",
|
| 647 |
+
]
|
| 648 |
+
element_list = [item.upper() for item in element_list]
|
| 649 |
+
element_dict = dict(zip(element_list, range(1, len(element_list))))
|
| 650 |
+
restype_3to1 = {
|
| 651 |
+
"ALA": "A",
|
| 652 |
+
"ARG": "R",
|
| 653 |
+
"ASN": "N",
|
| 654 |
+
"ASP": "D",
|
| 655 |
+
"CYS": "C",
|
| 656 |
+
"GLN": "Q",
|
| 657 |
+
"GLU": "E",
|
| 658 |
+
"GLY": "G",
|
| 659 |
+
"HIS": "H",
|
| 660 |
+
"ILE": "I",
|
| 661 |
+
"LEU": "L",
|
| 662 |
+
"LYS": "K",
|
| 663 |
+
"MET": "M",
|
| 664 |
+
"PHE": "F",
|
| 665 |
+
"PRO": "P",
|
| 666 |
+
"SER": "S",
|
| 667 |
+
"THR": "T",
|
| 668 |
+
"TRP": "W",
|
| 669 |
+
"TYR": "Y",
|
| 670 |
+
"VAL": "V",
|
| 671 |
+
}
|
| 672 |
+
restype_STRtoINT = {
|
| 673 |
+
"A": 0,
|
| 674 |
+
"C": 1,
|
| 675 |
+
"D": 2,
|
| 676 |
+
"E": 3,
|
| 677 |
+
"F": 4,
|
| 678 |
+
"G": 5,
|
| 679 |
+
"H": 6,
|
| 680 |
+
"I": 7,
|
| 681 |
+
"K": 8,
|
| 682 |
+
"L": 9,
|
| 683 |
+
"M": 10,
|
| 684 |
+
"N": 11,
|
| 685 |
+
"P": 12,
|
| 686 |
+
"Q": 13,
|
| 687 |
+
"R": 14,
|
| 688 |
+
"S": 15,
|
| 689 |
+
"T": 16,
|
| 690 |
+
"V": 17,
|
| 691 |
+
"W": 18,
|
| 692 |
+
"Y": 19,
|
| 693 |
+
"X": 20,
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
atom_order = {
|
| 697 |
+
"N": 0,
|
| 698 |
+
"CA": 1,
|
| 699 |
+
"C": 2,
|
| 700 |
+
"CB": 3,
|
| 701 |
+
"O": 4,
|
| 702 |
+
"CG": 5,
|
| 703 |
+
"CG1": 6,
|
| 704 |
+
"CG2": 7,
|
| 705 |
+
"OG": 8,
|
| 706 |
+
"OG1": 9,
|
| 707 |
+
"SG": 10,
|
| 708 |
+
"CD": 11,
|
| 709 |
+
"CD1": 12,
|
| 710 |
+
"CD2": 13,
|
| 711 |
+
"ND1": 14,
|
| 712 |
+
"ND2": 15,
|
| 713 |
+
"OD1": 16,
|
| 714 |
+
"OD2": 17,
|
| 715 |
+
"SD": 18,
|
| 716 |
+
"CE": 19,
|
| 717 |
+
"CE1": 20,
|
| 718 |
+
"CE2": 21,
|
| 719 |
+
"CE3": 22,
|
| 720 |
+
"NE": 23,
|
| 721 |
+
"NE1": 24,
|
| 722 |
+
"NE2": 25,
|
| 723 |
+
"OE1": 26,
|
| 724 |
+
"OE2": 27,
|
| 725 |
+
"CH2": 28,
|
| 726 |
+
"NH1": 29,
|
| 727 |
+
"NH2": 30,
|
| 728 |
+
"OH": 31,
|
| 729 |
+
"CZ": 32,
|
| 730 |
+
"CZ2": 33,
|
| 731 |
+
"CZ3": 34,
|
| 732 |
+
"NZ": 35,
|
| 733 |
+
"OXT": 36,
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
if not parse_all_atoms:
|
| 737 |
+
atom_types = ["N", "CA", "C", "O"]
|
| 738 |
+
else:
|
| 739 |
+
atom_types = [
|
| 740 |
+
"N",
|
| 741 |
+
"CA",
|
| 742 |
+
"C",
|
| 743 |
+
"CB",
|
| 744 |
+
"O",
|
| 745 |
+
"CG",
|
| 746 |
+
"CG1",
|
| 747 |
+
"CG2",
|
| 748 |
+
"OG",
|
| 749 |
+
"OG1",
|
| 750 |
+
"SG",
|
| 751 |
+
"CD",
|
| 752 |
+
"CD1",
|
| 753 |
+
"CD2",
|
| 754 |
+
"ND1",
|
| 755 |
+
"ND2",
|
| 756 |
+
"OD1",
|
| 757 |
+
"OD2",
|
| 758 |
+
"SD",
|
| 759 |
+
"CE",
|
| 760 |
+
"CE1",
|
| 761 |
+
"CE2",
|
| 762 |
+
"CE3",
|
| 763 |
+
"NE",
|
| 764 |
+
"NE1",
|
| 765 |
+
"NE2",
|
| 766 |
+
"OE1",
|
| 767 |
+
"OE2",
|
| 768 |
+
"CH2",
|
| 769 |
+
"NH1",
|
| 770 |
+
"NH2",
|
| 771 |
+
"OH",
|
| 772 |
+
"CZ",
|
| 773 |
+
"CZ2",
|
| 774 |
+
"CZ3",
|
| 775 |
+
"NZ",
|
| 776 |
+
]
|
| 777 |
+
|
| 778 |
+
atoms = parsePDB(input_path)
|
| 779 |
+
if not parse_atoms_with_zero_occupancy:
|
| 780 |
+
atoms = atoms.select("occupancy > 0")
|
| 781 |
+
if chains:
|
| 782 |
+
str_out = ""
|
| 783 |
+
for item in chains:
|
| 784 |
+
str_out += " chain " + item + " or"
|
| 785 |
+
atoms = atoms.select(str_out[1:-3])
|
| 786 |
+
|
| 787 |
+
protein_atoms = atoms.select("protein")
|
| 788 |
+
backbone = protein_atoms.select("backbone")
|
| 789 |
+
other_atoms = atoms.select("not protein and not water")
|
| 790 |
+
water_atoms = atoms.select("water")
|
| 791 |
+
|
| 792 |
+
CA_atoms = protein_atoms.select("name CA")
|
| 793 |
+
CA_resnums = CA_atoms.getResnums()
|
| 794 |
+
CA_chain_ids = CA_atoms.getChids()
|
| 795 |
+
CA_icodes = CA_atoms.getIcodes()
|
| 796 |
+
|
| 797 |
+
CA_dict = {}
|
| 798 |
+
for i in range(len(CA_resnums)):
|
| 799 |
+
code = CA_chain_ids[i] + "_" + str(CA_resnums[i]) + "_" + CA_icodes[i]
|
| 800 |
+
CA_dict[code] = i
|
| 801 |
+
|
| 802 |
+
xyz_37 = np.zeros([len(CA_dict), 37, 3], np.float32)
|
| 803 |
+
xyz_37_m = np.zeros([len(CA_dict), 37], np.int32)
|
| 804 |
+
for atom_name in atom_types:
|
| 805 |
+
xyz, xyz_m = get_aligned_coordinates(protein_atoms, CA_dict, atom_name)
|
| 806 |
+
xyz_37[:, atom_order[atom_name], :] = xyz
|
| 807 |
+
xyz_37_m[:, atom_order[atom_name]] = xyz_m
|
| 808 |
+
|
| 809 |
+
N = xyz_37[:, atom_order["N"], :]
|
| 810 |
+
CA = xyz_37[:, atom_order["CA"], :]
|
| 811 |
+
C = xyz_37[:, atom_order["C"], :]
|
| 812 |
+
O = xyz_37[:, atom_order["O"], :]
|
| 813 |
+
|
| 814 |
+
N_m = xyz_37_m[:, atom_order["N"]]
|
| 815 |
+
CA_m = xyz_37_m[:, atom_order["CA"]]
|
| 816 |
+
C_m = xyz_37_m[:, atom_order["C"]]
|
| 817 |
+
O_m = xyz_37_m[:, atom_order["O"]]
|
| 818 |
+
|
| 819 |
+
mask = N_m * CA_m * C_m * O_m # must all 4 atoms exist
|
| 820 |
+
|
| 821 |
+
b = CA - N
|
| 822 |
+
c = C - CA
|
| 823 |
+
a = np.cross(b, c, axis=-1)
|
| 824 |
+
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
|
| 825 |
+
|
| 826 |
+
chain_labels = np.array(CA_atoms.getChindices(), dtype=np.int32)
|
| 827 |
+
R_idx = np.array(CA_resnums, dtype=np.int32)
|
| 828 |
+
S = CA_atoms.getResnames()
|
| 829 |
+
S = [restype_3to1[AA] if AA in list(restype_3to1) else "X" for AA in list(S)]
|
| 830 |
+
S = np.array([restype_STRtoINT[AA] for AA in list(S)], np.int32)
|
| 831 |
+
X = np.concatenate([N[:, None], CA[:, None], C[:, None], O[:, None]], 1)
|
| 832 |
+
|
| 833 |
+
try:
|
| 834 |
+
Y = np.array(other_atoms.getCoords(), dtype=np.float32)
|
| 835 |
+
Y_t = list(other_atoms.getElements())
|
| 836 |
+
Y_t = np.array(
|
| 837 |
+
[
|
| 838 |
+
element_dict[y_t.upper()] if y_t.upper() in element_list else 0
|
| 839 |
+
for y_t in Y_t
|
| 840 |
+
],
|
| 841 |
+
dtype=np.int32,
|
| 842 |
+
)
|
| 843 |
+
Y_m = (Y_t != 1) * (Y_t != 0)
|
| 844 |
+
|
| 845 |
+
Y = Y[Y_m, :]
|
| 846 |
+
Y_t = Y_t[Y_m]
|
| 847 |
+
Y_m = Y_m[Y_m]
|
| 848 |
+
except:
|
| 849 |
+
Y = np.zeros([1, 3], np.float32)
|
| 850 |
+
Y_t = np.zeros([1], np.int32)
|
| 851 |
+
Y_m = np.zeros([1], np.int32)
|
| 852 |
+
|
| 853 |
+
output_dict = {}
|
| 854 |
+
output_dict["X"] = torch.tensor(X, device=device, dtype=torch.float32)
|
| 855 |
+
output_dict["mask"] = torch.tensor(mask, device=device, dtype=torch.int32)
|
| 856 |
+
output_dict["Y"] = torch.tensor(Y, device=device, dtype=torch.float32)
|
| 857 |
+
output_dict["Y_t"] = torch.tensor(Y_t, device=device, dtype=torch.int32)
|
| 858 |
+
output_dict["Y_m"] = torch.tensor(Y_m, device=device, dtype=torch.int32)
|
| 859 |
+
|
| 860 |
+
output_dict["R_idx"] = torch.tensor(R_idx, device=device, dtype=torch.int32)
|
| 861 |
+
output_dict["chain_labels"] = torch.tensor(
|
| 862 |
+
chain_labels, device=device, dtype=torch.int32
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
output_dict["chain_letters"] = CA_chain_ids
|
| 866 |
+
|
| 867 |
+
mask_c = []
|
| 868 |
+
chain_list = list(set(output_dict["chain_letters"]))
|
| 869 |
+
chain_list.sort()
|
| 870 |
+
for chain in chain_list:
|
| 871 |
+
mask_c.append(
|
| 872 |
+
torch.tensor(
|
| 873 |
+
[chain == item for item in output_dict["chain_letters"]],
|
| 874 |
+
device=device,
|
| 875 |
+
dtype=bool,
|
| 876 |
+
)
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
output_dict["mask_c"] = mask_c
|
| 880 |
+
output_dict["chain_list"] = chain_list
|
| 881 |
+
|
| 882 |
+
output_dict["S"] = torch.tensor(S, device=device, dtype=torch.int32)
|
| 883 |
+
|
| 884 |
+
output_dict["xyz_37"] = torch.tensor(xyz_37, device=device, dtype=torch.float32)
|
| 885 |
+
output_dict["xyz_37_m"] = torch.tensor(xyz_37_m, device=device, dtype=torch.int32)
|
| 886 |
+
|
| 887 |
+
return output_dict, backbone, other_atoms, CA_icodes, water_atoms
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
def get_nearest_neighbours(CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms):
|
| 891 |
+
device = CB.device
|
| 892 |
+
mask_CBY = mask[:, None] * Y_m[None, :] # [A,B]
|
| 893 |
+
L2_AB = torch.sum((CB[:, None, :] - Y[None, :, :]) ** 2, -1)
|
| 894 |
+
L2_AB = L2_AB * mask_CBY + (1 - mask_CBY) * 1000.0
|
| 895 |
+
|
| 896 |
+
nn_idx = torch.argsort(L2_AB, -1)[:, :number_of_ligand_atoms]
|
| 897 |
+
L2_AB_nn = torch.gather(L2_AB, 1, nn_idx)
|
| 898 |
+
D_AB_closest = torch.sqrt(L2_AB_nn[:, 0])
|
| 899 |
+
|
| 900 |
+
Y_r = Y[None, :, :].repeat(CB.shape[0], 1, 1)
|
| 901 |
+
Y_t_r = Y_t[None, :].repeat(CB.shape[0], 1)
|
| 902 |
+
Y_m_r = Y_m[None, :].repeat(CB.shape[0], 1)
|
| 903 |
+
|
| 904 |
+
Y_tmp = torch.gather(Y_r, 1, nn_idx[:, :, None].repeat(1, 1, 3))
|
| 905 |
+
Y_t_tmp = torch.gather(Y_t_r, 1, nn_idx)
|
| 906 |
+
Y_m_tmp = torch.gather(Y_m_r, 1, nn_idx)
|
| 907 |
+
|
| 908 |
+
Y = torch.zeros(
|
| 909 |
+
[CB.shape[0], number_of_ligand_atoms, 3], dtype=torch.float32, device=device
|
| 910 |
+
)
|
| 911 |
+
Y_t = torch.zeros(
|
| 912 |
+
[CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device
|
| 913 |
+
)
|
| 914 |
+
Y_m = torch.zeros(
|
| 915 |
+
[CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
num_nn_update = Y_tmp.shape[1]
|
| 919 |
+
Y[:, :num_nn_update] = Y_tmp
|
| 920 |
+
Y_t[:, :num_nn_update] = Y_t_tmp
|
| 921 |
+
Y_m[:, :num_nn_update] = Y_m_tmp
|
| 922 |
+
|
| 923 |
+
return Y, Y_t, Y_m, D_AB_closest
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
def featurize(
|
| 927 |
+
input_dict,
|
| 928 |
+
cutoff_for_score=8.0,
|
| 929 |
+
use_atom_context=True,
|
| 930 |
+
number_of_ligand_atoms=16,
|
| 931 |
+
model_type="protein_mpnn",
|
| 932 |
+
):
|
| 933 |
+
output_dict = {}
|
| 934 |
+
if model_type == "ligand_mpnn":
|
| 935 |
+
mask = input_dict["mask"]
|
| 936 |
+
Y = input_dict["Y"]
|
| 937 |
+
Y_t = input_dict["Y_t"]
|
| 938 |
+
Y_m = input_dict["Y_m"]
|
| 939 |
+
N = input_dict["X"][:, 0, :]
|
| 940 |
+
CA = input_dict["X"][:, 1, :]
|
| 941 |
+
C = input_dict["X"][:, 2, :]
|
| 942 |
+
b = CA - N
|
| 943 |
+
c = C - CA
|
| 944 |
+
a = torch.cross(b, c, axis=-1)
|
| 945 |
+
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
|
| 946 |
+
Y, Y_t, Y_m, D_XY = get_nearest_neighbours(
|
| 947 |
+
CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms
|
| 948 |
+
)
|
| 949 |
+
mask_XY = (D_XY < cutoff_for_score) * mask * Y_m[:, 0]
|
| 950 |
+
output_dict["mask_XY"] = mask_XY[None,]
|
| 951 |
+
if "side_chain_mask" in list(input_dict):
|
| 952 |
+
output_dict["side_chain_mask"] = input_dict["side_chain_mask"][None,]
|
| 953 |
+
output_dict["Y"] = Y[None,]
|
| 954 |
+
output_dict["Y_t"] = Y_t[None,]
|
| 955 |
+
output_dict["Y_m"] = Y_m[None,]
|
| 956 |
+
if not use_atom_context:
|
| 957 |
+
output_dict["Y_m"] = 0.0 * output_dict["Y_m"]
|
| 958 |
+
elif (
|
| 959 |
+
model_type == "per_residue_label_membrane_mpnn"
|
| 960 |
+
or model_type == "global_label_membrane_mpnn"
|
| 961 |
+
):
|
| 962 |
+
output_dict["membrane_per_residue_labels"] = input_dict[
|
| 963 |
+
"membrane_per_residue_labels"
|
| 964 |
+
][None,]
|
| 965 |
+
|
| 966 |
+
R_idx_list = []
|
| 967 |
+
count = 0
|
| 968 |
+
R_idx_prev = -100000
|
| 969 |
+
for R_idx in list(input_dict["R_idx"]):
|
| 970 |
+
if R_idx_prev == R_idx:
|
| 971 |
+
count += 1
|
| 972 |
+
R_idx_list.append(R_idx + count)
|
| 973 |
+
R_idx_prev = R_idx
|
| 974 |
+
R_idx_renumbered = torch.tensor(R_idx_list, device=R_idx.device)
|
| 975 |
+
output_dict["R_idx"] = R_idx_renumbered[None,]
|
| 976 |
+
output_dict["R_idx_original"] = input_dict["R_idx"][None,]
|
| 977 |
+
output_dict["chain_labels"] = input_dict["chain_labels"][None,]
|
| 978 |
+
output_dict["S"] = input_dict["S"][None,]
|
| 979 |
+
output_dict["chain_mask"] = input_dict["chain_mask"][None,]
|
| 980 |
+
output_dict["mask"] = input_dict["mask"][None,]
|
| 981 |
+
|
| 982 |
+
output_dict["X"] = input_dict["X"][None,]
|
| 983 |
+
|
| 984 |
+
if "xyz_37" in list(input_dict):
|
| 985 |
+
output_dict["xyz_37"] = input_dict["xyz_37"][None,]
|
| 986 |
+
output_dict["xyz_37_m"] = input_dict["xyz_37_m"][None,]
|
| 987 |
+
|
| 988 |
+
return output_dict
|
inputs/1BC8.pdb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model_utils.py
ADDED
|
@@ -0,0 +1,1772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
|
| 3 |
+
import itertools
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ProteinMPNN(torch.nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
num_letters=21,
|
| 14 |
+
node_features=128,
|
| 15 |
+
edge_features=128,
|
| 16 |
+
hidden_dim=128,
|
| 17 |
+
num_encoder_layers=3,
|
| 18 |
+
num_decoder_layers=3,
|
| 19 |
+
vocab=21,
|
| 20 |
+
k_neighbors=48,
|
| 21 |
+
augment_eps=0.0,
|
| 22 |
+
dropout=0.0,
|
| 23 |
+
device=None,
|
| 24 |
+
atom_context_num=0,
|
| 25 |
+
model_type="protein_mpnn",
|
| 26 |
+
ligand_mpnn_use_side_chain_context=False,
|
| 27 |
+
):
|
| 28 |
+
super(ProteinMPNN, self).__init__()
|
| 29 |
+
|
| 30 |
+
self.model_type = model_type
|
| 31 |
+
self.node_features = node_features
|
| 32 |
+
self.edge_features = edge_features
|
| 33 |
+
self.hidden_dim = hidden_dim
|
| 34 |
+
|
| 35 |
+
if self.model_type == "ligand_mpnn":
|
| 36 |
+
self.features = ProteinFeaturesLigand(
|
| 37 |
+
node_features,
|
| 38 |
+
edge_features,
|
| 39 |
+
top_k=k_neighbors,
|
| 40 |
+
augment_eps=augment_eps,
|
| 41 |
+
device=device,
|
| 42 |
+
atom_context_num=atom_context_num,
|
| 43 |
+
use_side_chains=ligand_mpnn_use_side_chain_context,
|
| 44 |
+
)
|
| 45 |
+
self.W_v = torch.nn.Linear(node_features, hidden_dim, bias=True)
|
| 46 |
+
self.W_c = torch.nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 47 |
+
|
| 48 |
+
self.W_nodes_y = torch.nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 49 |
+
self.W_edges_y = torch.nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 50 |
+
|
| 51 |
+
self.V_C = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 52 |
+
self.V_C_norm = torch.nn.LayerNorm(hidden_dim)
|
| 53 |
+
|
| 54 |
+
self.context_encoder_layers = torch.nn.ModuleList(
|
| 55 |
+
[
|
| 56 |
+
DecLayer(hidden_dim, hidden_dim * 2, dropout=dropout)
|
| 57 |
+
for _ in range(2)
|
| 58 |
+
]
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.y_context_encoder_layers = torch.nn.ModuleList(
|
| 62 |
+
[DecLayerJ(hidden_dim, hidden_dim, dropout=dropout) for _ in range(2)]
|
| 63 |
+
)
|
| 64 |
+
elif self.model_type == "protein_mpnn" or self.model_type == "soluble_mpnn":
|
| 65 |
+
self.features = ProteinFeatures(
|
| 66 |
+
node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps
|
| 67 |
+
)
|
| 68 |
+
elif (
|
| 69 |
+
self.model_type == "per_residue_label_membrane_mpnn"
|
| 70 |
+
or self.model_type == "global_label_membrane_mpnn"
|
| 71 |
+
):
|
| 72 |
+
self.W_v = torch.nn.Linear(node_features, hidden_dim, bias=True)
|
| 73 |
+
self.features = ProteinFeaturesMembrane(
|
| 74 |
+
node_features,
|
| 75 |
+
edge_features,
|
| 76 |
+
top_k=k_neighbors,
|
| 77 |
+
augment_eps=augment_eps,
|
| 78 |
+
num_classes=3,
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
print("Choose --model_type flag from currently available models")
|
| 82 |
+
sys.exit()
|
| 83 |
+
|
| 84 |
+
self.W_e = torch.nn.Linear(edge_features, hidden_dim, bias=True)
|
| 85 |
+
self.W_s = torch.nn.Embedding(vocab, hidden_dim)
|
| 86 |
+
|
| 87 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 88 |
+
|
| 89 |
+
# Encoder layers
|
| 90 |
+
self.encoder_layers = torch.nn.ModuleList(
|
| 91 |
+
[
|
| 92 |
+
EncLayer(hidden_dim, hidden_dim * 2, dropout=dropout)
|
| 93 |
+
for _ in range(num_encoder_layers)
|
| 94 |
+
]
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Decoder layers
|
| 98 |
+
self.decoder_layers = torch.nn.ModuleList(
|
| 99 |
+
[
|
| 100 |
+
DecLayer(hidden_dim, hidden_dim * 3, dropout=dropout)
|
| 101 |
+
for _ in range(num_decoder_layers)
|
| 102 |
+
]
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.W_out = torch.nn.Linear(hidden_dim, num_letters, bias=True)
|
| 106 |
+
|
| 107 |
+
for p in self.parameters():
|
| 108 |
+
if p.dim() > 1:
|
| 109 |
+
torch.nn.init.xavier_uniform_(p)
|
| 110 |
+
|
| 111 |
+
def encode(self, feature_dict):
|
| 112 |
+
# xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed
|
| 113 |
+
# xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords
|
| 114 |
+
# Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords
|
| 115 |
+
# Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type
|
| 116 |
+
# Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask
|
| 117 |
+
# X = feature_dict["X"] #[B,L,4,3] - backbone xyz coordinates for N,CA,C,O
|
| 118 |
+
S_true = feature_dict[
|
| 119 |
+
"S"
|
| 120 |
+
] # [B,L] - integer protein sequence encoded using "restype_STRtoINT"
|
| 121 |
+
# R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index
|
| 122 |
+
mask = feature_dict[
|
| 123 |
+
"mask"
|
| 124 |
+
] # [B,L] - mask for missing regions - should be removed! all ones most of the time
|
| 125 |
+
# chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters
|
| 126 |
+
|
| 127 |
+
B, L = S_true.shape
|
| 128 |
+
device = S_true.device
|
| 129 |
+
|
| 130 |
+
if self.model_type == "ligand_mpnn":
|
| 131 |
+
V, E, E_idx, Y_nodes, Y_edges, Y_m = self.features(feature_dict)
|
| 132 |
+
h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
|
| 133 |
+
h_E = self.W_e(E)
|
| 134 |
+
h_E_context = self.W_v(V)
|
| 135 |
+
|
| 136 |
+
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
|
| 137 |
+
mask_attend = mask.unsqueeze(-1) * mask_attend
|
| 138 |
+
for layer in self.encoder_layers:
|
| 139 |
+
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
|
| 140 |
+
|
| 141 |
+
h_V_C = self.W_c(h_V)
|
| 142 |
+
Y_m_edges = Y_m[:, :, :, None] * Y_m[:, :, None, :]
|
| 143 |
+
Y_nodes = self.W_nodes_y(Y_nodes)
|
| 144 |
+
Y_edges = self.W_edges_y(Y_edges)
|
| 145 |
+
for i in range(len(self.context_encoder_layers)):
|
| 146 |
+
Y_nodes = self.y_context_encoder_layers[i](
|
| 147 |
+
Y_nodes, Y_edges, Y_m, Y_m_edges
|
| 148 |
+
)
|
| 149 |
+
h_E_context_cat = torch.cat([h_E_context, Y_nodes], -1)
|
| 150 |
+
h_V_C = self.context_encoder_layers[i](
|
| 151 |
+
h_V_C, h_E_context_cat, mask, Y_m
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
h_V_C = self.V_C(h_V_C)
|
| 155 |
+
h_V = h_V + self.V_C_norm(self.dropout(h_V_C))
|
| 156 |
+
elif self.model_type == "protein_mpnn" or self.model_type == "soluble_mpnn":
|
| 157 |
+
E, E_idx = self.features(feature_dict)
|
| 158 |
+
h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
|
| 159 |
+
h_E = self.W_e(E)
|
| 160 |
+
|
| 161 |
+
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
|
| 162 |
+
mask_attend = mask.unsqueeze(-1) * mask_attend
|
| 163 |
+
for layer in self.encoder_layers:
|
| 164 |
+
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
|
| 165 |
+
elif (
|
| 166 |
+
self.model_type == "per_residue_label_membrane_mpnn"
|
| 167 |
+
or self.model_type == "global_label_membrane_mpnn"
|
| 168 |
+
):
|
| 169 |
+
V, E, E_idx = self.features(feature_dict)
|
| 170 |
+
h_V = self.W_v(V)
|
| 171 |
+
h_E = self.W_e(E)
|
| 172 |
+
|
| 173 |
+
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
|
| 174 |
+
mask_attend = mask.unsqueeze(-1) * mask_attend
|
| 175 |
+
for layer in self.encoder_layers:
|
| 176 |
+
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
|
| 177 |
+
|
| 178 |
+
return h_V, h_E, E_idx
|
| 179 |
+
|
| 180 |
+
def sample(self, feature_dict):
|
| 181 |
+
# xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed
|
| 182 |
+
# xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords
|
| 183 |
+
# Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords
|
| 184 |
+
# Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type
|
| 185 |
+
# Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask
|
| 186 |
+
# X = feature_dict["X"] #[B,L,4,3] - backbone xyz coordinates for N,CA,C,O
|
| 187 |
+
B_decoder = feature_dict["batch_size"]
|
| 188 |
+
S_true = feature_dict[
|
| 189 |
+
"S"
|
| 190 |
+
] # [B,L] - integer proitein sequence encoded using "restype_STRtoINT"
|
| 191 |
+
# R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index
|
| 192 |
+
mask = feature_dict[
|
| 193 |
+
"mask"
|
| 194 |
+
] # [B,L] - mask for missing regions - should be removed! all ones most of the time
|
| 195 |
+
chain_mask = feature_dict[
|
| 196 |
+
"chain_mask"
|
| 197 |
+
] # [B,L] - mask for which residues need to be fixed; 0.0 - fixed; 1.0 - will be designed
|
| 198 |
+
bias = feature_dict["bias"] # [B,L,21] - amino acid bias per position
|
| 199 |
+
# chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters
|
| 200 |
+
randn = feature_dict[
|
| 201 |
+
"randn"
|
| 202 |
+
] # [B,L] - random numbers for decoding order; only the first entry is used since decoding within a batch needs to match for symmetry
|
| 203 |
+
temperature = feature_dict[
|
| 204 |
+
"temperature"
|
| 205 |
+
] # float - sampling temperature; prob = softmax(logits/temperature)
|
| 206 |
+
symmetry_list_of_lists = feature_dict[
|
| 207 |
+
"symmetry_residues"
|
| 208 |
+
] # [[0, 1, 14], [10,11,14,15], [20, 21]] #indices to select X over length - L
|
| 209 |
+
symmetry_weights_list_of_lists = feature_dict[
|
| 210 |
+
"symmetry_weights"
|
| 211 |
+
] # [[1.0, 1.0, 1.0], [-2.0,1.1,0.2,1.1], [2.3, 1.1]]
|
| 212 |
+
|
| 213 |
+
B, L = S_true.shape
|
| 214 |
+
device = S_true.device
|
| 215 |
+
|
| 216 |
+
h_V, h_E, E_idx = self.encode(feature_dict)
|
| 217 |
+
|
| 218 |
+
chain_mask = mask * chain_mask # update chain_M to include missing regions
|
| 219 |
+
decoding_order = torch.argsort(
|
| 220 |
+
(chain_mask + 0.0001) * (torch.abs(randn))
|
| 221 |
+
) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
|
| 222 |
+
if len(symmetry_list_of_lists[0]) == 0 and len(symmetry_list_of_lists) == 1:
|
| 223 |
+
E_idx = E_idx.repeat(B_decoder, 1, 1)
|
| 224 |
+
permutation_matrix_reverse = torch.nn.functional.one_hot(
|
| 225 |
+
decoding_order, num_classes=L
|
| 226 |
+
).float()
|
| 227 |
+
order_mask_backward = torch.einsum(
|
| 228 |
+
"ij, biq, bjp->bqp",
|
| 229 |
+
(1 - torch.triu(torch.ones(L, L, device=device))),
|
| 230 |
+
permutation_matrix_reverse,
|
| 231 |
+
permutation_matrix_reverse,
|
| 232 |
+
)
|
| 233 |
+
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
|
| 234 |
+
mask_1D = mask.view([B, L, 1, 1])
|
| 235 |
+
mask_bw = mask_1D * mask_attend
|
| 236 |
+
mask_fw = mask_1D * (1.0 - mask_attend)
|
| 237 |
+
|
| 238 |
+
# repeat for decoding
|
| 239 |
+
S_true = S_true.repeat(B_decoder, 1)
|
| 240 |
+
h_V = h_V.repeat(B_decoder, 1, 1)
|
| 241 |
+
h_E = h_E.repeat(B_decoder, 1, 1, 1)
|
| 242 |
+
chain_mask = chain_mask.repeat(B_decoder, 1)
|
| 243 |
+
mask = mask.repeat(B_decoder, 1)
|
| 244 |
+
bias = bias.repeat(B_decoder, 1, 1)
|
| 245 |
+
|
| 246 |
+
all_probs = torch.zeros(
|
| 247 |
+
(B_decoder, L, 20), device=device, dtype=torch.float32
|
| 248 |
+
)
|
| 249 |
+
all_log_probs = torch.zeros(
|
| 250 |
+
(B_decoder, L, 21), device=device, dtype=torch.float32
|
| 251 |
+
)
|
| 252 |
+
h_S = torch.zeros_like(h_V, device=device)
|
| 253 |
+
S = 20 * torch.ones((B_decoder, L), dtype=torch.int64, device=device)
|
| 254 |
+
h_V_stack = [h_V] + [
|
| 255 |
+
torch.zeros_like(h_V, device=device)
|
| 256 |
+
for _ in range(len(self.decoder_layers))
|
| 257 |
+
]
|
| 258 |
+
|
| 259 |
+
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
|
| 260 |
+
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
|
| 261 |
+
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
|
| 262 |
+
|
| 263 |
+
for t_ in range(L):
|
| 264 |
+
t = decoding_order[:, t_] # [B]
|
| 265 |
+
chain_mask_t = torch.gather(chain_mask, 1, t[:, None])[:, 0] # [B]
|
| 266 |
+
mask_t = torch.gather(mask, 1, t[:, None])[:, 0] # [B]
|
| 267 |
+
bias_t = torch.gather(bias, 1, t[:, None, None].repeat(1, 1, 21))[
|
| 268 |
+
:, 0, :
|
| 269 |
+
] # [B,21]
|
| 270 |
+
|
| 271 |
+
E_idx_t = torch.gather(
|
| 272 |
+
E_idx, 1, t[:, None, None].repeat(1, 1, E_idx.shape[-1])
|
| 273 |
+
)
|
| 274 |
+
h_E_t = torch.gather(
|
| 275 |
+
h_E,
|
| 276 |
+
1,
|
| 277 |
+
t[:, None, None, None].repeat(1, 1, h_E.shape[-2], h_E.shape[-1]),
|
| 278 |
+
)
|
| 279 |
+
h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
|
| 280 |
+
h_EXV_encoder_t = torch.gather(
|
| 281 |
+
h_EXV_encoder_fw,
|
| 282 |
+
1,
|
| 283 |
+
t[:, None, None, None].repeat(
|
| 284 |
+
1, 1, h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1]
|
| 285 |
+
),
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
mask_bw_t = torch.gather(
|
| 289 |
+
mask_bw,
|
| 290 |
+
1,
|
| 291 |
+
t[:, None, None, None].repeat(
|
| 292 |
+
1, 1, mask_bw.shape[-2], mask_bw.shape[-1]
|
| 293 |
+
),
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
for l, layer in enumerate(self.decoder_layers):
|
| 297 |
+
h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
|
| 298 |
+
h_V_t = torch.gather(
|
| 299 |
+
h_V_stack[l],
|
| 300 |
+
1,
|
| 301 |
+
t[:, None, None].repeat(1, 1, h_V_stack[l].shape[-1]),
|
| 302 |
+
)
|
| 303 |
+
h_ESV_t = mask_bw_t * h_ESV_decoder_t + h_EXV_encoder_t
|
| 304 |
+
h_V_stack[l + 1].scatter_(
|
| 305 |
+
1,
|
| 306 |
+
t[:, None, None].repeat(1, 1, h_V.shape[-1]),
|
| 307 |
+
layer(h_V_t, h_ESV_t, mask_V=mask_t),
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
h_V_t = torch.gather(
|
| 311 |
+
h_V_stack[-1],
|
| 312 |
+
1,
|
| 313 |
+
t[:, None, None].repeat(1, 1, h_V_stack[-1].shape[-1]),
|
| 314 |
+
)[:, 0]
|
| 315 |
+
logits = self.W_out(h_V_t) # [B,21]
|
| 316 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # [B,21]
|
| 317 |
+
|
| 318 |
+
probs = torch.nn.functional.softmax(
|
| 319 |
+
(logits + bias_t) / temperature, dim=-1
|
| 320 |
+
) # [B,21]
|
| 321 |
+
probs_sample = probs[:, :20] / torch.sum(
|
| 322 |
+
probs[:, :20], dim=-1, keepdim=True
|
| 323 |
+
) # hard omit X #[B,20]
|
| 324 |
+
S_t = torch.multinomial(probs_sample, 1)[:, 0] # [B]
|
| 325 |
+
|
| 326 |
+
all_probs.scatter_(
|
| 327 |
+
1,
|
| 328 |
+
t[:, None, None].repeat(1, 1, 20),
|
| 329 |
+
(chain_mask_t[:, None, None] * probs_sample[:, None, :]).float(),
|
| 330 |
+
)
|
| 331 |
+
all_log_probs.scatter_(
|
| 332 |
+
1,
|
| 333 |
+
t[:, None, None].repeat(1, 1, 21),
|
| 334 |
+
(chain_mask_t[:, None, None] * log_probs[:, None, :]).float(),
|
| 335 |
+
)
|
| 336 |
+
S_true_t = torch.gather(S_true, 1, t[:, None])[:, 0]
|
| 337 |
+
S_t = (S_t * chain_mask_t + S_true_t * (1.0 - chain_mask_t)).long()
|
| 338 |
+
h_S.scatter_(
|
| 339 |
+
1,
|
| 340 |
+
t[:, None, None].repeat(1, 1, h_S.shape[-1]),
|
| 341 |
+
self.W_s(S_t)[:, None, :],
|
| 342 |
+
)
|
| 343 |
+
S.scatter_(1, t[:, None], S_t[:, None])
|
| 344 |
+
|
| 345 |
+
output_dict = {
|
| 346 |
+
"S": S,
|
| 347 |
+
"sampling_probs": all_probs,
|
| 348 |
+
"log_probs": all_log_probs,
|
| 349 |
+
"decoding_order": decoding_order,
|
| 350 |
+
}
|
| 351 |
+
else:
|
| 352 |
+
# weights for symmetric design
|
| 353 |
+
symmetry_weights = torch.ones([L], device=device, dtype=torch.float32)
|
| 354 |
+
for i1, item_list in enumerate(symmetry_list_of_lists):
|
| 355 |
+
for i2, item in enumerate(item_list):
|
| 356 |
+
symmetry_weights[item] = symmetry_weights_list_of_lists[i1][i2]
|
| 357 |
+
|
| 358 |
+
new_decoding_order = []
|
| 359 |
+
for t_dec in list(decoding_order[0,].cpu().data.numpy()):
|
| 360 |
+
if t_dec not in list(itertools.chain(*new_decoding_order)):
|
| 361 |
+
list_a = [item for item in symmetry_list_of_lists if t_dec in item]
|
| 362 |
+
if list_a:
|
| 363 |
+
new_decoding_order.append(list_a[0])
|
| 364 |
+
else:
|
| 365 |
+
new_decoding_order.append([t_dec])
|
| 366 |
+
|
| 367 |
+
decoding_order = torch.tensor(
|
| 368 |
+
list(itertools.chain(*new_decoding_order)), device=device
|
| 369 |
+
)[None,].repeat(B, 1)
|
| 370 |
+
|
| 371 |
+
permutation_matrix_reverse = torch.nn.functional.one_hot(
|
| 372 |
+
decoding_order, num_classes=L
|
| 373 |
+
).float()
|
| 374 |
+
order_mask_backward = torch.einsum(
|
| 375 |
+
"ij, biq, bjp->bqp",
|
| 376 |
+
(1 - torch.triu(torch.ones(L, L, device=device))),
|
| 377 |
+
permutation_matrix_reverse,
|
| 378 |
+
permutation_matrix_reverse,
|
| 379 |
+
)
|
| 380 |
+
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
|
| 381 |
+
mask_1D = mask.view([B, L, 1, 1])
|
| 382 |
+
mask_bw = mask_1D * mask_attend
|
| 383 |
+
mask_fw = mask_1D * (1.0 - mask_attend)
|
| 384 |
+
|
| 385 |
+
# repeat for decoding
|
| 386 |
+
S_true = S_true.repeat(B_decoder, 1)
|
| 387 |
+
h_V = h_V.repeat(B_decoder, 1, 1)
|
| 388 |
+
h_E = h_E.repeat(B_decoder, 1, 1, 1)
|
| 389 |
+
E_idx = E_idx.repeat(B_decoder, 1, 1)
|
| 390 |
+
mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1)
|
| 391 |
+
mask_bw = mask_bw.repeat(B_decoder, 1, 1, 1)
|
| 392 |
+
chain_mask = chain_mask.repeat(B_decoder, 1)
|
| 393 |
+
mask = mask.repeat(B_decoder, 1)
|
| 394 |
+
bias = bias.repeat(B_decoder, 1, 1)
|
| 395 |
+
|
| 396 |
+
all_probs = torch.zeros(
|
| 397 |
+
(B_decoder, L, 20), device=device, dtype=torch.float32
|
| 398 |
+
)
|
| 399 |
+
all_log_probs = torch.zeros(
|
| 400 |
+
(B_decoder, L, 21), device=device, dtype=torch.float32
|
| 401 |
+
)
|
| 402 |
+
h_S = torch.zeros_like(h_V, device=device)
|
| 403 |
+
S = 20 * torch.ones((B_decoder, L), dtype=torch.int64, device=device)
|
| 404 |
+
h_V_stack = [h_V] + [
|
| 405 |
+
torch.zeros_like(h_V, device=device)
|
| 406 |
+
for _ in range(len(self.decoder_layers))
|
| 407 |
+
]
|
| 408 |
+
|
| 409 |
+
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
|
| 410 |
+
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
|
| 411 |
+
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
|
| 412 |
+
|
| 413 |
+
for t_list in new_decoding_order:
|
| 414 |
+
total_logits = 0.0
|
| 415 |
+
for t in t_list:
|
| 416 |
+
chain_mask_t = chain_mask[:, t] # [B]
|
| 417 |
+
mask_t = mask[:, t] # [B]
|
| 418 |
+
bias_t = bias[:, t] # [B, 21]
|
| 419 |
+
|
| 420 |
+
E_idx_t = E_idx[:, t : t + 1]
|
| 421 |
+
h_E_t = h_E[:, t : t + 1]
|
| 422 |
+
h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
|
| 423 |
+
h_EXV_encoder_t = h_EXV_encoder_fw[:, t : t + 1]
|
| 424 |
+
for l, layer in enumerate(self.decoder_layers):
|
| 425 |
+
h_ESV_decoder_t = cat_neighbors_nodes(
|
| 426 |
+
h_V_stack[l], h_ES_t, E_idx_t
|
| 427 |
+
)
|
| 428 |
+
h_V_t = h_V_stack[l][:, t : t + 1]
|
| 429 |
+
h_ESV_t = (
|
| 430 |
+
mask_bw[:, t : t + 1] * h_ESV_decoder_t + h_EXV_encoder_t
|
| 431 |
+
)
|
| 432 |
+
h_V_stack[l + 1][:, t : t + 1, :] = layer(
|
| 433 |
+
h_V_t, h_ESV_t, mask_V=mask_t[:, None]
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
h_V_t = h_V_stack[-1][:, t]
|
| 437 |
+
logits = self.W_out(h_V_t) # [B,21]
|
| 438 |
+
log_probs = torch.nn.functional.log_softmax(
|
| 439 |
+
logits, dim=-1
|
| 440 |
+
) # [B,21]
|
| 441 |
+
all_log_probs[:, t] = (
|
| 442 |
+
chain_mask_t[:, None] * log_probs
|
| 443 |
+
).float() # [B,21]
|
| 444 |
+
total_logits += symmetry_weights[t] * logits
|
| 445 |
+
|
| 446 |
+
probs = torch.nn.functional.softmax(
|
| 447 |
+
(total_logits + bias_t) / temperature, dim=-1
|
| 448 |
+
) # [B,21]
|
| 449 |
+
probs_sample = probs[:, :20] / torch.sum(
|
| 450 |
+
probs[:, :20], dim=-1, keepdim=True
|
| 451 |
+
) # hard omit X #[B,20]
|
| 452 |
+
S_t = torch.multinomial(probs_sample, 1)[:, 0] # [B]
|
| 453 |
+
for t in t_list:
|
| 454 |
+
chain_mask_t = chain_mask[:, t] # [B]
|
| 455 |
+
all_probs[:, t] = (
|
| 456 |
+
chain_mask_t[:, None] * probs_sample
|
| 457 |
+
).float() # [B,20]
|
| 458 |
+
S_true_t = S_true[:, t] # [B]
|
| 459 |
+
S_t = (S_t * chain_mask_t + S_true_t * (1.0 - chain_mask_t)).long()
|
| 460 |
+
h_S[:, t] = self.W_s(S_t)
|
| 461 |
+
S[:, t] = S_t
|
| 462 |
+
|
| 463 |
+
output_dict = {
|
| 464 |
+
"S": S,
|
| 465 |
+
"sampling_probs": all_probs,
|
| 466 |
+
"log_probs": all_log_probs,
|
| 467 |
+
"decoding_order": decoding_order.repeat(B_decoder, 1),
|
| 468 |
+
}
|
| 469 |
+
return output_dict
|
| 470 |
+
|
| 471 |
+
def single_aa_score(self, feature_dict, use_sequence: bool):
|
| 472 |
+
"""
|
| 473 |
+
feature_dict - input features
|
| 474 |
+
use_sequence - False using backbone info only
|
| 475 |
+
"""
|
| 476 |
+
B_decoder = feature_dict["batch_size"]
|
| 477 |
+
S_true_enc = feature_dict[
|
| 478 |
+
"S"
|
| 479 |
+
]
|
| 480 |
+
mask_enc = feature_dict[
|
| 481 |
+
"mask"
|
| 482 |
+
]
|
| 483 |
+
chain_mask_enc = feature_dict[
|
| 484 |
+
"chain_mask"
|
| 485 |
+
]
|
| 486 |
+
randn = feature_dict[
|
| 487 |
+
"randn"
|
| 488 |
+
]
|
| 489 |
+
B, L = S_true_enc.shape
|
| 490 |
+
device = S_true_enc.device
|
| 491 |
+
|
| 492 |
+
h_V_enc, h_E_enc, E_idx_enc = self.encode(feature_dict)
|
| 493 |
+
log_probs_out = torch.zeros([B_decoder, L, 21], device=device).float()
|
| 494 |
+
logits_out = torch.zeros([B_decoder, L, 21], device=device).float()
|
| 495 |
+
decoding_order_out = torch.zeros([B_decoder, L, L], device=device).float()
|
| 496 |
+
|
| 497 |
+
for idx in range(L):
|
| 498 |
+
h_V = torch.clone(h_V_enc)
|
| 499 |
+
E_idx = torch.clone(E_idx_enc)
|
| 500 |
+
mask = torch.clone(mask_enc)
|
| 501 |
+
S_true = torch.clone(S_true_enc)
|
| 502 |
+
if not use_sequence:
|
| 503 |
+
order_mask = torch.zeros(chain_mask_enc.shape[1], device=device).float()
|
| 504 |
+
order_mask[idx] = 1.
|
| 505 |
+
else:
|
| 506 |
+
order_mask = torch.ones(chain_mask_enc.shape[1], device=device).float()
|
| 507 |
+
order_mask[idx] = 0.
|
| 508 |
+
decoding_order = torch.argsort(
|
| 509 |
+
(order_mask + 0.0001) * (torch.abs(randn))
|
| 510 |
+
) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
|
| 511 |
+
E_idx = E_idx.repeat(B_decoder, 1, 1)
|
| 512 |
+
permutation_matrix_reverse = torch.nn.functional.one_hot(
|
| 513 |
+
decoding_order, num_classes=L
|
| 514 |
+
).float()
|
| 515 |
+
order_mask_backward = torch.einsum(
|
| 516 |
+
"ij, biq, bjp->bqp",
|
| 517 |
+
(1 - torch.triu(torch.ones(L, L, device=device))),
|
| 518 |
+
permutation_matrix_reverse,
|
| 519 |
+
permutation_matrix_reverse,
|
| 520 |
+
)
|
| 521 |
+
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
|
| 522 |
+
mask_1D = mask.view([B, L, 1, 1])
|
| 523 |
+
mask_bw = mask_1D * mask_attend
|
| 524 |
+
mask_fw = mask_1D * (1.0 - mask_attend)
|
| 525 |
+
S_true = S_true.repeat(B_decoder, 1)
|
| 526 |
+
h_V = h_V.repeat(B_decoder, 1, 1)
|
| 527 |
+
h_E = h_E_enc.repeat(B_decoder, 1, 1, 1)
|
| 528 |
+
mask = mask.repeat(B_decoder, 1)
|
| 529 |
+
|
| 530 |
+
h_S = self.W_s(S_true)
|
| 531 |
+
h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
|
| 532 |
+
|
| 533 |
+
# Build encoder embeddings
|
| 534 |
+
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
|
| 535 |
+
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
|
| 536 |
+
|
| 537 |
+
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
|
| 538 |
+
for layer in self.decoder_layers:
|
| 539 |
+
# Masked positions attend to encoder information, unmasked see.
|
| 540 |
+
h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
|
| 541 |
+
h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
|
| 542 |
+
h_V = layer(h_V, h_ESV, mask)
|
| 543 |
+
|
| 544 |
+
logits = self.W_out(h_V)
|
| 545 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
| 546 |
+
|
| 547 |
+
log_probs_out[:,idx,:] = log_probs[:,idx,:]
|
| 548 |
+
logits_out[:,idx,:] = logits[:,idx,:]
|
| 549 |
+
decoding_order_out[:,idx,:] = decoding_order
|
| 550 |
+
|
| 551 |
+
output_dict = {
|
| 552 |
+
"S": S_true,
|
| 553 |
+
"log_probs": log_probs_out,
|
| 554 |
+
"logits": logits_out,
|
| 555 |
+
"decoding_order": decoding_order_out,
|
| 556 |
+
}
|
| 557 |
+
return output_dict
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def score(self, feature_dict, use_sequence: bool):
|
| 561 |
+
B_decoder = feature_dict["batch_size"]
|
| 562 |
+
S_true = feature_dict[
|
| 563 |
+
"S"
|
| 564 |
+
]
|
| 565 |
+
mask = feature_dict[
|
| 566 |
+
"mask"
|
| 567 |
+
]
|
| 568 |
+
chain_mask = feature_dict[
|
| 569 |
+
"chain_mask"
|
| 570 |
+
]
|
| 571 |
+
randn = feature_dict[
|
| 572 |
+
"randn"
|
| 573 |
+
]
|
| 574 |
+
symmetry_list_of_lists = feature_dict[
|
| 575 |
+
"symmetry_residues"
|
| 576 |
+
]
|
| 577 |
+
B, L = S_true.shape
|
| 578 |
+
device = S_true.device
|
| 579 |
+
|
| 580 |
+
h_V, h_E, E_idx = self.encode(feature_dict)
|
| 581 |
+
|
| 582 |
+
chain_mask = mask * chain_mask # update chain_M to include missing regions
|
| 583 |
+
decoding_order = torch.argsort(
|
| 584 |
+
(chain_mask + 0.0001) * (torch.abs(randn))
|
| 585 |
+
) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
|
| 586 |
+
if len(symmetry_list_of_lists[0]) == 0 and len(symmetry_list_of_lists) == 1:
|
| 587 |
+
E_idx = E_idx.repeat(B_decoder, 1, 1)
|
| 588 |
+
permutation_matrix_reverse = torch.nn.functional.one_hot(
|
| 589 |
+
decoding_order, num_classes=L
|
| 590 |
+
).float()
|
| 591 |
+
order_mask_backward = torch.einsum(
|
| 592 |
+
"ij, biq, bjp->bqp",
|
| 593 |
+
(1 - torch.triu(torch.ones(L, L, device=device))),
|
| 594 |
+
permutation_matrix_reverse,
|
| 595 |
+
permutation_matrix_reverse,
|
| 596 |
+
)
|
| 597 |
+
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
|
| 598 |
+
mask_1D = mask.view([B, L, 1, 1])
|
| 599 |
+
mask_bw = mask_1D * mask_attend
|
| 600 |
+
mask_fw = mask_1D * (1.0 - mask_attend)
|
| 601 |
+
else:
|
| 602 |
+
new_decoding_order = []
|
| 603 |
+
for t_dec in list(decoding_order[0,].cpu().data.numpy()):
|
| 604 |
+
if t_dec not in list(itertools.chain(*new_decoding_order)):
|
| 605 |
+
list_a = [item for item in symmetry_list_of_lists if t_dec in item]
|
| 606 |
+
if list_a:
|
| 607 |
+
new_decoding_order.append(list_a[0])
|
| 608 |
+
else:
|
| 609 |
+
new_decoding_order.append([t_dec])
|
| 610 |
+
|
| 611 |
+
decoding_order = torch.tensor(
|
| 612 |
+
list(itertools.chain(*new_decoding_order)), device=device
|
| 613 |
+
)[None,].repeat(B, 1)
|
| 614 |
+
|
| 615 |
+
permutation_matrix_reverse = torch.nn.functional.one_hot(
|
| 616 |
+
decoding_order, num_classes=L
|
| 617 |
+
).float()
|
| 618 |
+
order_mask_backward = torch.einsum(
|
| 619 |
+
"ij, biq, bjp->bqp",
|
| 620 |
+
(1 - torch.triu(torch.ones(L, L, device=device))),
|
| 621 |
+
permutation_matrix_reverse,
|
| 622 |
+
permutation_matrix_reverse,
|
| 623 |
+
)
|
| 624 |
+
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
|
| 625 |
+
mask_1D = mask.view([B, L, 1, 1])
|
| 626 |
+
mask_bw = mask_1D * mask_attend
|
| 627 |
+
mask_fw = mask_1D * (1.0 - mask_attend)
|
| 628 |
+
|
| 629 |
+
E_idx = E_idx.repeat(B_decoder, 1, 1)
|
| 630 |
+
mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1)
|
| 631 |
+
mask_bw = mask_bw.repeat(B_decoder, 1, 1, 1)
|
| 632 |
+
decoding_order = decoding_order.repeat(B_decoder, 1)
|
| 633 |
+
|
| 634 |
+
S_true = S_true.repeat(B_decoder, 1)
|
| 635 |
+
h_V = h_V.repeat(B_decoder, 1, 1)
|
| 636 |
+
h_E = h_E.repeat(B_decoder, 1, 1, 1)
|
| 637 |
+
mask = mask.repeat(B_decoder, 1)
|
| 638 |
+
|
| 639 |
+
h_S = self.W_s(S_true)
|
| 640 |
+
h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
|
| 641 |
+
|
| 642 |
+
# Build encoder embeddings
|
| 643 |
+
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
|
| 644 |
+
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
|
| 645 |
+
|
| 646 |
+
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
|
| 647 |
+
if not use_sequence:
|
| 648 |
+
for layer in self.decoder_layers:
|
| 649 |
+
h_V = layer(h_V, h_EXV_encoder_fw, mask)
|
| 650 |
+
else:
|
| 651 |
+
for layer in self.decoder_layers:
|
| 652 |
+
# Masked positions attend to encoder information, unmasked see.
|
| 653 |
+
h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
|
| 654 |
+
h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
|
| 655 |
+
h_V = layer(h_V, h_ESV, mask)
|
| 656 |
+
|
| 657 |
+
logits = self.W_out(h_V)
|
| 658 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
|
| 659 |
+
|
| 660 |
+
output_dict = {
|
| 661 |
+
"S": S_true,
|
| 662 |
+
"log_probs": log_probs,
|
| 663 |
+
"logits": logits,
|
| 664 |
+
"decoding_order": decoding_order,
|
| 665 |
+
}
|
| 666 |
+
return output_dict
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
class ProteinFeaturesLigand(torch.nn.Module):
|
| 670 |
+
def __init__(
|
| 671 |
+
self,
|
| 672 |
+
edge_features,
|
| 673 |
+
node_features,
|
| 674 |
+
num_positional_embeddings=16,
|
| 675 |
+
num_rbf=16,
|
| 676 |
+
top_k=30,
|
| 677 |
+
augment_eps=0.0,
|
| 678 |
+
device=None,
|
| 679 |
+
atom_context_num=16,
|
| 680 |
+
use_side_chains=False,
|
| 681 |
+
):
|
| 682 |
+
"""Extract protein features"""
|
| 683 |
+
super(ProteinFeaturesLigand, self).__init__()
|
| 684 |
+
|
| 685 |
+
self.use_side_chains = use_side_chains
|
| 686 |
+
|
| 687 |
+
self.edge_features = edge_features
|
| 688 |
+
self.node_features = node_features
|
| 689 |
+
self.top_k = top_k
|
| 690 |
+
self.augment_eps = augment_eps
|
| 691 |
+
self.num_rbf = num_rbf
|
| 692 |
+
self.num_positional_embeddings = num_positional_embeddings
|
| 693 |
+
|
| 694 |
+
self.embeddings = PositionalEncodings(num_positional_embeddings)
|
| 695 |
+
edge_in = num_positional_embeddings + num_rbf * 25
|
| 696 |
+
self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False)
|
| 697 |
+
self.norm_edges = torch.nn.LayerNorm(edge_features)
|
| 698 |
+
|
| 699 |
+
self.node_project_down = torch.nn.Linear(
|
| 700 |
+
5 * num_rbf + 64 + 4, node_features, bias=True
|
| 701 |
+
)
|
| 702 |
+
self.norm_nodes = torch.nn.LayerNorm(node_features)
|
| 703 |
+
|
| 704 |
+
self.type_linear = torch.nn.Linear(147, 64)
|
| 705 |
+
|
| 706 |
+
self.y_nodes = torch.nn.Linear(147, node_features, bias=False)
|
| 707 |
+
self.y_edges = torch.nn.Linear(num_rbf, node_features, bias=False)
|
| 708 |
+
|
| 709 |
+
self.norm_y_edges = torch.nn.LayerNorm(node_features)
|
| 710 |
+
self.norm_y_nodes = torch.nn.LayerNorm(node_features)
|
| 711 |
+
|
| 712 |
+
self.atom_context_num = atom_context_num
|
| 713 |
+
|
| 714 |
+
# the last 32 atoms in the 37 atom representation
|
| 715 |
+
self.side_chain_atom_types = torch.tensor(
|
| 716 |
+
[
|
| 717 |
+
6,
|
| 718 |
+
6,
|
| 719 |
+
6,
|
| 720 |
+
8,
|
| 721 |
+
8,
|
| 722 |
+
16,
|
| 723 |
+
6,
|
| 724 |
+
6,
|
| 725 |
+
6,
|
| 726 |
+
7,
|
| 727 |
+
7,
|
| 728 |
+
8,
|
| 729 |
+
8,
|
| 730 |
+
16,
|
| 731 |
+
6,
|
| 732 |
+
6,
|
| 733 |
+
6,
|
| 734 |
+
6,
|
| 735 |
+
7,
|
| 736 |
+
7,
|
| 737 |
+
7,
|
| 738 |
+
8,
|
| 739 |
+
8,
|
| 740 |
+
6,
|
| 741 |
+
7,
|
| 742 |
+
7,
|
| 743 |
+
8,
|
| 744 |
+
6,
|
| 745 |
+
6,
|
| 746 |
+
6,
|
| 747 |
+
7,
|
| 748 |
+
8,
|
| 749 |
+
],
|
| 750 |
+
device=device,
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
self.periodic_table_features = torch.tensor(
|
| 754 |
+
[
|
| 755 |
+
[
|
| 756 |
+
0,
|
| 757 |
+
1,
|
| 758 |
+
2,
|
| 759 |
+
3,
|
| 760 |
+
4,
|
| 761 |
+
5,
|
| 762 |
+
6,
|
| 763 |
+
7,
|
| 764 |
+
8,
|
| 765 |
+
9,
|
| 766 |
+
10,
|
| 767 |
+
11,
|
| 768 |
+
12,
|
| 769 |
+
13,
|
| 770 |
+
14,
|
| 771 |
+
15,
|
| 772 |
+
16,
|
| 773 |
+
17,
|
| 774 |
+
18,
|
| 775 |
+
19,
|
| 776 |
+
20,
|
| 777 |
+
21,
|
| 778 |
+
22,
|
| 779 |
+
23,
|
| 780 |
+
24,
|
| 781 |
+
25,
|
| 782 |
+
26,
|
| 783 |
+
27,
|
| 784 |
+
28,
|
| 785 |
+
29,
|
| 786 |
+
30,
|
| 787 |
+
31,
|
| 788 |
+
32,
|
| 789 |
+
33,
|
| 790 |
+
34,
|
| 791 |
+
35,
|
| 792 |
+
36,
|
| 793 |
+
37,
|
| 794 |
+
38,
|
| 795 |
+
39,
|
| 796 |
+
40,
|
| 797 |
+
41,
|
| 798 |
+
42,
|
| 799 |
+
43,
|
| 800 |
+
44,
|
| 801 |
+
45,
|
| 802 |
+
46,
|
| 803 |
+
47,
|
| 804 |
+
48,
|
| 805 |
+
49,
|
| 806 |
+
50,
|
| 807 |
+
51,
|
| 808 |
+
52,
|
| 809 |
+
53,
|
| 810 |
+
54,
|
| 811 |
+
55,
|
| 812 |
+
56,
|
| 813 |
+
57,
|
| 814 |
+
58,
|
| 815 |
+
59,
|
| 816 |
+
60,
|
| 817 |
+
61,
|
| 818 |
+
62,
|
| 819 |
+
63,
|
| 820 |
+
64,
|
| 821 |
+
65,
|
| 822 |
+
66,
|
| 823 |
+
67,
|
| 824 |
+
68,
|
| 825 |
+
69,
|
| 826 |
+
70,
|
| 827 |
+
71,
|
| 828 |
+
72,
|
| 829 |
+
73,
|
| 830 |
+
74,
|
| 831 |
+
75,
|
| 832 |
+
76,
|
| 833 |
+
77,
|
| 834 |
+
78,
|
| 835 |
+
79,
|
| 836 |
+
80,
|
| 837 |
+
81,
|
| 838 |
+
82,
|
| 839 |
+
83,
|
| 840 |
+
84,
|
| 841 |
+
85,
|
| 842 |
+
86,
|
| 843 |
+
87,
|
| 844 |
+
88,
|
| 845 |
+
89,
|
| 846 |
+
90,
|
| 847 |
+
91,
|
| 848 |
+
92,
|
| 849 |
+
93,
|
| 850 |
+
94,
|
| 851 |
+
95,
|
| 852 |
+
96,
|
| 853 |
+
97,
|
| 854 |
+
98,
|
| 855 |
+
99,
|
| 856 |
+
100,
|
| 857 |
+
101,
|
| 858 |
+
102,
|
| 859 |
+
103,
|
| 860 |
+
104,
|
| 861 |
+
105,
|
| 862 |
+
106,
|
| 863 |
+
107,
|
| 864 |
+
108,
|
| 865 |
+
109,
|
| 866 |
+
110,
|
| 867 |
+
111,
|
| 868 |
+
112,
|
| 869 |
+
113,
|
| 870 |
+
114,
|
| 871 |
+
115,
|
| 872 |
+
116,
|
| 873 |
+
117,
|
| 874 |
+
118,
|
| 875 |
+
],
|
| 876 |
+
[
|
| 877 |
+
0,
|
| 878 |
+
1,
|
| 879 |
+
18,
|
| 880 |
+
1,
|
| 881 |
+
2,
|
| 882 |
+
13,
|
| 883 |
+
14,
|
| 884 |
+
15,
|
| 885 |
+
16,
|
| 886 |
+
17,
|
| 887 |
+
18,
|
| 888 |
+
1,
|
| 889 |
+
2,
|
| 890 |
+
13,
|
| 891 |
+
14,
|
| 892 |
+
15,
|
| 893 |
+
16,
|
| 894 |
+
17,
|
| 895 |
+
18,
|
| 896 |
+
1,
|
| 897 |
+
2,
|
| 898 |
+
3,
|
| 899 |
+
4,
|
| 900 |
+
5,
|
| 901 |
+
6,
|
| 902 |
+
7,
|
| 903 |
+
8,
|
| 904 |
+
9,
|
| 905 |
+
10,
|
| 906 |
+
11,
|
| 907 |
+
12,
|
| 908 |
+
13,
|
| 909 |
+
14,
|
| 910 |
+
15,
|
| 911 |
+
16,
|
| 912 |
+
17,
|
| 913 |
+
18,
|
| 914 |
+
1,
|
| 915 |
+
2,
|
| 916 |
+
3,
|
| 917 |
+
4,
|
| 918 |
+
5,
|
| 919 |
+
6,
|
| 920 |
+
7,
|
| 921 |
+
8,
|
| 922 |
+
9,
|
| 923 |
+
10,
|
| 924 |
+
11,
|
| 925 |
+
12,
|
| 926 |
+
13,
|
| 927 |
+
14,
|
| 928 |
+
15,
|
| 929 |
+
16,
|
| 930 |
+
17,
|
| 931 |
+
18,
|
| 932 |
+
1,
|
| 933 |
+
2,
|
| 934 |
+
3,
|
| 935 |
+
3,
|
| 936 |
+
3,
|
| 937 |
+
3,
|
| 938 |
+
3,
|
| 939 |
+
3,
|
| 940 |
+
3,
|
| 941 |
+
3,
|
| 942 |
+
3,
|
| 943 |
+
3,
|
| 944 |
+
3,
|
| 945 |
+
3,
|
| 946 |
+
3,
|
| 947 |
+
3,
|
| 948 |
+
3,
|
| 949 |
+
4,
|
| 950 |
+
5,
|
| 951 |
+
6,
|
| 952 |
+
7,
|
| 953 |
+
8,
|
| 954 |
+
9,
|
| 955 |
+
10,
|
| 956 |
+
11,
|
| 957 |
+
12,
|
| 958 |
+
13,
|
| 959 |
+
14,
|
| 960 |
+
15,
|
| 961 |
+
16,
|
| 962 |
+
17,
|
| 963 |
+
18,
|
| 964 |
+
1,
|
| 965 |
+
2,
|
| 966 |
+
3,
|
| 967 |
+
3,
|
| 968 |
+
3,
|
| 969 |
+
3,
|
| 970 |
+
3,
|
| 971 |
+
3,
|
| 972 |
+
3,
|
| 973 |
+
3,
|
| 974 |
+
3,
|
| 975 |
+
3,
|
| 976 |
+
3,
|
| 977 |
+
3,
|
| 978 |
+
3,
|
| 979 |
+
3,
|
| 980 |
+
3,
|
| 981 |
+
4,
|
| 982 |
+
5,
|
| 983 |
+
6,
|
| 984 |
+
7,
|
| 985 |
+
8,
|
| 986 |
+
9,
|
| 987 |
+
10,
|
| 988 |
+
11,
|
| 989 |
+
12,
|
| 990 |
+
13,
|
| 991 |
+
14,
|
| 992 |
+
15,
|
| 993 |
+
16,
|
| 994 |
+
17,
|
| 995 |
+
18,
|
| 996 |
+
],
|
| 997 |
+
[
|
| 998 |
+
0,
|
| 999 |
+
1,
|
| 1000 |
+
1,
|
| 1001 |
+
2,
|
| 1002 |
+
2,
|
| 1003 |
+
2,
|
| 1004 |
+
2,
|
| 1005 |
+
2,
|
| 1006 |
+
2,
|
| 1007 |
+
2,
|
| 1008 |
+
2,
|
| 1009 |
+
3,
|
| 1010 |
+
3,
|
| 1011 |
+
3,
|
| 1012 |
+
3,
|
| 1013 |
+
3,
|
| 1014 |
+
3,
|
| 1015 |
+
3,
|
| 1016 |
+
3,
|
| 1017 |
+
4,
|
| 1018 |
+
4,
|
| 1019 |
+
4,
|
| 1020 |
+
4,
|
| 1021 |
+
4,
|
| 1022 |
+
4,
|
| 1023 |
+
4,
|
| 1024 |
+
4,
|
| 1025 |
+
4,
|
| 1026 |
+
4,
|
| 1027 |
+
4,
|
| 1028 |
+
4,
|
| 1029 |
+
4,
|
| 1030 |
+
4,
|
| 1031 |
+
4,
|
| 1032 |
+
4,
|
| 1033 |
+
4,
|
| 1034 |
+
4,
|
| 1035 |
+
5,
|
| 1036 |
+
5,
|
| 1037 |
+
5,
|
| 1038 |
+
5,
|
| 1039 |
+
5,
|
| 1040 |
+
5,
|
| 1041 |
+
5,
|
| 1042 |
+
5,
|
| 1043 |
+
5,
|
| 1044 |
+
5,
|
| 1045 |
+
5,
|
| 1046 |
+
5,
|
| 1047 |
+
5,
|
| 1048 |
+
5,
|
| 1049 |
+
5,
|
| 1050 |
+
5,
|
| 1051 |
+
5,
|
| 1052 |
+
5,
|
| 1053 |
+
6,
|
| 1054 |
+
6,
|
| 1055 |
+
6,
|
| 1056 |
+
6,
|
| 1057 |
+
6,
|
| 1058 |
+
6,
|
| 1059 |
+
6,
|
| 1060 |
+
6,
|
| 1061 |
+
6,
|
| 1062 |
+
6,
|
| 1063 |
+
6,
|
| 1064 |
+
6,
|
| 1065 |
+
6,
|
| 1066 |
+
6,
|
| 1067 |
+
6,
|
| 1068 |
+
6,
|
| 1069 |
+
6,
|
| 1070 |
+
6,
|
| 1071 |
+
6,
|
| 1072 |
+
6,
|
| 1073 |
+
6,
|
| 1074 |
+
6,
|
| 1075 |
+
6,
|
| 1076 |
+
6,
|
| 1077 |
+
6,
|
| 1078 |
+
6,
|
| 1079 |
+
6,
|
| 1080 |
+
6,
|
| 1081 |
+
6,
|
| 1082 |
+
6,
|
| 1083 |
+
6,
|
| 1084 |
+
6,
|
| 1085 |
+
7,
|
| 1086 |
+
7,
|
| 1087 |
+
7,
|
| 1088 |
+
7,
|
| 1089 |
+
7,
|
| 1090 |
+
7,
|
| 1091 |
+
7,
|
| 1092 |
+
7,
|
| 1093 |
+
7,
|
| 1094 |
+
7,
|
| 1095 |
+
7,
|
| 1096 |
+
7,
|
| 1097 |
+
7,
|
| 1098 |
+
7,
|
| 1099 |
+
7,
|
| 1100 |
+
7,
|
| 1101 |
+
7,
|
| 1102 |
+
7,
|
| 1103 |
+
7,
|
| 1104 |
+
7,
|
| 1105 |
+
7,
|
| 1106 |
+
7,
|
| 1107 |
+
7,
|
| 1108 |
+
7,
|
| 1109 |
+
7,
|
| 1110 |
+
7,
|
| 1111 |
+
7,
|
| 1112 |
+
7,
|
| 1113 |
+
7,
|
| 1114 |
+
7,
|
| 1115 |
+
7,
|
| 1116 |
+
7,
|
| 1117 |
+
],
|
| 1118 |
+
],
|
| 1119 |
+
dtype=torch.long,
|
| 1120 |
+
device=device,
|
| 1121 |
+
)
|
| 1122 |
+
|
| 1123 |
+
def _make_angle_features(self, A, B, C, Y):
|
| 1124 |
+
v1 = A - B
|
| 1125 |
+
v2 = C - B
|
| 1126 |
+
e1 = torch.nn.functional.normalize(v1, dim=-1)
|
| 1127 |
+
e1_v2_dot = torch.einsum("bli, bli -> bl", e1, v2)[..., None]
|
| 1128 |
+
u2 = v2 - e1 * e1_v2_dot
|
| 1129 |
+
e2 = torch.nn.functional.normalize(u2, dim=-1)
|
| 1130 |
+
e3 = torch.cross(e1, e2, dim=-1)
|
| 1131 |
+
R_residue = torch.cat(
|
| 1132 |
+
(e1[:, :, :, None], e2[:, :, :, None], e3[:, :, :, None]), dim=-1
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
local_vectors = torch.einsum(
|
| 1136 |
+
"blqp, blyq -> blyp", R_residue, Y - B[:, :, None, :]
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
rxy = torch.sqrt(local_vectors[..., 0] ** 2 + local_vectors[..., 1] ** 2 + 1e-8)
|
| 1140 |
+
f1 = local_vectors[..., 0] / rxy
|
| 1141 |
+
f2 = local_vectors[..., 1] / rxy
|
| 1142 |
+
rxyz = torch.norm(local_vectors, dim=-1) + 1e-8
|
| 1143 |
+
f3 = rxy / rxyz
|
| 1144 |
+
f4 = local_vectors[..., 2] / rxyz
|
| 1145 |
+
|
| 1146 |
+
f = torch.cat([f1[..., None], f2[..., None], f3[..., None], f4[..., None]], -1)
|
| 1147 |
+
return f
|
| 1148 |
+
|
| 1149 |
+
def _dist(self, X, mask, eps=1e-6):
|
| 1150 |
+
mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2)
|
| 1151 |
+
dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2)
|
| 1152 |
+
D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
|
| 1153 |
+
D_max, _ = torch.max(D, -1, keepdim=True)
|
| 1154 |
+
D_adjust = D + (1.0 - mask_2D) * D_max
|
| 1155 |
+
D_neighbors, E_idx = torch.topk(
|
| 1156 |
+
D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False
|
| 1157 |
+
)
|
| 1158 |
+
return D_neighbors, E_idx
|
| 1159 |
+
|
| 1160 |
+
def _rbf(self, D):
|
| 1161 |
+
device = D.device
|
| 1162 |
+
D_min, D_max, D_count = 2.0, 22.0, self.num_rbf
|
| 1163 |
+
D_mu = torch.linspace(D_min, D_max, D_count, device=device)
|
| 1164 |
+
D_mu = D_mu.view([1, 1, 1, -1])
|
| 1165 |
+
D_sigma = (D_max - D_min) / D_count
|
| 1166 |
+
D_expand = torch.unsqueeze(D, -1)
|
| 1167 |
+
RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2))
|
| 1168 |
+
return RBF
|
| 1169 |
+
|
| 1170 |
+
def _get_rbf(self, A, B, E_idx):
|
| 1171 |
+
D_A_B = torch.sqrt(
|
| 1172 |
+
torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6
|
| 1173 |
+
) # [B, L, L]
|
| 1174 |
+
D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[
|
| 1175 |
+
:, :, :, 0
|
| 1176 |
+
] # [B,L,K]
|
| 1177 |
+
RBF_A_B = self._rbf(D_A_B_neighbors)
|
| 1178 |
+
return RBF_A_B
|
| 1179 |
+
|
| 1180 |
+
def forward(self, input_features):
|
| 1181 |
+
Y = input_features["Y"]
|
| 1182 |
+
Y_m = input_features["Y_m"]
|
| 1183 |
+
Y_t = input_features["Y_t"]
|
| 1184 |
+
X = input_features["X"]
|
| 1185 |
+
mask = input_features["mask"]
|
| 1186 |
+
R_idx = input_features["R_idx"]
|
| 1187 |
+
chain_labels = input_features["chain_labels"]
|
| 1188 |
+
|
| 1189 |
+
if self.augment_eps > 0:
|
| 1190 |
+
X = X + self.augment_eps * torch.randn_like(X)
|
| 1191 |
+
Y = Y + self.augment_eps * torch.randn_like(Y)
|
| 1192 |
+
|
| 1193 |
+
B, L, _, _ = X.shape
|
| 1194 |
+
|
| 1195 |
+
Ca = X[:, :, 1, :]
|
| 1196 |
+
N = X[:, :, 0, :]
|
| 1197 |
+
C = X[:, :, 2, :]
|
| 1198 |
+
O = X[:, :, 3, :]
|
| 1199 |
+
|
| 1200 |
+
b = Ca - N
|
| 1201 |
+
c = C - Ca
|
| 1202 |
+
a = torch.cross(b, c, dim=-1)
|
| 1203 |
+
Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca # shift from CA
|
| 1204 |
+
|
| 1205 |
+
D_neighbors, E_idx = self._dist(Ca, mask)
|
| 1206 |
+
|
| 1207 |
+
RBF_all = []
|
| 1208 |
+
RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca
|
| 1209 |
+
RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N
|
| 1210 |
+
RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C
|
| 1211 |
+
RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O
|
| 1212 |
+
RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb
|
| 1213 |
+
RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N
|
| 1214 |
+
RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C
|
| 1215 |
+
RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O
|
| 1216 |
+
RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb
|
| 1217 |
+
RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C
|
| 1218 |
+
RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O
|
| 1219 |
+
RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb
|
| 1220 |
+
RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C
|
| 1221 |
+
RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O
|
| 1222 |
+
RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C
|
| 1223 |
+
RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca
|
| 1224 |
+
RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca
|
| 1225 |
+
RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca
|
| 1226 |
+
RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca
|
| 1227 |
+
RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N
|
| 1228 |
+
RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N
|
| 1229 |
+
RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N
|
| 1230 |
+
RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb
|
| 1231 |
+
RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb
|
| 1232 |
+
RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O
|
| 1233 |
+
RBF_all = torch.cat(tuple(RBF_all), dim=-1)
|
| 1234 |
+
|
| 1235 |
+
offset = R_idx[:, :, None] - R_idx[:, None, :]
|
| 1236 |
+
offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K]
|
| 1237 |
+
|
| 1238 |
+
d_chains = (
|
| 1239 |
+
(chain_labels[:, :, None] - chain_labels[:, None, :]) == 0
|
| 1240 |
+
).long() # find self vs non-self interaction
|
| 1241 |
+
E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0]
|
| 1242 |
+
E_positional = self.embeddings(offset.long(), E_chains)
|
| 1243 |
+
E = torch.cat((E_positional, RBF_all), -1)
|
| 1244 |
+
E = self.edge_embedding(E)
|
| 1245 |
+
E = self.norm_edges(E)
|
| 1246 |
+
|
| 1247 |
+
if self.use_side_chains:
|
| 1248 |
+
xyz_37 = input_features["xyz_37"]
|
| 1249 |
+
xyz_37_m = input_features["xyz_37_m"]
|
| 1250 |
+
E_idx_sub = E_idx[:, :, :16] # [B, L, 15]
|
| 1251 |
+
mask_residues = input_features["chain_mask"]
|
| 1252 |
+
xyz_37_m = xyz_37_m * (1 - mask_residues[:, :, None])
|
| 1253 |
+
R_m = gather_nodes(xyz_37_m[:, :, 5:], E_idx_sub)
|
| 1254 |
+
|
| 1255 |
+
X_sidechain = xyz_37[:, :, 5:, :].view(B, L, -1)
|
| 1256 |
+
R = gather_nodes(X_sidechain, E_idx_sub).view(
|
| 1257 |
+
B, L, E_idx_sub.shape[2], -1, 3
|
| 1258 |
+
)
|
| 1259 |
+
R_t = self.side_chain_atom_types[None, None, None, :].repeat(
|
| 1260 |
+
B, L, E_idx_sub.shape[2], 1
|
| 1261 |
+
)
|
| 1262 |
+
|
| 1263 |
+
# Side chain atom context
|
| 1264 |
+
R = R.view(B, L, -1, 3) # coordinates
|
| 1265 |
+
R_m = R_m.view(B, L, -1) # mask
|
| 1266 |
+
R_t = R_t.view(B, L, -1) # atom types
|
| 1267 |
+
|
| 1268 |
+
# Ligand atom context
|
| 1269 |
+
Y = torch.cat((R, Y), 2) # [B, L, atoms, 3]
|
| 1270 |
+
Y_m = torch.cat((R_m, Y_m), 2) # [B, L, atoms]
|
| 1271 |
+
Y_t = torch.cat((R_t, Y_t), 2) # [B, L, atoms]
|
| 1272 |
+
|
| 1273 |
+
Cb_Y_distances = torch.sum((Cb[:, :, None, :] - Y) ** 2, -1)
|
| 1274 |
+
mask_Y = mask[:, :, None] * Y_m
|
| 1275 |
+
Cb_Y_distances_adjusted = Cb_Y_distances * mask_Y + (1.0 - mask_Y) * 10000.0
|
| 1276 |
+
_, E_idx_Y = torch.topk(
|
| 1277 |
+
Cb_Y_distances_adjusted, self.atom_context_num, dim=-1, largest=False
|
| 1278 |
+
)
|
| 1279 |
+
|
| 1280 |
+
Y = torch.gather(Y, 2, E_idx_Y[:, :, :, None].repeat(1, 1, 1, 3))
|
| 1281 |
+
Y_t = torch.gather(Y_t, 2, E_idx_Y)
|
| 1282 |
+
Y_m = torch.gather(Y_m, 2, E_idx_Y)
|
| 1283 |
+
|
| 1284 |
+
Y_t = Y_t.long()
|
| 1285 |
+
Y_t_g = self.periodic_table_features[1][Y_t] # group; 19 categories including 0
|
| 1286 |
+
Y_t_p = self.periodic_table_features[2][Y_t] # period; 8 categories including 0
|
| 1287 |
+
|
| 1288 |
+
Y_t_g_1hot_ = torch.nn.functional.one_hot(Y_t_g, 19) # [B, L, M, 19]
|
| 1289 |
+
Y_t_p_1hot_ = torch.nn.functional.one_hot(Y_t_p, 8) # [B, L, M, 8]
|
| 1290 |
+
Y_t_1hot_ = torch.nn.functional.one_hot(Y_t, 120) # [B, L, M, 120]
|
| 1291 |
+
|
| 1292 |
+
Y_t_1hot_ = torch.cat(
|
| 1293 |
+
[Y_t_1hot_, Y_t_g_1hot_, Y_t_p_1hot_], -1
|
| 1294 |
+
) # [B, L, M, 147]
|
| 1295 |
+
Y_t_1hot = self.type_linear(Y_t_1hot_.float())
|
| 1296 |
+
|
| 1297 |
+
D_N_Y = self._rbf(
|
| 1298 |
+
torch.sqrt(torch.sum((N[:, :, None, :] - Y) ** 2, -1) + 1e-6)
|
| 1299 |
+
) # [B, L, M, num_bins]
|
| 1300 |
+
D_Ca_Y = self._rbf(
|
| 1301 |
+
torch.sqrt(torch.sum((Ca[:, :, None, :] - Y) ** 2, -1) + 1e-6)
|
| 1302 |
+
)
|
| 1303 |
+
D_C_Y = self._rbf(torch.sqrt(torch.sum((C[:, :, None, :] - Y) ** 2, -1) + 1e-6))
|
| 1304 |
+
D_O_Y = self._rbf(torch.sqrt(torch.sum((O[:, :, None, :] - Y) ** 2, -1) + 1e-6))
|
| 1305 |
+
D_Cb_Y = self._rbf(
|
| 1306 |
+
torch.sqrt(torch.sum((Cb[:, :, None, :] - Y) ** 2, -1) + 1e-6)
|
| 1307 |
+
)
|
| 1308 |
+
|
| 1309 |
+
f_angles = self._make_angle_features(N, Ca, C, Y) # [B, L, M, 4]
|
| 1310 |
+
|
| 1311 |
+
D_all = torch.cat(
|
| 1312 |
+
(D_N_Y, D_Ca_Y, D_C_Y, D_O_Y, D_Cb_Y, Y_t_1hot, f_angles), dim=-1
|
| 1313 |
+
) # [B,L,M,5*num_bins+5]
|
| 1314 |
+
V = self.node_project_down(D_all) # [B, L, M, node_features]
|
| 1315 |
+
V = self.norm_nodes(V)
|
| 1316 |
+
|
| 1317 |
+
Y_edges = self._rbf(
|
| 1318 |
+
torch.sqrt(
|
| 1319 |
+
torch.sum((Y[:, :, :, None, :] - Y[:, :, None, :, :]) ** 2, -1) + 1e-6
|
| 1320 |
+
)
|
| 1321 |
+
) # [B, L, M, M, num_bins]
|
| 1322 |
+
|
| 1323 |
+
Y_edges = self.y_edges(Y_edges)
|
| 1324 |
+
Y_nodes = self.y_nodes(Y_t_1hot_.float())
|
| 1325 |
+
|
| 1326 |
+
Y_edges = self.norm_y_edges(Y_edges)
|
| 1327 |
+
Y_nodes = self.norm_y_nodes(Y_nodes)
|
| 1328 |
+
|
| 1329 |
+
return V, E, E_idx, Y_nodes, Y_edges, Y_m
|
| 1330 |
+
|
| 1331 |
+
|
| 1332 |
+
class ProteinFeatures(torch.nn.Module):
|
| 1333 |
+
def __init__(
|
| 1334 |
+
self,
|
| 1335 |
+
edge_features,
|
| 1336 |
+
node_features,
|
| 1337 |
+
num_positional_embeddings=16,
|
| 1338 |
+
num_rbf=16,
|
| 1339 |
+
top_k=48,
|
| 1340 |
+
augment_eps=0.0,
|
| 1341 |
+
):
|
| 1342 |
+
"""Extract protein features"""
|
| 1343 |
+
super(ProteinFeatures, self).__init__()
|
| 1344 |
+
self.edge_features = edge_features
|
| 1345 |
+
self.node_features = node_features
|
| 1346 |
+
self.top_k = top_k
|
| 1347 |
+
self.augment_eps = augment_eps
|
| 1348 |
+
self.num_rbf = num_rbf
|
| 1349 |
+
self.num_positional_embeddings = num_positional_embeddings
|
| 1350 |
+
|
| 1351 |
+
self.embeddings = PositionalEncodings(num_positional_embeddings)
|
| 1352 |
+
edge_in = num_positional_embeddings + num_rbf * 25
|
| 1353 |
+
self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False)
|
| 1354 |
+
self.norm_edges = torch.nn.LayerNorm(edge_features)
|
| 1355 |
+
|
| 1356 |
+
def _dist(self, X, mask, eps=1e-6):
|
| 1357 |
+
mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2)
|
| 1358 |
+
dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2)
|
| 1359 |
+
D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
|
| 1360 |
+
D_max, _ = torch.max(D, -1, keepdim=True)
|
| 1361 |
+
D_adjust = D + (1.0 - mask_2D) * D_max
|
| 1362 |
+
D_neighbors, E_idx = torch.topk(
|
| 1363 |
+
D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False
|
| 1364 |
+
)
|
| 1365 |
+
return D_neighbors, E_idx
|
| 1366 |
+
|
| 1367 |
+
def _rbf(self, D):
|
| 1368 |
+
device = D.device
|
| 1369 |
+
D_min, D_max, D_count = 2.0, 22.0, self.num_rbf
|
| 1370 |
+
D_mu = torch.linspace(D_min, D_max, D_count, device=device)
|
| 1371 |
+
D_mu = D_mu.view([1, 1, 1, -1])
|
| 1372 |
+
D_sigma = (D_max - D_min) / D_count
|
| 1373 |
+
D_expand = torch.unsqueeze(D, -1)
|
| 1374 |
+
RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2))
|
| 1375 |
+
return RBF
|
| 1376 |
+
|
| 1377 |
+
def _get_rbf(self, A, B, E_idx):
|
| 1378 |
+
D_A_B = torch.sqrt(
|
| 1379 |
+
torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6
|
| 1380 |
+
) # [B, L, L]
|
| 1381 |
+
D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[
|
| 1382 |
+
:, :, :, 0
|
| 1383 |
+
] # [B,L,K]
|
| 1384 |
+
RBF_A_B = self._rbf(D_A_B_neighbors)
|
| 1385 |
+
return RBF_A_B
|
| 1386 |
+
|
| 1387 |
+
def forward(self, input_features):
|
| 1388 |
+
X = input_features["X"]
|
| 1389 |
+
mask = input_features["mask"]
|
| 1390 |
+
R_idx = input_features["R_idx"]
|
| 1391 |
+
chain_labels = input_features["chain_labels"]
|
| 1392 |
+
|
| 1393 |
+
if self.augment_eps > 0:
|
| 1394 |
+
X = X + self.augment_eps * torch.randn_like(X)
|
| 1395 |
+
|
| 1396 |
+
b = X[:, :, 1, :] - X[:, :, 0, :]
|
| 1397 |
+
c = X[:, :, 2, :] - X[:, :, 1, :]
|
| 1398 |
+
a = torch.cross(b, c, dim=-1)
|
| 1399 |
+
Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + X[:, :, 1, :]
|
| 1400 |
+
Ca = X[:, :, 1, :]
|
| 1401 |
+
N = X[:, :, 0, :]
|
| 1402 |
+
C = X[:, :, 2, :]
|
| 1403 |
+
O = X[:, :, 3, :]
|
| 1404 |
+
|
| 1405 |
+
D_neighbors, E_idx = self._dist(Ca, mask)
|
| 1406 |
+
|
| 1407 |
+
RBF_all = []
|
| 1408 |
+
RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca
|
| 1409 |
+
RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N
|
| 1410 |
+
RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C
|
| 1411 |
+
RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O
|
| 1412 |
+
RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb
|
| 1413 |
+
RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N
|
| 1414 |
+
RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C
|
| 1415 |
+
RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O
|
| 1416 |
+
RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb
|
| 1417 |
+
RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C
|
| 1418 |
+
RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O
|
| 1419 |
+
RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb
|
| 1420 |
+
RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C
|
| 1421 |
+
RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O
|
| 1422 |
+
RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C
|
| 1423 |
+
RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca
|
| 1424 |
+
RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca
|
| 1425 |
+
RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca
|
| 1426 |
+
RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca
|
| 1427 |
+
RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N
|
| 1428 |
+
RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N
|
| 1429 |
+
RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N
|
| 1430 |
+
RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb
|
| 1431 |
+
RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb
|
| 1432 |
+
RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O
|
| 1433 |
+
RBF_all = torch.cat(tuple(RBF_all), dim=-1)
|
| 1434 |
+
|
| 1435 |
+
offset = R_idx[:, :, None] - R_idx[:, None, :]
|
| 1436 |
+
offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K]
|
| 1437 |
+
|
| 1438 |
+
d_chains = (
|
| 1439 |
+
(chain_labels[:, :, None] - chain_labels[:, None, :]) == 0
|
| 1440 |
+
).long() # find self vs non-self interaction
|
| 1441 |
+
E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0]
|
| 1442 |
+
E_positional = self.embeddings(offset.long(), E_chains)
|
| 1443 |
+
E = torch.cat((E_positional, RBF_all), -1)
|
| 1444 |
+
E = self.edge_embedding(E)
|
| 1445 |
+
E = self.norm_edges(E)
|
| 1446 |
+
|
| 1447 |
+
return E, E_idx
|
| 1448 |
+
|
| 1449 |
+
|
| 1450 |
+
class ProteinFeaturesMembrane(torch.nn.Module):
|
| 1451 |
+
def __init__(
|
| 1452 |
+
self,
|
| 1453 |
+
edge_features,
|
| 1454 |
+
node_features,
|
| 1455 |
+
num_positional_embeddings=16,
|
| 1456 |
+
num_rbf=16,
|
| 1457 |
+
top_k=48,
|
| 1458 |
+
augment_eps=0.0,
|
| 1459 |
+
num_classes=3,
|
| 1460 |
+
):
|
| 1461 |
+
"""Extract protein features"""
|
| 1462 |
+
super(ProteinFeaturesMembrane, self).__init__()
|
| 1463 |
+
self.edge_features = edge_features
|
| 1464 |
+
self.node_features = node_features
|
| 1465 |
+
self.top_k = top_k
|
| 1466 |
+
self.augment_eps = augment_eps
|
| 1467 |
+
self.num_rbf = num_rbf
|
| 1468 |
+
self.num_positional_embeddings = num_positional_embeddings
|
| 1469 |
+
self.num_classes = num_classes
|
| 1470 |
+
|
| 1471 |
+
self.embeddings = PositionalEncodings(num_positional_embeddings)
|
| 1472 |
+
edge_in = num_positional_embeddings + num_rbf * 25
|
| 1473 |
+
self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False)
|
| 1474 |
+
self.norm_edges = torch.nn.LayerNorm(edge_features)
|
| 1475 |
+
|
| 1476 |
+
self.node_embedding = torch.nn.Linear(
|
| 1477 |
+
self.num_classes, node_features, bias=False
|
| 1478 |
+
)
|
| 1479 |
+
self.norm_nodes = torch.nn.LayerNorm(node_features)
|
| 1480 |
+
|
| 1481 |
+
def _dist(self, X, mask, eps=1e-6):
|
| 1482 |
+
mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2)
|
| 1483 |
+
dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2)
|
| 1484 |
+
D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
|
| 1485 |
+
D_max, _ = torch.max(D, -1, keepdim=True)
|
| 1486 |
+
D_adjust = D + (1.0 - mask_2D) * D_max
|
| 1487 |
+
D_neighbors, E_idx = torch.topk(
|
| 1488 |
+
D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False
|
| 1489 |
+
)
|
| 1490 |
+
return D_neighbors, E_idx
|
| 1491 |
+
|
| 1492 |
+
def _rbf(self, D):
|
| 1493 |
+
device = D.device
|
| 1494 |
+
D_min, D_max, D_count = 2.0, 22.0, self.num_rbf
|
| 1495 |
+
D_mu = torch.linspace(D_min, D_max, D_count, device=device)
|
| 1496 |
+
D_mu = D_mu.view([1, 1, 1, -1])
|
| 1497 |
+
D_sigma = (D_max - D_min) / D_count
|
| 1498 |
+
D_expand = torch.unsqueeze(D, -1)
|
| 1499 |
+
RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2))
|
| 1500 |
+
return RBF
|
| 1501 |
+
|
| 1502 |
+
def _get_rbf(self, A, B, E_idx):
|
| 1503 |
+
D_A_B = torch.sqrt(
|
| 1504 |
+
torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6
|
| 1505 |
+
) # [B, L, L]
|
| 1506 |
+
D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[
|
| 1507 |
+
:, :, :, 0
|
| 1508 |
+
] # [B,L,K]
|
| 1509 |
+
RBF_A_B = self._rbf(D_A_B_neighbors)
|
| 1510 |
+
return RBF_A_B
|
| 1511 |
+
|
| 1512 |
+
def forward(self, input_features):
|
| 1513 |
+
X = input_features["X"]
|
| 1514 |
+
mask = input_features["mask"]
|
| 1515 |
+
R_idx = input_features["R_idx"]
|
| 1516 |
+
chain_labels = input_features["chain_labels"]
|
| 1517 |
+
membrane_per_residue_labels = input_features["membrane_per_residue_labels"]
|
| 1518 |
+
|
| 1519 |
+
if self.augment_eps > 0:
|
| 1520 |
+
X = X + self.augment_eps * torch.randn_like(X)
|
| 1521 |
+
|
| 1522 |
+
b = X[:, :, 1, :] - X[:, :, 0, :]
|
| 1523 |
+
c = X[:, :, 2, :] - X[:, :, 1, :]
|
| 1524 |
+
a = torch.cross(b, c, dim=-1)
|
| 1525 |
+
Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + X[:, :, 1, :]
|
| 1526 |
+
Ca = X[:, :, 1, :]
|
| 1527 |
+
N = X[:, :, 0, :]
|
| 1528 |
+
C = X[:, :, 2, :]
|
| 1529 |
+
O = X[:, :, 3, :]
|
| 1530 |
+
|
| 1531 |
+
D_neighbors, E_idx = self._dist(Ca, mask)
|
| 1532 |
+
|
| 1533 |
+
RBF_all = []
|
| 1534 |
+
RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca
|
| 1535 |
+
RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N
|
| 1536 |
+
RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C
|
| 1537 |
+
RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O
|
| 1538 |
+
RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb
|
| 1539 |
+
RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N
|
| 1540 |
+
RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C
|
| 1541 |
+
RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O
|
| 1542 |
+
RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb
|
| 1543 |
+
RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C
|
| 1544 |
+
RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O
|
| 1545 |
+
RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb
|
| 1546 |
+
RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C
|
| 1547 |
+
RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O
|
| 1548 |
+
RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C
|
| 1549 |
+
RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca
|
| 1550 |
+
RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca
|
| 1551 |
+
RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca
|
| 1552 |
+
RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca
|
| 1553 |
+
RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N
|
| 1554 |
+
RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N
|
| 1555 |
+
RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N
|
| 1556 |
+
RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb
|
| 1557 |
+
RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb
|
| 1558 |
+
RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O
|
| 1559 |
+
RBF_all = torch.cat(tuple(RBF_all), dim=-1)
|
| 1560 |
+
|
| 1561 |
+
offset = R_idx[:, :, None] - R_idx[:, None, :]
|
| 1562 |
+
offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K]
|
| 1563 |
+
|
| 1564 |
+
d_chains = (
|
| 1565 |
+
(chain_labels[:, :, None] - chain_labels[:, None, :]) == 0
|
| 1566 |
+
).long() # find self vs non-self interaction
|
| 1567 |
+
E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0]
|
| 1568 |
+
E_positional = self.embeddings(offset.long(), E_chains)
|
| 1569 |
+
E = torch.cat((E_positional, RBF_all), -1)
|
| 1570 |
+
E = self.edge_embedding(E)
|
| 1571 |
+
E = self.norm_edges(E)
|
| 1572 |
+
|
| 1573 |
+
C_1hot = torch.nn.functional.one_hot(
|
| 1574 |
+
membrane_per_residue_labels, self.num_classes
|
| 1575 |
+
).float()
|
| 1576 |
+
V = self.node_embedding(C_1hot)
|
| 1577 |
+
V = self.norm_nodes(V)
|
| 1578 |
+
|
| 1579 |
+
return V, E, E_idx
|
| 1580 |
+
|
| 1581 |
+
|
| 1582 |
+
class DecLayerJ(torch.nn.Module):
|
| 1583 |
+
def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
|
| 1584 |
+
super(DecLayerJ, self).__init__()
|
| 1585 |
+
self.num_hidden = num_hidden
|
| 1586 |
+
self.num_in = num_in
|
| 1587 |
+
self.scale = scale
|
| 1588 |
+
self.dropout1 = torch.nn.Dropout(dropout)
|
| 1589 |
+
self.dropout2 = torch.nn.Dropout(dropout)
|
| 1590 |
+
self.norm1 = torch.nn.LayerNorm(num_hidden)
|
| 1591 |
+
self.norm2 = torch.nn.LayerNorm(num_hidden)
|
| 1592 |
+
|
| 1593 |
+
self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True)
|
| 1594 |
+
self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
|
| 1595 |
+
self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
|
| 1596 |
+
self.act = torch.nn.GELU()
|
| 1597 |
+
self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
|
| 1598 |
+
|
| 1599 |
+
def forward(self, h_V, h_E, mask_V=None, mask_attend=None):
|
| 1600 |
+
"""Parallel computation of full transformer layer"""
|
| 1601 |
+
|
| 1602 |
+
# Concatenate h_V_i to h_E_ij
|
| 1603 |
+
h_V_expand = h_V.unsqueeze(-2).expand(
|
| 1604 |
+
-1, -1, -1, h_E.size(-2), -1
|
| 1605 |
+
) # the only difference
|
| 1606 |
+
h_EV = torch.cat([h_V_expand, h_E], -1)
|
| 1607 |
+
|
| 1608 |
+
h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
|
| 1609 |
+
if mask_attend is not None:
|
| 1610 |
+
h_message = mask_attend.unsqueeze(-1) * h_message
|
| 1611 |
+
dh = torch.sum(h_message, -2) / self.scale
|
| 1612 |
+
|
| 1613 |
+
h_V = self.norm1(h_V + self.dropout1(dh))
|
| 1614 |
+
|
| 1615 |
+
# Position-wise feedforward
|
| 1616 |
+
dh = self.dense(h_V)
|
| 1617 |
+
h_V = self.norm2(h_V + self.dropout2(dh))
|
| 1618 |
+
|
| 1619 |
+
if mask_V is not None:
|
| 1620 |
+
mask_V = mask_V.unsqueeze(-1)
|
| 1621 |
+
h_V = mask_V * h_V
|
| 1622 |
+
return h_V
|
| 1623 |
+
|
| 1624 |
+
|
| 1625 |
+
class PositionWiseFeedForward(torch.nn.Module):
|
| 1626 |
+
def __init__(self, num_hidden, num_ff):
|
| 1627 |
+
super(PositionWiseFeedForward, self).__init__()
|
| 1628 |
+
self.W_in = torch.nn.Linear(num_hidden, num_ff, bias=True)
|
| 1629 |
+
self.W_out = torch.nn.Linear(num_ff, num_hidden, bias=True)
|
| 1630 |
+
self.act = torch.nn.GELU()
|
| 1631 |
+
|
| 1632 |
+
def forward(self, h_V):
|
| 1633 |
+
h = self.act(self.W_in(h_V))
|
| 1634 |
+
h = self.W_out(h)
|
| 1635 |
+
return h
|
| 1636 |
+
|
| 1637 |
+
|
| 1638 |
+
class PositionalEncodings(torch.nn.Module):
|
| 1639 |
+
def __init__(self, num_embeddings, max_relative_feature=32):
|
| 1640 |
+
super(PositionalEncodings, self).__init__()
|
| 1641 |
+
self.num_embeddings = num_embeddings
|
| 1642 |
+
self.max_relative_feature = max_relative_feature
|
| 1643 |
+
self.linear = torch.nn.Linear(2 * max_relative_feature + 1 + 1, num_embeddings)
|
| 1644 |
+
|
| 1645 |
+
def forward(self, offset, mask):
|
| 1646 |
+
d = torch.clip(
|
| 1647 |
+
offset + self.max_relative_feature, 0, 2 * self.max_relative_feature
|
| 1648 |
+
) * mask + (1 - mask) * (2 * self.max_relative_feature + 1)
|
| 1649 |
+
d_onehot = torch.nn.functional.one_hot(d, 2 * self.max_relative_feature + 1 + 1)
|
| 1650 |
+
E = self.linear(d_onehot.float())
|
| 1651 |
+
return E
|
| 1652 |
+
|
| 1653 |
+
|
| 1654 |
+
class DecLayer(torch.nn.Module):
|
| 1655 |
+
def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
|
| 1656 |
+
super(DecLayer, self).__init__()
|
| 1657 |
+
self.num_hidden = num_hidden
|
| 1658 |
+
self.num_in = num_in
|
| 1659 |
+
self.scale = scale
|
| 1660 |
+
self.dropout1 = torch.nn.Dropout(dropout)
|
| 1661 |
+
self.dropout2 = torch.nn.Dropout(dropout)
|
| 1662 |
+
self.norm1 = torch.nn.LayerNorm(num_hidden)
|
| 1663 |
+
self.norm2 = torch.nn.LayerNorm(num_hidden)
|
| 1664 |
+
|
| 1665 |
+
self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True)
|
| 1666 |
+
self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
|
| 1667 |
+
self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
|
| 1668 |
+
self.act = torch.nn.GELU()
|
| 1669 |
+
self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
|
| 1670 |
+
|
| 1671 |
+
def forward(self, h_V, h_E, mask_V=None, mask_attend=None):
|
| 1672 |
+
"""Parallel computation of full transformer layer"""
|
| 1673 |
+
|
| 1674 |
+
# Concatenate h_V_i to h_E_ij
|
| 1675 |
+
h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_E.size(-2), -1)
|
| 1676 |
+
h_EV = torch.cat([h_V_expand, h_E], -1)
|
| 1677 |
+
|
| 1678 |
+
h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
|
| 1679 |
+
if mask_attend is not None:
|
| 1680 |
+
h_message = mask_attend.unsqueeze(-1) * h_message
|
| 1681 |
+
dh = torch.sum(h_message, -2) / self.scale
|
| 1682 |
+
|
| 1683 |
+
h_V = self.norm1(h_V + self.dropout1(dh))
|
| 1684 |
+
|
| 1685 |
+
# Position-wise feedforward
|
| 1686 |
+
dh = self.dense(h_V)
|
| 1687 |
+
h_V = self.norm2(h_V + self.dropout2(dh))
|
| 1688 |
+
|
| 1689 |
+
if mask_V is not None:
|
| 1690 |
+
mask_V = mask_V.unsqueeze(-1)
|
| 1691 |
+
h_V = mask_V * h_V
|
| 1692 |
+
return h_V
|
| 1693 |
+
|
| 1694 |
+
|
| 1695 |
+
class EncLayer(torch.nn.Module):
|
| 1696 |
+
def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
|
| 1697 |
+
super(EncLayer, self).__init__()
|
| 1698 |
+
self.num_hidden = num_hidden
|
| 1699 |
+
self.num_in = num_in
|
| 1700 |
+
self.scale = scale
|
| 1701 |
+
self.dropout1 = torch.nn.Dropout(dropout)
|
| 1702 |
+
self.dropout2 = torch.nn.Dropout(dropout)
|
| 1703 |
+
self.dropout3 = torch.nn.Dropout(dropout)
|
| 1704 |
+
self.norm1 = torch.nn.LayerNorm(num_hidden)
|
| 1705 |
+
self.norm2 = torch.nn.LayerNorm(num_hidden)
|
| 1706 |
+
self.norm3 = torch.nn.LayerNorm(num_hidden)
|
| 1707 |
+
|
| 1708 |
+
self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True)
|
| 1709 |
+
self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
|
| 1710 |
+
self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
|
| 1711 |
+
self.W11 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True)
|
| 1712 |
+
self.W12 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
|
| 1713 |
+
self.W13 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
|
| 1714 |
+
self.act = torch.nn.GELU()
|
| 1715 |
+
self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
|
| 1716 |
+
|
| 1717 |
+
def forward(self, h_V, h_E, E_idx, mask_V=None, mask_attend=None):
|
| 1718 |
+
"""Parallel computation of full transformer layer"""
|
| 1719 |
+
|
| 1720 |
+
h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
|
| 1721 |
+
h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1)
|
| 1722 |
+
h_EV = torch.cat([h_V_expand, h_EV], -1)
|
| 1723 |
+
h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
|
| 1724 |
+
if mask_attend is not None:
|
| 1725 |
+
h_message = mask_attend.unsqueeze(-1) * h_message
|
| 1726 |
+
dh = torch.sum(h_message, -2) / self.scale
|
| 1727 |
+
h_V = self.norm1(h_V + self.dropout1(dh))
|
| 1728 |
+
|
| 1729 |
+
dh = self.dense(h_V)
|
| 1730 |
+
h_V = self.norm2(h_V + self.dropout2(dh))
|
| 1731 |
+
if mask_V is not None:
|
| 1732 |
+
mask_V = mask_V.unsqueeze(-1)
|
| 1733 |
+
h_V = mask_V * h_V
|
| 1734 |
+
|
| 1735 |
+
h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
|
| 1736 |
+
h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1)
|
| 1737 |
+
h_EV = torch.cat([h_V_expand, h_EV], -1)
|
| 1738 |
+
h_message = self.W13(self.act(self.W12(self.act(self.W11(h_EV)))))
|
| 1739 |
+
h_E = self.norm3(h_E + self.dropout3(h_message))
|
| 1740 |
+
return h_V, h_E
|
| 1741 |
+
|
| 1742 |
+
|
| 1743 |
+
# The following gather functions
|
| 1744 |
+
def gather_edges(edges, neighbor_idx):
|
| 1745 |
+
# Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C]
|
| 1746 |
+
neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1))
|
| 1747 |
+
edge_features = torch.gather(edges, 2, neighbors)
|
| 1748 |
+
return edge_features
|
| 1749 |
+
|
| 1750 |
+
|
| 1751 |
+
def gather_nodes(nodes, neighbor_idx):
|
| 1752 |
+
# Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
|
| 1753 |
+
# Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
|
| 1754 |
+
neighbors_flat = neighbor_idx.reshape((neighbor_idx.shape[0], -1))
|
| 1755 |
+
neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))
|
| 1756 |
+
# Gather and re-pack
|
| 1757 |
+
neighbor_features = torch.gather(nodes, 1, neighbors_flat)
|
| 1758 |
+
neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1])
|
| 1759 |
+
return neighbor_features
|
| 1760 |
+
|
| 1761 |
+
|
| 1762 |
+
def gather_nodes_t(nodes, neighbor_idx):
|
| 1763 |
+
# Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C]
|
| 1764 |
+
idx_flat = neighbor_idx.unsqueeze(-1).expand(-1, -1, nodes.size(2))
|
| 1765 |
+
neighbor_features = torch.gather(nodes, 1, idx_flat)
|
| 1766 |
+
return neighbor_features
|
| 1767 |
+
|
| 1768 |
+
|
| 1769 |
+
def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx):
|
| 1770 |
+
h_nodes = gather_nodes(h_nodes, E_idx)
|
| 1771 |
+
h_nn = torch.cat([h_neighbors, h_nodes], -1)
|
| 1772 |
+
return h_nn
|
openfold/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#from . import model
|
| 2 |
+
#from . import utils
|
| 3 |
+
#from . import np
|
| 4 |
+
#from . import resources
|
| 5 |
+
|
| 6 |
+
#__all__ = ["model", "utils", "np", "data", "resources"]
|
openfold/config.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import ml_collections as mlc
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def set_inf(c, inf):
|
| 6 |
+
for k, v in c.items():
|
| 7 |
+
if isinstance(v, mlc.ConfigDict):
|
| 8 |
+
set_inf(v, inf)
|
| 9 |
+
elif k == "inf":
|
| 10 |
+
c[k] = inf
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def enforce_config_constraints(config):
|
| 14 |
+
def string_to_setting(s):
|
| 15 |
+
path = s.split('.')
|
| 16 |
+
setting = config
|
| 17 |
+
for p in path:
|
| 18 |
+
setting = setting[p]
|
| 19 |
+
|
| 20 |
+
return setting
|
| 21 |
+
|
| 22 |
+
mutually_exclusive_bools = [
|
| 23 |
+
(
|
| 24 |
+
"model.template.average_templates",
|
| 25 |
+
"model.template.offload_templates"
|
| 26 |
+
)
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
for s1, s2 in mutually_exclusive_bools:
|
| 30 |
+
s1_setting = string_to_setting(s1)
|
| 31 |
+
s2_setting = string_to_setting(s2)
|
| 32 |
+
if(s1_setting and s2_setting):
|
| 33 |
+
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def model_config(name, train=False, low_prec=False):
|
| 37 |
+
c = copy.deepcopy(config)
|
| 38 |
+
if name == "initial_training":
|
| 39 |
+
# AF2 Suppl. Table 4, "initial training" setting
|
| 40 |
+
pass
|
| 41 |
+
elif name == "finetuning":
|
| 42 |
+
# AF2 Suppl. Table 4, "finetuning" setting
|
| 43 |
+
c.data.train.max_extra_msa = 5120
|
| 44 |
+
c.data.train.crop_size = 384
|
| 45 |
+
c.data.train.max_msa_clusters = 512
|
| 46 |
+
c.loss.violation.weight = 1.
|
| 47 |
+
c.loss.experimentally_resolved.weight = 0.01
|
| 48 |
+
elif name == "finetuning_ptm":
|
| 49 |
+
c.data.train.max_extra_msa = 5120
|
| 50 |
+
c.data.train.crop_size = 384
|
| 51 |
+
c.data.train.max_msa_clusters = 512
|
| 52 |
+
c.loss.violation.weight = 1.
|
| 53 |
+
c.loss.experimentally_resolved.weight = 0.01
|
| 54 |
+
c.model.heads.tm.enabled = True
|
| 55 |
+
c.loss.tm.weight = 0.1
|
| 56 |
+
elif name == "model_1":
|
| 57 |
+
# AF2 Suppl. Table 5, Model 1.1.1
|
| 58 |
+
c.data.train.max_extra_msa = 5120
|
| 59 |
+
c.data.predict.max_extra_msa = 5120
|
| 60 |
+
c.data.common.reduce_max_clusters_by_max_templates = True
|
| 61 |
+
c.data.common.use_templates = True
|
| 62 |
+
c.data.common.use_template_torsion_angles = True
|
| 63 |
+
c.model.template.enabled = True
|
| 64 |
+
elif name == "model_2":
|
| 65 |
+
# AF2 Suppl. Table 5, Model 1.1.2
|
| 66 |
+
c.data.common.reduce_max_clusters_by_max_templates = True
|
| 67 |
+
c.data.common.use_templates = True
|
| 68 |
+
c.data.common.use_template_torsion_angles = True
|
| 69 |
+
c.model.template.enabled = True
|
| 70 |
+
elif name == "model_3":
|
| 71 |
+
# AF2 Suppl. Table 5, Model 1.2.1
|
| 72 |
+
c.data.train.max_extra_msa = 5120
|
| 73 |
+
c.data.predict.max_extra_msa = 5120
|
| 74 |
+
c.model.template.enabled = False
|
| 75 |
+
elif name == "model_4":
|
| 76 |
+
# AF2 Suppl. Table 5, Model 1.2.2
|
| 77 |
+
c.data.train.max_extra_msa = 5120
|
| 78 |
+
c.data.predict.max_extra_msa = 5120
|
| 79 |
+
c.model.template.enabled = False
|
| 80 |
+
elif name == "model_5":
|
| 81 |
+
# AF2 Suppl. Table 5, Model 1.2.3
|
| 82 |
+
c.model.template.enabled = False
|
| 83 |
+
elif name == "model_1_ptm":
|
| 84 |
+
c.data.train.max_extra_msa = 5120
|
| 85 |
+
c.data.predict.max_extra_msa = 5120
|
| 86 |
+
c.data.common.reduce_max_clusters_by_max_templates = True
|
| 87 |
+
c.data.common.use_templates = True
|
| 88 |
+
c.data.common.use_template_torsion_angles = True
|
| 89 |
+
c.model.template.enabled = True
|
| 90 |
+
c.model.heads.tm.enabled = True
|
| 91 |
+
c.loss.tm.weight = 0.1
|
| 92 |
+
elif name == "model_2_ptm":
|
| 93 |
+
c.data.common.reduce_max_clusters_by_max_templates = True
|
| 94 |
+
c.data.common.use_templates = True
|
| 95 |
+
c.data.common.use_template_torsion_angles = True
|
| 96 |
+
c.model.template.enabled = True
|
| 97 |
+
c.model.heads.tm.enabled = True
|
| 98 |
+
c.loss.tm.weight = 0.1
|
| 99 |
+
elif name == "model_3_ptm":
|
| 100 |
+
c.data.train.max_extra_msa = 5120
|
| 101 |
+
c.data.predict.max_extra_msa = 5120
|
| 102 |
+
c.model.template.enabled = False
|
| 103 |
+
c.model.heads.tm.enabled = True
|
| 104 |
+
c.loss.tm.weight = 0.1
|
| 105 |
+
elif name == "model_4_ptm":
|
| 106 |
+
c.data.train.max_extra_msa = 5120
|
| 107 |
+
c.data.predict.max_extra_msa = 5120
|
| 108 |
+
c.model.template.enabled = False
|
| 109 |
+
c.model.heads.tm.enabled = True
|
| 110 |
+
c.loss.tm.weight = 0.1
|
| 111 |
+
elif name == "model_5_ptm":
|
| 112 |
+
c.model.template.enabled = False
|
| 113 |
+
c.model.heads.tm.enabled = True
|
| 114 |
+
c.loss.tm.weight = 0.1
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError("Invalid model name")
|
| 117 |
+
|
| 118 |
+
if train:
|
| 119 |
+
c.globals.blocks_per_ckpt = 1
|
| 120 |
+
c.globals.chunk_size = None
|
| 121 |
+
c.globals.use_lma = False
|
| 122 |
+
c.globals.offload_inference = False
|
| 123 |
+
c.model.template.average_templates = False
|
| 124 |
+
c.model.template.offload_templates = False
|
| 125 |
+
if low_prec:
|
| 126 |
+
c.globals.eps = 1e-4
|
| 127 |
+
# If we want exact numerical parity with the original, inf can't be
|
| 128 |
+
# a global constant
|
| 129 |
+
set_inf(c, 1e4)
|
| 130 |
+
|
| 131 |
+
enforce_config_constraints(c)
|
| 132 |
+
|
| 133 |
+
return c
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
c_z = mlc.FieldReference(128, field_type=int)
|
| 137 |
+
c_m = mlc.FieldReference(256, field_type=int)
|
| 138 |
+
c_t = mlc.FieldReference(64, field_type=int)
|
| 139 |
+
c_e = mlc.FieldReference(64, field_type=int)
|
| 140 |
+
c_s = mlc.FieldReference(384, field_type=int)
|
| 141 |
+
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
|
| 142 |
+
chunk_size = mlc.FieldReference(4, field_type=int)
|
| 143 |
+
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
|
| 144 |
+
tm_enabled = mlc.FieldReference(False, field_type=bool)
|
| 145 |
+
eps = mlc.FieldReference(1e-8, field_type=float)
|
| 146 |
+
templates_enabled = mlc.FieldReference(True, field_type=bool)
|
| 147 |
+
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
|
| 148 |
+
tune_chunk_size = mlc.FieldReference(True, field_type=bool)
|
| 149 |
+
|
| 150 |
+
NUM_RES = "num residues placeholder"
|
| 151 |
+
NUM_MSA_SEQ = "msa placeholder"
|
| 152 |
+
NUM_EXTRA_SEQ = "extra msa placeholder"
|
| 153 |
+
NUM_TEMPLATES = "num templates placeholder"
|
| 154 |
+
|
| 155 |
+
config = mlc.ConfigDict(
|
| 156 |
+
{
|
| 157 |
+
"data": {
|
| 158 |
+
"common": {
|
| 159 |
+
"feat": {
|
| 160 |
+
"aatype": [NUM_RES],
|
| 161 |
+
"all_atom_mask": [NUM_RES, None],
|
| 162 |
+
"all_atom_positions": [NUM_RES, None, None],
|
| 163 |
+
"alt_chi_angles": [NUM_RES, None],
|
| 164 |
+
"atom14_alt_gt_exists": [NUM_RES, None],
|
| 165 |
+
"atom14_alt_gt_positions": [NUM_RES, None, None],
|
| 166 |
+
"atom14_atom_exists": [NUM_RES, None],
|
| 167 |
+
"atom14_atom_is_ambiguous": [NUM_RES, None],
|
| 168 |
+
"atom14_gt_exists": [NUM_RES, None],
|
| 169 |
+
"atom14_gt_positions": [NUM_RES, None, None],
|
| 170 |
+
"atom37_atom_exists": [NUM_RES, None],
|
| 171 |
+
"backbone_rigid_mask": [NUM_RES],
|
| 172 |
+
"backbone_rigid_tensor": [NUM_RES, None, None],
|
| 173 |
+
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
|
| 174 |
+
"chi_angles_sin_cos": [NUM_RES, None, None],
|
| 175 |
+
"chi_mask": [NUM_RES, None],
|
| 176 |
+
"extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES],
|
| 177 |
+
"extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES],
|
| 178 |
+
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
|
| 179 |
+
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
|
| 180 |
+
"extra_msa_row_mask": [NUM_EXTRA_SEQ],
|
| 181 |
+
"is_distillation": [],
|
| 182 |
+
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
|
| 183 |
+
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
|
| 184 |
+
"msa_row_mask": [NUM_MSA_SEQ],
|
| 185 |
+
"no_recycling_iters": [],
|
| 186 |
+
"pseudo_beta": [NUM_RES, None],
|
| 187 |
+
"pseudo_beta_mask": [NUM_RES],
|
| 188 |
+
"residue_index": [NUM_RES],
|
| 189 |
+
"residx_atom14_to_atom37": [NUM_RES, None],
|
| 190 |
+
"residx_atom37_to_atom14": [NUM_RES, None],
|
| 191 |
+
"resolution": [],
|
| 192 |
+
"rigidgroups_alt_gt_frames": [NUM_RES, None, None, None],
|
| 193 |
+
"rigidgroups_group_exists": [NUM_RES, None],
|
| 194 |
+
"rigidgroups_group_is_ambiguous": [NUM_RES, None],
|
| 195 |
+
"rigidgroups_gt_exists": [NUM_RES, None],
|
| 196 |
+
"rigidgroups_gt_frames": [NUM_RES, None, None, None],
|
| 197 |
+
"seq_length": [],
|
| 198 |
+
"seq_mask": [NUM_RES],
|
| 199 |
+
"target_feat": [NUM_RES, None],
|
| 200 |
+
"template_aatype": [NUM_TEMPLATES, NUM_RES],
|
| 201 |
+
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
|
| 202 |
+
"template_all_atom_positions": [
|
| 203 |
+
NUM_TEMPLATES, NUM_RES, None, None,
|
| 204 |
+
],
|
| 205 |
+
"template_alt_torsion_angles_sin_cos": [
|
| 206 |
+
NUM_TEMPLATES, NUM_RES, None, None,
|
| 207 |
+
],
|
| 208 |
+
"template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
|
| 209 |
+
"template_backbone_rigid_tensor": [
|
| 210 |
+
NUM_TEMPLATES, NUM_RES, None, None,
|
| 211 |
+
],
|
| 212 |
+
"template_mask": [NUM_TEMPLATES],
|
| 213 |
+
"template_pseudo_beta": [NUM_TEMPLATES, NUM_RES, None],
|
| 214 |
+
"template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_RES],
|
| 215 |
+
"template_sum_probs": [NUM_TEMPLATES, None],
|
| 216 |
+
"template_torsion_angles_mask": [
|
| 217 |
+
NUM_TEMPLATES, NUM_RES, None,
|
| 218 |
+
],
|
| 219 |
+
"template_torsion_angles_sin_cos": [
|
| 220 |
+
NUM_TEMPLATES, NUM_RES, None, None,
|
| 221 |
+
],
|
| 222 |
+
"true_msa": [NUM_MSA_SEQ, NUM_RES],
|
| 223 |
+
"use_clamped_fape": [],
|
| 224 |
+
},
|
| 225 |
+
"masked_msa": {
|
| 226 |
+
"profile_prob": 0.1,
|
| 227 |
+
"same_prob": 0.1,
|
| 228 |
+
"uniform_prob": 0.1,
|
| 229 |
+
},
|
| 230 |
+
"max_recycling_iters": 3,
|
| 231 |
+
"msa_cluster_features": True,
|
| 232 |
+
"reduce_msa_clusters_by_max_templates": False,
|
| 233 |
+
"resample_msa_in_recycling": True,
|
| 234 |
+
"template_features": [
|
| 235 |
+
"template_all_atom_positions",
|
| 236 |
+
"template_sum_probs",
|
| 237 |
+
"template_aatype",
|
| 238 |
+
"template_all_atom_mask",
|
| 239 |
+
],
|
| 240 |
+
"unsupervised_features": [
|
| 241 |
+
"aatype",
|
| 242 |
+
"residue_index",
|
| 243 |
+
"msa",
|
| 244 |
+
"num_alignments",
|
| 245 |
+
"seq_length",
|
| 246 |
+
"between_segment_residues",
|
| 247 |
+
"deletion_matrix",
|
| 248 |
+
"no_recycling_iters",
|
| 249 |
+
],
|
| 250 |
+
"use_templates": templates_enabled,
|
| 251 |
+
"use_template_torsion_angles": embed_template_torsion_angles,
|
| 252 |
+
},
|
| 253 |
+
"supervised": {
|
| 254 |
+
"clamp_prob": 0.9,
|
| 255 |
+
"supervised_features": [
|
| 256 |
+
"all_atom_mask",
|
| 257 |
+
"all_atom_positions",
|
| 258 |
+
"resolution",
|
| 259 |
+
"use_clamped_fape",
|
| 260 |
+
"is_distillation",
|
| 261 |
+
],
|
| 262 |
+
},
|
| 263 |
+
"predict": {
|
| 264 |
+
"fixed_size": True,
|
| 265 |
+
"subsample_templates": False, # We want top templates.
|
| 266 |
+
"masked_msa_replace_fraction": 0.15,
|
| 267 |
+
"max_msa_clusters": 512,
|
| 268 |
+
"max_extra_msa": 1024,
|
| 269 |
+
"max_template_hits": 4,
|
| 270 |
+
"max_templates": 4,
|
| 271 |
+
"crop": False,
|
| 272 |
+
"crop_size": None,
|
| 273 |
+
"supervised": False,
|
| 274 |
+
"uniform_recycling": False,
|
| 275 |
+
},
|
| 276 |
+
"eval": {
|
| 277 |
+
"fixed_size": True,
|
| 278 |
+
"subsample_templates": False, # We want top templates.
|
| 279 |
+
"masked_msa_replace_fraction": 0.15,
|
| 280 |
+
"max_msa_clusters": 128,
|
| 281 |
+
"max_extra_msa": 1024,
|
| 282 |
+
"max_template_hits": 4,
|
| 283 |
+
"max_templates": 4,
|
| 284 |
+
"crop": False,
|
| 285 |
+
"crop_size": None,
|
| 286 |
+
"supervised": True,
|
| 287 |
+
"uniform_recycling": False,
|
| 288 |
+
},
|
| 289 |
+
"train": {
|
| 290 |
+
"fixed_size": True,
|
| 291 |
+
"subsample_templates": True,
|
| 292 |
+
"masked_msa_replace_fraction": 0.15,
|
| 293 |
+
"max_msa_clusters": 128,
|
| 294 |
+
"max_extra_msa": 1024,
|
| 295 |
+
"max_template_hits": 4,
|
| 296 |
+
"max_templates": 4,
|
| 297 |
+
"shuffle_top_k_prefiltered": 20,
|
| 298 |
+
"crop": True,
|
| 299 |
+
"crop_size": 256,
|
| 300 |
+
"supervised": True,
|
| 301 |
+
"clamp_prob": 0.9,
|
| 302 |
+
"max_distillation_msa_clusters": 1000,
|
| 303 |
+
"uniform_recycling": True,
|
| 304 |
+
"distillation_prob": 0.75,
|
| 305 |
+
},
|
| 306 |
+
"data_module": {
|
| 307 |
+
"use_small_bfd": False,
|
| 308 |
+
"data_loaders": {
|
| 309 |
+
"batch_size": 1,
|
| 310 |
+
"num_workers": 16,
|
| 311 |
+
},
|
| 312 |
+
},
|
| 313 |
+
},
|
| 314 |
+
# Recurring FieldReferences that can be changed globally here
|
| 315 |
+
"globals": {
|
| 316 |
+
"blocks_per_ckpt": blocks_per_ckpt,
|
| 317 |
+
"chunk_size": chunk_size,
|
| 318 |
+
"use_lma": False,
|
| 319 |
+
"offload_inference": False,
|
| 320 |
+
"c_z": c_z,
|
| 321 |
+
"c_m": c_m,
|
| 322 |
+
"c_t": c_t,
|
| 323 |
+
"c_e": c_e,
|
| 324 |
+
"c_s": c_s,
|
| 325 |
+
"eps": eps,
|
| 326 |
+
},
|
| 327 |
+
"model": {
|
| 328 |
+
"_mask_trans": False,
|
| 329 |
+
"input_embedder": {
|
| 330 |
+
"tf_dim": 22,
|
| 331 |
+
"msa_dim": 49,
|
| 332 |
+
"c_z": c_z,
|
| 333 |
+
"c_m": c_m,
|
| 334 |
+
"relpos_k": 32,
|
| 335 |
+
},
|
| 336 |
+
"recycling_embedder": {
|
| 337 |
+
"c_z": c_z,
|
| 338 |
+
"c_m": c_m,
|
| 339 |
+
"min_bin": 3.25,
|
| 340 |
+
"max_bin": 20.75,
|
| 341 |
+
"no_bins": 15,
|
| 342 |
+
"inf": 1e8,
|
| 343 |
+
},
|
| 344 |
+
"template": {
|
| 345 |
+
"distogram": {
|
| 346 |
+
"min_bin": 3.25,
|
| 347 |
+
"max_bin": 50.75,
|
| 348 |
+
"no_bins": 39,
|
| 349 |
+
},
|
| 350 |
+
"template_angle_embedder": {
|
| 351 |
+
# DISCREPANCY: c_in is supposed to be 51.
|
| 352 |
+
"c_in": 57,
|
| 353 |
+
"c_out": c_m,
|
| 354 |
+
},
|
| 355 |
+
"template_pair_embedder": {
|
| 356 |
+
"c_in": 88,
|
| 357 |
+
"c_out": c_t,
|
| 358 |
+
},
|
| 359 |
+
"template_pair_stack": {
|
| 360 |
+
"c_t": c_t,
|
| 361 |
+
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
|
| 362 |
+
# as 64. In the code, it's 16.
|
| 363 |
+
"c_hidden_tri_att": 16,
|
| 364 |
+
"c_hidden_tri_mul": 64,
|
| 365 |
+
"no_blocks": 2,
|
| 366 |
+
"no_heads": 4,
|
| 367 |
+
"pair_transition_n": 2,
|
| 368 |
+
"dropout_rate": 0.25,
|
| 369 |
+
"blocks_per_ckpt": blocks_per_ckpt,
|
| 370 |
+
"tune_chunk_size": tune_chunk_size,
|
| 371 |
+
"inf": 1e9,
|
| 372 |
+
},
|
| 373 |
+
"template_pointwise_attention": {
|
| 374 |
+
"c_t": c_t,
|
| 375 |
+
"c_z": c_z,
|
| 376 |
+
# DISCREPANCY: c_hidden here is given in the supplement as 64.
|
| 377 |
+
# It's actually 16.
|
| 378 |
+
"c_hidden": 16,
|
| 379 |
+
"no_heads": 4,
|
| 380 |
+
"inf": 1e5, # 1e9,
|
| 381 |
+
},
|
| 382 |
+
"inf": 1e5, # 1e9,
|
| 383 |
+
"eps": eps, # 1e-6,
|
| 384 |
+
"enabled": templates_enabled,
|
| 385 |
+
"embed_angles": embed_template_torsion_angles,
|
| 386 |
+
"use_unit_vector": False,
|
| 387 |
+
# Approximate template computation, saving memory.
|
| 388 |
+
# In our experiments, results are equivalent to or better than
|
| 389 |
+
# the stock implementation. Should be enabled for all new
|
| 390 |
+
# training runs.
|
| 391 |
+
"average_templates": False,
|
| 392 |
+
# Offload template embeddings to CPU memory. Vastly reduced
|
| 393 |
+
# memory consumption at the cost of a modest increase in
|
| 394 |
+
# runtime. Useful for inference on very long sequences.
|
| 395 |
+
# Mutually exclusive with average_templates.
|
| 396 |
+
"offload_templates": False,
|
| 397 |
+
},
|
| 398 |
+
"extra_msa": {
|
| 399 |
+
"extra_msa_embedder": {
|
| 400 |
+
"c_in": 25,
|
| 401 |
+
"c_out": c_e,
|
| 402 |
+
},
|
| 403 |
+
"extra_msa_stack": {
|
| 404 |
+
"c_m": c_e,
|
| 405 |
+
"c_z": c_z,
|
| 406 |
+
"c_hidden_msa_att": 8,
|
| 407 |
+
"c_hidden_opm": 32,
|
| 408 |
+
"c_hidden_mul": 128,
|
| 409 |
+
"c_hidden_pair_att": 32,
|
| 410 |
+
"no_heads_msa": 8,
|
| 411 |
+
"no_heads_pair": 4,
|
| 412 |
+
"no_blocks": 4,
|
| 413 |
+
"transition_n": 4,
|
| 414 |
+
"msa_dropout": 0.15,
|
| 415 |
+
"pair_dropout": 0.25,
|
| 416 |
+
"clear_cache_between_blocks": False,
|
| 417 |
+
"tune_chunk_size": tune_chunk_size,
|
| 418 |
+
"inf": 1e9,
|
| 419 |
+
"eps": eps, # 1e-10,
|
| 420 |
+
"ckpt": blocks_per_ckpt is not None,
|
| 421 |
+
},
|
| 422 |
+
"enabled": True,
|
| 423 |
+
},
|
| 424 |
+
"evoformer_stack": {
|
| 425 |
+
"c_m": c_m,
|
| 426 |
+
"c_z": c_z,
|
| 427 |
+
"c_hidden_msa_att": 32,
|
| 428 |
+
"c_hidden_opm": 32,
|
| 429 |
+
"c_hidden_mul": 128,
|
| 430 |
+
"c_hidden_pair_att": 32,
|
| 431 |
+
"c_s": c_s,
|
| 432 |
+
"no_heads_msa": 8,
|
| 433 |
+
"no_heads_pair": 4,
|
| 434 |
+
"no_blocks": 48,
|
| 435 |
+
"transition_n": 4,
|
| 436 |
+
"msa_dropout": 0.15,
|
| 437 |
+
"pair_dropout": 0.25,
|
| 438 |
+
"blocks_per_ckpt": blocks_per_ckpt,
|
| 439 |
+
"clear_cache_between_blocks": False,
|
| 440 |
+
"tune_chunk_size": tune_chunk_size,
|
| 441 |
+
"inf": 1e9,
|
| 442 |
+
"eps": eps, # 1e-10,
|
| 443 |
+
},
|
| 444 |
+
"structure_module": {
|
| 445 |
+
"c_s": c_s,
|
| 446 |
+
"c_z": c_z,
|
| 447 |
+
"c_ipa": 16,
|
| 448 |
+
"c_resnet": 128,
|
| 449 |
+
"no_heads_ipa": 12,
|
| 450 |
+
"no_qk_points": 4,
|
| 451 |
+
"no_v_points": 8,
|
| 452 |
+
"dropout_rate": 0.1,
|
| 453 |
+
"no_blocks": 8,
|
| 454 |
+
"no_transition_layers": 1,
|
| 455 |
+
"no_resnet_blocks": 2,
|
| 456 |
+
"no_angles": 7,
|
| 457 |
+
"trans_scale_factor": 10,
|
| 458 |
+
"epsilon": eps, # 1e-12,
|
| 459 |
+
"inf": 1e5,
|
| 460 |
+
},
|
| 461 |
+
"heads": {
|
| 462 |
+
"lddt": {
|
| 463 |
+
"no_bins": 50,
|
| 464 |
+
"c_in": c_s,
|
| 465 |
+
"c_hidden": 128,
|
| 466 |
+
},
|
| 467 |
+
"distogram": {
|
| 468 |
+
"c_z": c_z,
|
| 469 |
+
"no_bins": aux_distogram_bins,
|
| 470 |
+
},
|
| 471 |
+
"tm": {
|
| 472 |
+
"c_z": c_z,
|
| 473 |
+
"no_bins": aux_distogram_bins,
|
| 474 |
+
"enabled": tm_enabled,
|
| 475 |
+
},
|
| 476 |
+
"masked_msa": {
|
| 477 |
+
"c_m": c_m,
|
| 478 |
+
"c_out": 23,
|
| 479 |
+
},
|
| 480 |
+
"experimentally_resolved": {
|
| 481 |
+
"c_s": c_s,
|
| 482 |
+
"c_out": 37,
|
| 483 |
+
},
|
| 484 |
+
},
|
| 485 |
+
},
|
| 486 |
+
"relax": {
|
| 487 |
+
"max_iterations": 0, # no max
|
| 488 |
+
"tolerance": 2.39,
|
| 489 |
+
"stiffness": 10.0,
|
| 490 |
+
"max_outer_iterations": 20,
|
| 491 |
+
"exclude_residues": [],
|
| 492 |
+
},
|
| 493 |
+
"loss": {
|
| 494 |
+
"distogram": {
|
| 495 |
+
"min_bin": 2.3125,
|
| 496 |
+
"max_bin": 21.6875,
|
| 497 |
+
"no_bins": 64,
|
| 498 |
+
"eps": eps, # 1e-6,
|
| 499 |
+
"weight": 0.3,
|
| 500 |
+
},
|
| 501 |
+
"experimentally_resolved": {
|
| 502 |
+
"eps": eps, # 1e-8,
|
| 503 |
+
"min_resolution": 0.1,
|
| 504 |
+
"max_resolution": 3.0,
|
| 505 |
+
"weight": 0.0,
|
| 506 |
+
},
|
| 507 |
+
"fape": {
|
| 508 |
+
"backbone": {
|
| 509 |
+
"clamp_distance": 10.0,
|
| 510 |
+
"loss_unit_distance": 10.0,
|
| 511 |
+
"weight": 0.5,
|
| 512 |
+
},
|
| 513 |
+
"sidechain": {
|
| 514 |
+
"clamp_distance": 10.0,
|
| 515 |
+
"length_scale": 10.0,
|
| 516 |
+
"weight": 0.5,
|
| 517 |
+
},
|
| 518 |
+
"eps": 1e-4,
|
| 519 |
+
"weight": 1.0,
|
| 520 |
+
},
|
| 521 |
+
"lddt": {
|
| 522 |
+
"min_resolution": 0.1,
|
| 523 |
+
"max_resolution": 3.0,
|
| 524 |
+
"cutoff": 15.0,
|
| 525 |
+
"no_bins": 50,
|
| 526 |
+
"eps": eps, # 1e-10,
|
| 527 |
+
"weight": 0.01,
|
| 528 |
+
},
|
| 529 |
+
"masked_msa": {
|
| 530 |
+
"eps": eps, # 1e-8,
|
| 531 |
+
"weight": 2.0,
|
| 532 |
+
},
|
| 533 |
+
"supervised_chi": {
|
| 534 |
+
"chi_weight": 0.5,
|
| 535 |
+
"angle_norm_weight": 0.01,
|
| 536 |
+
"eps": eps, # 1e-6,
|
| 537 |
+
"weight": 1.0,
|
| 538 |
+
},
|
| 539 |
+
"violation": {
|
| 540 |
+
"violation_tolerance_factor": 12.0,
|
| 541 |
+
"clash_overlap_tolerance": 1.5,
|
| 542 |
+
"eps": eps, # 1e-6,
|
| 543 |
+
"weight": 0.0,
|
| 544 |
+
},
|
| 545 |
+
"tm": {
|
| 546 |
+
"max_bin": 31,
|
| 547 |
+
"no_bins": 64,
|
| 548 |
+
"min_resolution": 0.1,
|
| 549 |
+
"max_resolution": 3.0,
|
| 550 |
+
"eps": eps, # 1e-8,
|
| 551 |
+
"weight": 0.,
|
| 552 |
+
"enabled": tm_enabled,
|
| 553 |
+
},
|
| 554 |
+
"eps": eps,
|
| 555 |
+
},
|
| 556 |
+
"ema": {"decay": 0.999},
|
| 557 |
+
}
|
| 558 |
+
)
|
openfold/data/__init__.py
ADDED
|
File without changes
|
openfold/data/data_modules.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from functools import partial
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import pickle
|
| 7 |
+
from typing import Optional, Sequence, List, Any
|
| 8 |
+
|
| 9 |
+
import ml_collections as mlc
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pytorch_lightning as pl
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import RandomSampler
|
| 14 |
+
|
| 15 |
+
from openfold.data import (
|
| 16 |
+
data_pipeline,
|
| 17 |
+
feature_pipeline,
|
| 18 |
+
mmcif_parsing,
|
| 19 |
+
templates,
|
| 20 |
+
)
|
| 21 |
+
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class OpenFoldSingleDataset(torch.utils.data.Dataset):
|
| 25 |
+
def __init__(self,
|
| 26 |
+
data_dir: str,
|
| 27 |
+
alignment_dir: str,
|
| 28 |
+
template_mmcif_dir: str,
|
| 29 |
+
max_template_date: str,
|
| 30 |
+
config: mlc.ConfigDict,
|
| 31 |
+
kalign_binary_path: str = '/usr/bin/kalign',
|
| 32 |
+
max_template_hits: int = 4,
|
| 33 |
+
obsolete_pdbs_file_path: Optional[str] = None,
|
| 34 |
+
template_release_dates_cache_path: Optional[str] = None,
|
| 35 |
+
shuffle_top_k_prefiltered: Optional[int] = None,
|
| 36 |
+
treat_pdb_as_distillation: bool = True,
|
| 37 |
+
mapping_path: Optional[str] = None,
|
| 38 |
+
mode: str = "train",
|
| 39 |
+
alignment_index: Optional[Any] = None,
|
| 40 |
+
_output_raw: bool = False,
|
| 41 |
+
_structure_index: Optional[Any] = None,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Args:
|
| 45 |
+
data_dir:
|
| 46 |
+
A path to a directory containing mmCIF files (in train
|
| 47 |
+
mode) or FASTA files (in inference mode).
|
| 48 |
+
alignment_dir:
|
| 49 |
+
A path to a directory containing only data in the format
|
| 50 |
+
output by an AlignmentRunner
|
| 51 |
+
(defined in openfold.features.alignment_runner).
|
| 52 |
+
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
|
| 53 |
+
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
|
| 54 |
+
files.
|
| 55 |
+
template_mmcif_dir:
|
| 56 |
+
Path to a directory containing template mmCIF files.
|
| 57 |
+
config:
|
| 58 |
+
A dataset config object. See openfold.config
|
| 59 |
+
kalign_binary_path:
|
| 60 |
+
Path to kalign binary.
|
| 61 |
+
max_template_hits:
|
| 62 |
+
An upper bound on how many templates are considered. During
|
| 63 |
+
training, the templates ultimately used are subsampled
|
| 64 |
+
from this total quantity.
|
| 65 |
+
template_release_dates_cache_path:
|
| 66 |
+
Path to the output of scripts/generate_mmcif_cache.
|
| 67 |
+
obsolete_pdbs_file_path:
|
| 68 |
+
Path to the file containing replacements for obsolete PDBs.
|
| 69 |
+
shuffle_top_k_prefiltered:
|
| 70 |
+
Whether to uniformly shuffle the top k template hits before
|
| 71 |
+
parsing max_template_hits of them. Can be used to
|
| 72 |
+
approximate DeepMind's training-time template subsampling
|
| 73 |
+
scheme much more performantly.
|
| 74 |
+
treat_pdb_as_distillation:
|
| 75 |
+
Whether to assume that .pdb files in the data_dir are from
|
| 76 |
+
the self-distillation set (and should be subjected to
|
| 77 |
+
special distillation set preprocessing steps).
|
| 78 |
+
mode:
|
| 79 |
+
"train", "val", or "predict"
|
| 80 |
+
"""
|
| 81 |
+
super(OpenFoldSingleDataset, self).__init__()
|
| 82 |
+
self.data_dir = data_dir
|
| 83 |
+
self.alignment_dir = alignment_dir
|
| 84 |
+
self.config = config
|
| 85 |
+
self.treat_pdb_as_distillation = treat_pdb_as_distillation
|
| 86 |
+
self.mode = mode
|
| 87 |
+
self.alignment_index = alignment_index
|
| 88 |
+
self._output_raw = _output_raw
|
| 89 |
+
self._structure_index = _structure_index
|
| 90 |
+
|
| 91 |
+
self.supported_exts = [".cif", ".core", ".pdb"]
|
| 92 |
+
|
| 93 |
+
valid_modes = ["train", "eval", "predict"]
|
| 94 |
+
if(mode not in valid_modes):
|
| 95 |
+
raise ValueError(f'mode must be one of {valid_modes}')
|
| 96 |
+
|
| 97 |
+
if(template_release_dates_cache_path is None):
|
| 98 |
+
logging.warning(
|
| 99 |
+
"Template release dates cache does not exist. Remember to run "
|
| 100 |
+
"scripts/generate_mmcif_cache.py before running OpenFold"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if(alignment_index is not None):
|
| 104 |
+
self._chain_ids = list(alignment_index.keys())
|
| 105 |
+
elif(mapping_path is None):
|
| 106 |
+
self._chain_ids = list(os.listdir(alignment_dir))
|
| 107 |
+
else:
|
| 108 |
+
with open(mapping_path, "r") as f:
|
| 109 |
+
self._chain_ids = [l.strip() for l in f.readlines()]
|
| 110 |
+
|
| 111 |
+
self._chain_id_to_idx_dict = {
|
| 112 |
+
chain: i for i, chain in enumerate(self._chain_ids)
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
template_featurizer = templates.TemplateHitFeaturizer(
|
| 116 |
+
mmcif_dir=template_mmcif_dir,
|
| 117 |
+
max_template_date=max_template_date,
|
| 118 |
+
max_hits=max_template_hits,
|
| 119 |
+
kalign_binary_path=kalign_binary_path,
|
| 120 |
+
release_dates_path=template_release_dates_cache_path,
|
| 121 |
+
obsolete_pdbs_path=obsolete_pdbs_file_path,
|
| 122 |
+
_shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.data_pipeline = data_pipeline.DataPipeline(
|
| 126 |
+
template_featurizer=template_featurizer,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if(not self._output_raw):
|
| 130 |
+
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
|
| 131 |
+
|
| 132 |
+
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
|
| 133 |
+
with open(path, 'r') as f:
|
| 134 |
+
mmcif_string = f.read()
|
| 135 |
+
|
| 136 |
+
mmcif_object = mmcif_parsing.parse(
|
| 137 |
+
file_id=file_id, mmcif_string=mmcif_string
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Crash if an error is encountered. Any parsing errors should have
|
| 141 |
+
# been dealt with at the alignment stage.
|
| 142 |
+
if(mmcif_object.mmcif_object is None):
|
| 143 |
+
raise list(mmcif_object.errors.values())[0]
|
| 144 |
+
|
| 145 |
+
mmcif_object = mmcif_object.mmcif_object
|
| 146 |
+
|
| 147 |
+
data = self.data_pipeline.process_mmcif(
|
| 148 |
+
mmcif=mmcif_object,
|
| 149 |
+
alignment_dir=alignment_dir,
|
| 150 |
+
chain_id=chain_id,
|
| 151 |
+
alignment_index=alignment_index
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return data
|
| 155 |
+
|
| 156 |
+
def chain_id_to_idx(self, chain_id):
|
| 157 |
+
return self._chain_id_to_idx_dict[chain_id]
|
| 158 |
+
|
| 159 |
+
def idx_to_chain_id(self, idx):
|
| 160 |
+
return self._chain_ids[idx]
|
| 161 |
+
|
| 162 |
+
def __getitem__(self, idx):
|
| 163 |
+
name = self.idx_to_chain_id(idx)
|
| 164 |
+
alignment_dir = os.path.join(self.alignment_dir, name)
|
| 165 |
+
|
| 166 |
+
alignment_index = None
|
| 167 |
+
if(self.alignment_index is not None):
|
| 168 |
+
alignment_dir = self.alignment_dir
|
| 169 |
+
alignment_index = self.alignment_index[name]
|
| 170 |
+
|
| 171 |
+
if(self.mode == 'train' or self.mode == 'eval'):
|
| 172 |
+
spl = name.rsplit('_', 1)
|
| 173 |
+
if(len(spl) == 2):
|
| 174 |
+
file_id, chain_id = spl
|
| 175 |
+
else:
|
| 176 |
+
file_id, = spl
|
| 177 |
+
chain_id = None
|
| 178 |
+
|
| 179 |
+
path = os.path.join(self.data_dir, file_id)
|
| 180 |
+
structure_index_entry = None
|
| 181 |
+
if(self._structure_index is not None):
|
| 182 |
+
structure_index_entry = self._structure_index[name]
|
| 183 |
+
assert(len(structure_index_entry["files"]) == 1)
|
| 184 |
+
filename, _, _ = structure_index_entry["files"][0]
|
| 185 |
+
ext = os.path.splitext(filename)[1]
|
| 186 |
+
else:
|
| 187 |
+
ext = None
|
| 188 |
+
for e in self.supported_exts:
|
| 189 |
+
if(os.path.exists(path + e)):
|
| 190 |
+
ext = e
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
+
if(ext is None):
|
| 194 |
+
raise ValueError("Invalid file type")
|
| 195 |
+
|
| 196 |
+
path += ext
|
| 197 |
+
if(ext == ".cif"):
|
| 198 |
+
data = self._parse_mmcif(
|
| 199 |
+
path, file_id, chain_id, alignment_dir, alignment_index,
|
| 200 |
+
)
|
| 201 |
+
elif(ext == ".core"):
|
| 202 |
+
data = self.data_pipeline.process_core(
|
| 203 |
+
path, alignment_dir, alignment_index,
|
| 204 |
+
)
|
| 205 |
+
elif(ext == ".pdb"):
|
| 206 |
+
data = self.data_pipeline.process_pdb(
|
| 207 |
+
pdb_path=path,
|
| 208 |
+
alignment_dir=alignment_dir,
|
| 209 |
+
is_distillation=self.treat_pdb_as_distillation,
|
| 210 |
+
chain_id=chain_id,
|
| 211 |
+
alignment_index=alignment_index,
|
| 212 |
+
_structure_index=self._structure_index[name],
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
raise ValueError("Extension branch missing")
|
| 216 |
+
else:
|
| 217 |
+
path = os.path.join(name, name + ".fasta")
|
| 218 |
+
data = self.data_pipeline.process_fasta(
|
| 219 |
+
fasta_path=path,
|
| 220 |
+
alignment_dir=alignment_dir,
|
| 221 |
+
alignment_index=alignment_index,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if(self._output_raw):
|
| 225 |
+
return data
|
| 226 |
+
|
| 227 |
+
feats = self.feature_pipeline.process_features(
|
| 228 |
+
data, self.mode
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
feats["batch_idx"] = torch.tensor([idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device)
|
| 232 |
+
|
| 233 |
+
return feats
|
| 234 |
+
|
| 235 |
+
def __len__(self):
|
| 236 |
+
return len(self._chain_ids)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def deterministic_train_filter(
|
| 240 |
+
chain_data_cache_entry: Any,
|
| 241 |
+
max_resolution: float = 9.,
|
| 242 |
+
max_single_aa_prop: float = 0.8,
|
| 243 |
+
) -> bool:
|
| 244 |
+
# Hard filters
|
| 245 |
+
resolution = chain_data_cache_entry.get("resolution", None)
|
| 246 |
+
if(resolution is not None and resolution > max_resolution):
|
| 247 |
+
return False
|
| 248 |
+
|
| 249 |
+
seq = chain_data_cache_entry["seq"]
|
| 250 |
+
counts = {}
|
| 251 |
+
for aa in seq:
|
| 252 |
+
counts.setdefault(aa, 0)
|
| 253 |
+
counts[aa] += 1
|
| 254 |
+
largest_aa_count = max(counts.values())
|
| 255 |
+
largest_single_aa_prop = largest_aa_count / len(seq)
|
| 256 |
+
if(largest_single_aa_prop > max_single_aa_prop):
|
| 257 |
+
return False
|
| 258 |
+
|
| 259 |
+
return True
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def get_stochastic_train_filter_prob(
|
| 263 |
+
chain_data_cache_entry: Any,
|
| 264 |
+
) -> List[float]:
|
| 265 |
+
# Stochastic filters
|
| 266 |
+
probabilities = []
|
| 267 |
+
|
| 268 |
+
cluster_size = chain_data_cache_entry.get("cluster_size", None)
|
| 269 |
+
if(cluster_size is not None and cluster_size > 0):
|
| 270 |
+
probabilities.append(1 / cluster_size)
|
| 271 |
+
|
| 272 |
+
chain_length = len(chain_data_cache_entry["seq"])
|
| 273 |
+
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
|
| 274 |
+
|
| 275 |
+
# Risk of underflow here?
|
| 276 |
+
out = 1
|
| 277 |
+
for p in probabilities:
|
| 278 |
+
out *= p
|
| 279 |
+
|
| 280 |
+
return out
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class OpenFoldDataset(torch.utils.data.Dataset):
|
| 284 |
+
"""
|
| 285 |
+
Implements the stochastic filters applied during AlphaFold's training.
|
| 286 |
+
Because samples are selected from constituent datasets randomly, the
|
| 287 |
+
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
|
| 288 |
+
and filtered once at initialization.
|
| 289 |
+
"""
|
| 290 |
+
def __init__(self,
|
| 291 |
+
datasets: Sequence[OpenFoldSingleDataset],
|
| 292 |
+
probabilities: Sequence[int],
|
| 293 |
+
epoch_len: int,
|
| 294 |
+
chain_data_cache_paths: List[str],
|
| 295 |
+
generator: torch.Generator = None,
|
| 296 |
+
_roll_at_init: bool = True,
|
| 297 |
+
):
|
| 298 |
+
self.datasets = datasets
|
| 299 |
+
self.probabilities = probabilities
|
| 300 |
+
self.epoch_len = epoch_len
|
| 301 |
+
self.generator = generator
|
| 302 |
+
|
| 303 |
+
self.chain_data_caches = []
|
| 304 |
+
for path in chain_data_cache_paths:
|
| 305 |
+
with open(path, "r") as fp:
|
| 306 |
+
self.chain_data_caches.append(json.load(fp))
|
| 307 |
+
|
| 308 |
+
def looped_shuffled_dataset_idx(dataset_len):
|
| 309 |
+
while True:
|
| 310 |
+
# Uniformly shuffle each dataset's indices
|
| 311 |
+
weights = [1. for _ in range(dataset_len)]
|
| 312 |
+
shuf = torch.multinomial(
|
| 313 |
+
torch.tensor(weights),
|
| 314 |
+
num_samples=dataset_len,
|
| 315 |
+
replacement=False,
|
| 316 |
+
generator=self.generator,
|
| 317 |
+
)
|
| 318 |
+
for idx in shuf:
|
| 319 |
+
yield idx
|
| 320 |
+
|
| 321 |
+
def looped_samples(dataset_idx):
|
| 322 |
+
max_cache_len = int(epoch_len * probabilities[dataset_idx])
|
| 323 |
+
dataset = self.datasets[dataset_idx]
|
| 324 |
+
idx_iter = looped_shuffled_dataset_idx(len(dataset))
|
| 325 |
+
chain_data_cache = self.chain_data_caches[dataset_idx]
|
| 326 |
+
while True:
|
| 327 |
+
weights = []
|
| 328 |
+
idx = []
|
| 329 |
+
for _ in range(max_cache_len):
|
| 330 |
+
candidate_idx = next(idx_iter)
|
| 331 |
+
chain_id = dataset.idx_to_chain_id(candidate_idx)
|
| 332 |
+
chain_data_cache_entry = chain_data_cache[chain_id]
|
| 333 |
+
if(not deterministic_train_filter(chain_data_cache_entry)):
|
| 334 |
+
continue
|
| 335 |
+
|
| 336 |
+
p = get_stochastic_train_filter_prob(
|
| 337 |
+
chain_data_cache_entry,
|
| 338 |
+
)
|
| 339 |
+
weights.append([1. - p, p])
|
| 340 |
+
idx.append(candidate_idx)
|
| 341 |
+
|
| 342 |
+
samples = torch.multinomial(
|
| 343 |
+
torch.tensor(weights),
|
| 344 |
+
num_samples=1,
|
| 345 |
+
generator=self.generator,
|
| 346 |
+
)
|
| 347 |
+
samples = samples.squeeze()
|
| 348 |
+
|
| 349 |
+
cache = [i for i, s in zip(idx, samples) if s]
|
| 350 |
+
|
| 351 |
+
for datapoint_idx in cache:
|
| 352 |
+
yield datapoint_idx
|
| 353 |
+
|
| 354 |
+
self._samples = [looped_samples(i) for i in range(len(self.datasets))]
|
| 355 |
+
|
| 356 |
+
if(_roll_at_init):
|
| 357 |
+
self.reroll()
|
| 358 |
+
|
| 359 |
+
def __getitem__(self, idx):
|
| 360 |
+
dataset_idx, datapoint_idx = self.datapoints[idx]
|
| 361 |
+
return self.datasets[dataset_idx][datapoint_idx]
|
| 362 |
+
|
| 363 |
+
def __len__(self):
|
| 364 |
+
return self.epoch_len
|
| 365 |
+
|
| 366 |
+
def reroll(self):
|
| 367 |
+
dataset_choices = torch.multinomial(
|
| 368 |
+
torch.tensor(self.probabilities),
|
| 369 |
+
num_samples=self.epoch_len,
|
| 370 |
+
replacement=True,
|
| 371 |
+
generator=self.generator,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
self.datapoints = []
|
| 375 |
+
for dataset_idx in dataset_choices:
|
| 376 |
+
samples = self._samples[dataset_idx]
|
| 377 |
+
datapoint_idx = next(samples)
|
| 378 |
+
self.datapoints.append((dataset_idx, datapoint_idx))
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class OpenFoldBatchCollator:
|
| 382 |
+
def __call__(self, prots):
|
| 383 |
+
stack_fn = partial(torch.stack, dim=0)
|
| 384 |
+
return dict_multimap(stack_fn, prots)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class OpenFoldDataLoader(torch.utils.data.DataLoader):
|
| 388 |
+
def __init__(self, *args, config, stage="train", generator=None, **kwargs):
|
| 389 |
+
super().__init__(*args, **kwargs)
|
| 390 |
+
self.config = config
|
| 391 |
+
self.stage = stage
|
| 392 |
+
|
| 393 |
+
if(generator is None):
|
| 394 |
+
generator = torch.Generator()
|
| 395 |
+
|
| 396 |
+
self.generator = generator
|
| 397 |
+
self._prep_batch_properties_probs()
|
| 398 |
+
|
| 399 |
+
def _prep_batch_properties_probs(self):
|
| 400 |
+
keyed_probs = []
|
| 401 |
+
stage_cfg = self.config[self.stage]
|
| 402 |
+
|
| 403 |
+
max_iters = self.config.common.max_recycling_iters
|
| 404 |
+
if(stage_cfg.supervised):
|
| 405 |
+
clamp_prob = self.config.supervised.clamp_prob
|
| 406 |
+
keyed_probs.append(
|
| 407 |
+
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if(stage_cfg.uniform_recycling):
|
| 411 |
+
recycling_probs = [
|
| 412 |
+
1. / (max_iters + 1) for _ in range(max_iters + 1)
|
| 413 |
+
]
|
| 414 |
+
else:
|
| 415 |
+
recycling_probs = [
|
| 416 |
+
0. for _ in range(max_iters + 1)
|
| 417 |
+
]
|
| 418 |
+
recycling_probs[-1] = 1.
|
| 419 |
+
|
| 420 |
+
keyed_probs.append(
|
| 421 |
+
("no_recycling_iters", recycling_probs)
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
keys, probs = zip(*keyed_probs)
|
| 425 |
+
max_len = max([len(p) for p in probs])
|
| 426 |
+
padding = [[0.] * (max_len - len(p)) for p in probs]
|
| 427 |
+
|
| 428 |
+
self.prop_keys = keys
|
| 429 |
+
self.prop_probs_tensor = torch.tensor(
|
| 430 |
+
[p + pad for p, pad in zip(probs, padding)],
|
| 431 |
+
dtype=torch.float32,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
def _add_batch_properties(self, batch):
|
| 435 |
+
samples = torch.multinomial(
|
| 436 |
+
self.prop_probs_tensor,
|
| 437 |
+
num_samples=1, # 1 per row
|
| 438 |
+
replacement=True,
|
| 439 |
+
generator=self.generator
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
aatype = batch["aatype"]
|
| 443 |
+
batch_dims = aatype.shape[:-2]
|
| 444 |
+
recycling_dim = aatype.shape[-1]
|
| 445 |
+
no_recycling = recycling_dim
|
| 446 |
+
for i, key in enumerate(self.prop_keys):
|
| 447 |
+
sample = int(samples[i][0])
|
| 448 |
+
sample_tensor = torch.tensor(
|
| 449 |
+
sample,
|
| 450 |
+
device=aatype.device,
|
| 451 |
+
requires_grad=False
|
| 452 |
+
)
|
| 453 |
+
orig_shape = sample_tensor.shape
|
| 454 |
+
sample_tensor = sample_tensor.view(
|
| 455 |
+
(1,) * len(batch_dims) + sample_tensor.shape + (1,)
|
| 456 |
+
)
|
| 457 |
+
sample_tensor = sample_tensor.expand(
|
| 458 |
+
batch_dims + orig_shape + (recycling_dim,)
|
| 459 |
+
)
|
| 460 |
+
batch[key] = sample_tensor
|
| 461 |
+
|
| 462 |
+
if(key == "no_recycling_iters"):
|
| 463 |
+
no_recycling = sample
|
| 464 |
+
|
| 465 |
+
resample_recycling = lambda t: t[..., :no_recycling + 1]
|
| 466 |
+
batch = tensor_tree_map(resample_recycling, batch)
|
| 467 |
+
|
| 468 |
+
return batch
|
| 469 |
+
|
| 470 |
+
def __iter__(self):
|
| 471 |
+
it = super().__iter__()
|
| 472 |
+
|
| 473 |
+
def _batch_prop_gen(iterator):
|
| 474 |
+
for batch in iterator:
|
| 475 |
+
yield self._add_batch_properties(batch)
|
| 476 |
+
|
| 477 |
+
return _batch_prop_gen(it)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class OpenFoldDataModule(pl.LightningDataModule):
|
| 481 |
+
def __init__(self,
|
| 482 |
+
config: mlc.ConfigDict,
|
| 483 |
+
template_mmcif_dir: str,
|
| 484 |
+
max_template_date: str,
|
| 485 |
+
train_data_dir: Optional[str] = None,
|
| 486 |
+
train_alignment_dir: Optional[str] = None,
|
| 487 |
+
train_chain_data_cache_path: Optional[str] = None,
|
| 488 |
+
distillation_data_dir: Optional[str] = None,
|
| 489 |
+
distillation_alignment_dir: Optional[str] = None,
|
| 490 |
+
distillation_chain_data_cache_path: Optional[str] = None,
|
| 491 |
+
val_data_dir: Optional[str] = None,
|
| 492 |
+
val_alignment_dir: Optional[str] = None,
|
| 493 |
+
predict_data_dir: Optional[str] = None,
|
| 494 |
+
predict_alignment_dir: Optional[str] = None,
|
| 495 |
+
kalign_binary_path: str = '/usr/bin/kalign',
|
| 496 |
+
train_mapping_path: Optional[str] = None,
|
| 497 |
+
distillation_mapping_path: Optional[str] = None,
|
| 498 |
+
obsolete_pdbs_file_path: Optional[str] = None,
|
| 499 |
+
template_release_dates_cache_path: Optional[str] = None,
|
| 500 |
+
batch_seed: Optional[int] = None,
|
| 501 |
+
train_epoch_len: int = 50000,
|
| 502 |
+
_distillation_structure_index_path: Optional[str] = None,
|
| 503 |
+
alignment_index_path: Optional[str] = None,
|
| 504 |
+
distillation_alignment_index_path: Optional[str] = None,
|
| 505 |
+
**kwargs
|
| 506 |
+
):
|
| 507 |
+
super(OpenFoldDataModule, self).__init__()
|
| 508 |
+
|
| 509 |
+
self.config = config
|
| 510 |
+
self.template_mmcif_dir = template_mmcif_dir
|
| 511 |
+
self.max_template_date = max_template_date
|
| 512 |
+
self.train_data_dir = train_data_dir
|
| 513 |
+
self.train_alignment_dir = train_alignment_dir
|
| 514 |
+
self.train_chain_data_cache_path = train_chain_data_cache_path
|
| 515 |
+
self.distillation_data_dir = distillation_data_dir
|
| 516 |
+
self.distillation_alignment_dir = distillation_alignment_dir
|
| 517 |
+
self.distillation_chain_data_cache_path = (
|
| 518 |
+
distillation_chain_data_cache_path
|
| 519 |
+
)
|
| 520 |
+
self.val_data_dir = val_data_dir
|
| 521 |
+
self.val_alignment_dir = val_alignment_dir
|
| 522 |
+
self.predict_data_dir = predict_data_dir
|
| 523 |
+
self.predict_alignment_dir = predict_alignment_dir
|
| 524 |
+
self.kalign_binary_path = kalign_binary_path
|
| 525 |
+
self.train_mapping_path = train_mapping_path
|
| 526 |
+
self.distillation_mapping_path = distillation_mapping_path
|
| 527 |
+
self.template_release_dates_cache_path = (
|
| 528 |
+
template_release_dates_cache_path
|
| 529 |
+
)
|
| 530 |
+
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
|
| 531 |
+
self.batch_seed = batch_seed
|
| 532 |
+
self.train_epoch_len = train_epoch_len
|
| 533 |
+
|
| 534 |
+
if(self.train_data_dir is None and self.predict_data_dir is None):
|
| 535 |
+
raise ValueError(
|
| 536 |
+
'At least one of train_data_dir or predict_data_dir must be '
|
| 537 |
+
'specified'
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
self.training_mode = self.train_data_dir is not None
|
| 541 |
+
|
| 542 |
+
if(self.training_mode and train_alignment_dir is None):
|
| 543 |
+
raise ValueError(
|
| 544 |
+
'In training mode, train_alignment_dir must be specified'
|
| 545 |
+
)
|
| 546 |
+
elif(not self.training_mode and predict_alignment_dir is None):
|
| 547 |
+
raise ValueError(
|
| 548 |
+
'In inference mode, predict_alignment_dir must be specified'
|
| 549 |
+
)
|
| 550 |
+
elif(val_data_dir is not None and val_alignment_dir is None):
|
| 551 |
+
raise ValueError(
|
| 552 |
+
'If val_data_dir is specified, val_alignment_dir must '
|
| 553 |
+
'be specified as well'
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
# An ad-hoc measure for our particular filesystem restrictions
|
| 557 |
+
self._distillation_structure_index = None
|
| 558 |
+
if(_distillation_structure_index_path is not None):
|
| 559 |
+
with open(_distillation_structure_index_path, "r") as fp:
|
| 560 |
+
self._distillation_structure_index = json.load(fp)
|
| 561 |
+
|
| 562 |
+
self.alignment_index = None
|
| 563 |
+
if(alignment_index_path is not None):
|
| 564 |
+
with open(alignment_index_path, "r") as fp:
|
| 565 |
+
self.alignment_index = json.load(fp)
|
| 566 |
+
|
| 567 |
+
self.distillation_alignment_index = None
|
| 568 |
+
if(distillation_alignment_index_path is not None):
|
| 569 |
+
with open(distillation_alignment_index_path, "r") as fp:
|
| 570 |
+
self.distillation_alignment_index = json.load(fp)
|
| 571 |
+
|
| 572 |
+
def setup(self):
|
| 573 |
+
# Most of the arguments are the same for the three datasets
|
| 574 |
+
dataset_gen = partial(OpenFoldSingleDataset,
|
| 575 |
+
template_mmcif_dir=self.template_mmcif_dir,
|
| 576 |
+
max_template_date=self.max_template_date,
|
| 577 |
+
config=self.config,
|
| 578 |
+
kalign_binary_path=self.kalign_binary_path,
|
| 579 |
+
template_release_dates_cache_path=
|
| 580 |
+
self.template_release_dates_cache_path,
|
| 581 |
+
obsolete_pdbs_file_path=
|
| 582 |
+
self.obsolete_pdbs_file_path,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
if(self.training_mode):
|
| 586 |
+
train_dataset = dataset_gen(
|
| 587 |
+
data_dir=self.train_data_dir,
|
| 588 |
+
alignment_dir=self.train_alignment_dir,
|
| 589 |
+
mapping_path=self.train_mapping_path,
|
| 590 |
+
max_template_hits=self.config.train.max_template_hits,
|
| 591 |
+
shuffle_top_k_prefiltered=
|
| 592 |
+
self.config.train.shuffle_top_k_prefiltered,
|
| 593 |
+
treat_pdb_as_distillation=False,
|
| 594 |
+
mode="train",
|
| 595 |
+
alignment_index=self.alignment_index,
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
distillation_dataset = None
|
| 599 |
+
if(self.distillation_data_dir is not None):
|
| 600 |
+
distillation_dataset = dataset_gen(
|
| 601 |
+
data_dir=self.distillation_data_dir,
|
| 602 |
+
alignment_dir=self.distillation_alignment_dir,
|
| 603 |
+
mapping_path=self.distillation_mapping_path,
|
| 604 |
+
max_template_hits=self.config.train.max_template_hits,
|
| 605 |
+
treat_pdb_as_distillation=True,
|
| 606 |
+
mode="train",
|
| 607 |
+
alignment_index=self.distillation_alignment_index,
|
| 608 |
+
_structure_index=self._distillation_structure_index,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
d_prob = self.config.train.distillation_prob
|
| 612 |
+
|
| 613 |
+
if(distillation_dataset is not None):
|
| 614 |
+
datasets = [train_dataset, distillation_dataset]
|
| 615 |
+
d_prob = self.config.train.distillation_prob
|
| 616 |
+
probabilities = [1. - d_prob, d_prob]
|
| 617 |
+
chain_data_cache_paths = [
|
| 618 |
+
self.train_chain_data_cache_path,
|
| 619 |
+
self.distillation_chain_data_cache_path,
|
| 620 |
+
]
|
| 621 |
+
else:
|
| 622 |
+
datasets = [train_dataset]
|
| 623 |
+
probabilities = [1.]
|
| 624 |
+
chain_data_cache_paths = [
|
| 625 |
+
self.train_chain_data_cache_path,
|
| 626 |
+
]
|
| 627 |
+
|
| 628 |
+
if(self.batch_seed is not None):
|
| 629 |
+
generator = torch.Generator()
|
| 630 |
+
generator = generator.manual_seed(self.batch_seed + 1)
|
| 631 |
+
|
| 632 |
+
self.train_dataset = OpenFoldDataset(
|
| 633 |
+
datasets=datasets,
|
| 634 |
+
probabilities=probabilities,
|
| 635 |
+
epoch_len=self.train_epoch_len,
|
| 636 |
+
chain_data_cache_paths=chain_data_cache_paths,
|
| 637 |
+
generator=generator,
|
| 638 |
+
_roll_at_init=False,
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
if(self.val_data_dir is not None):
|
| 642 |
+
self.eval_dataset = dataset_gen(
|
| 643 |
+
data_dir=self.val_data_dir,
|
| 644 |
+
alignment_dir=self.val_alignment_dir,
|
| 645 |
+
mapping_path=None,
|
| 646 |
+
max_template_hits=self.config.eval.max_template_hits,
|
| 647 |
+
mode="eval",
|
| 648 |
+
)
|
| 649 |
+
else:
|
| 650 |
+
self.eval_dataset = None
|
| 651 |
+
else:
|
| 652 |
+
self.predict_dataset = dataset_gen(
|
| 653 |
+
data_dir=self.predict_data_dir,
|
| 654 |
+
alignment_dir=self.predict_alignment_dir,
|
| 655 |
+
mapping_path=None,
|
| 656 |
+
max_template_hits=self.config.predict.max_template_hits,
|
| 657 |
+
mode="predict",
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
def _gen_dataloader(self, stage):
|
| 661 |
+
generator = torch.Generator()
|
| 662 |
+
if(self.batch_seed is not None):
|
| 663 |
+
generator = generator.manual_seed(self.batch_seed)
|
| 664 |
+
|
| 665 |
+
dataset = None
|
| 666 |
+
if(stage == "train"):
|
| 667 |
+
dataset = self.train_dataset
|
| 668 |
+
# Filter the dataset, if necessary
|
| 669 |
+
dataset.reroll()
|
| 670 |
+
elif(stage == "eval"):
|
| 671 |
+
dataset = self.eval_dataset
|
| 672 |
+
elif(stage == "predict"):
|
| 673 |
+
dataset = self.predict_dataset
|
| 674 |
+
else:
|
| 675 |
+
raise ValueError("Invalid stage")
|
| 676 |
+
|
| 677 |
+
batch_collator = OpenFoldBatchCollator()
|
| 678 |
+
|
| 679 |
+
dl = OpenFoldDataLoader(
|
| 680 |
+
dataset,
|
| 681 |
+
config=self.config,
|
| 682 |
+
stage=stage,
|
| 683 |
+
generator=generator,
|
| 684 |
+
batch_size=self.config.data_module.data_loaders.batch_size,
|
| 685 |
+
num_workers=self.config.data_module.data_loaders.num_workers,
|
| 686 |
+
collate_fn=batch_collator,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
return dl
|
| 690 |
+
|
| 691 |
+
def train_dataloader(self):
|
| 692 |
+
return self._gen_dataloader("train")
|
| 693 |
+
|
| 694 |
+
def val_dataloader(self):
|
| 695 |
+
if(self.eval_dataset is not None):
|
| 696 |
+
return self._gen_dataloader("eval")
|
| 697 |
+
return None
|
| 698 |
+
|
| 699 |
+
def predict_dataloader(self):
|
| 700 |
+
return self._gen_dataloader("predict")
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
class DummyDataset(torch.utils.data.Dataset):
|
| 704 |
+
def __init__(self, batch_path):
|
| 705 |
+
with open(batch_path, "rb") as f:
|
| 706 |
+
self.batch = pickle.load(f)
|
| 707 |
+
|
| 708 |
+
def __getitem__(self, idx):
|
| 709 |
+
return copy.deepcopy(self.batch)
|
| 710 |
+
|
| 711 |
+
def __len__(self):
|
| 712 |
+
return 1000
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
class DummyDataLoader(pl.LightningDataModule):
|
| 716 |
+
def __init__(self, batch_path):
|
| 717 |
+
super().__init__()
|
| 718 |
+
self.dataset = DummyDataset(batch_path)
|
| 719 |
+
|
| 720 |
+
def train_dataloader(self):
|
| 721 |
+
return torch.utils.data.DataLoader(self.dataset)
|
openfold/data/data_pipeline.py
ADDED
|
@@ -0,0 +1,826 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import datetime
|
| 18 |
+
from multiprocessing import cpu_count
|
| 19 |
+
from typing import Mapping, Optional, Sequence, Any
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from openfold.data import templates, parsers, mmcif_parsing
|
| 24 |
+
from openfold.data.tools import jackhmmer, hhblits, hhsearch
|
| 25 |
+
from openfold.data.tools.utils import to_date
|
| 26 |
+
from openfold.np import residue_constants, protein
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
FeatureDict = Mapping[str, np.ndarray]
|
| 30 |
+
|
| 31 |
+
def empty_template_feats(n_res) -> FeatureDict:
|
| 32 |
+
return {
|
| 33 |
+
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
|
| 34 |
+
"template_all_atom_positions":
|
| 35 |
+
np.zeros((0, n_res, 37, 3)).astype(np.float32),
|
| 36 |
+
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
|
| 37 |
+
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def make_template_features(
|
| 42 |
+
input_sequence: str,
|
| 43 |
+
hits: Sequence[Any],
|
| 44 |
+
template_featurizer: Any,
|
| 45 |
+
query_pdb_code: Optional[str] = None,
|
| 46 |
+
query_release_date: Optional[str] = None,
|
| 47 |
+
) -> FeatureDict:
|
| 48 |
+
hits_cat = sum(hits.values(), [])
|
| 49 |
+
if(len(hits_cat) == 0 or template_featurizer is None):
|
| 50 |
+
template_features = empty_template_feats(len(input_sequence))
|
| 51 |
+
else:
|
| 52 |
+
templates_result = template_featurizer.get_templates(
|
| 53 |
+
query_sequence=input_sequence,
|
| 54 |
+
query_pdb_code=query_pdb_code,
|
| 55 |
+
query_release_date=query_release_date,
|
| 56 |
+
hits=hits_cat,
|
| 57 |
+
)
|
| 58 |
+
template_features = templates_result.features
|
| 59 |
+
|
| 60 |
+
# The template featurizer doesn't format empty template features
|
| 61 |
+
# properly. This is a quick fix.
|
| 62 |
+
if(template_features["template_aatype"].shape[0] == 0):
|
| 63 |
+
template_features = empty_template_feats(len(input_sequence))
|
| 64 |
+
|
| 65 |
+
return template_features
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def unify_template_features(
|
| 69 |
+
template_feature_list: Sequence[FeatureDict]
|
| 70 |
+
) -> FeatureDict:
|
| 71 |
+
out_dicts = []
|
| 72 |
+
seq_lens = [fd["template_aatype"].shape[1] for fd in template_feature_list]
|
| 73 |
+
for i, fd in enumerate(template_feature_list):
|
| 74 |
+
out_dict = {}
|
| 75 |
+
n_templates, n_res = fd["template_aatype"].shape[:2]
|
| 76 |
+
for k,v in fd.items():
|
| 77 |
+
seq_keys = [
|
| 78 |
+
"template_aatype",
|
| 79 |
+
"template_all_atom_positions",
|
| 80 |
+
"template_all_atom_mask",
|
| 81 |
+
]
|
| 82 |
+
if(k in seq_keys):
|
| 83 |
+
new_shape = list(v.shape)
|
| 84 |
+
assert(new_shape[1] == n_res)
|
| 85 |
+
new_shape[1] = sum(seq_lens)
|
| 86 |
+
new_array = np.zeros(new_shape, dtype=v.dtype)
|
| 87 |
+
|
| 88 |
+
if(k == "template_aatype"):
|
| 89 |
+
new_array[..., residue_constants.HHBLITS_AA_TO_ID['-']] = 1
|
| 90 |
+
|
| 91 |
+
offset = sum(seq_lens[:i])
|
| 92 |
+
new_array[:, offset:offset + seq_lens[i]] = v
|
| 93 |
+
out_dict[k] = new_array
|
| 94 |
+
else:
|
| 95 |
+
out_dict[k] = v
|
| 96 |
+
|
| 97 |
+
chain_indices = np.array(n_templates * [i])
|
| 98 |
+
out_dict["template_chain_index"] = chain_indices
|
| 99 |
+
|
| 100 |
+
if(n_templates != 0):
|
| 101 |
+
out_dicts.append(out_dict)
|
| 102 |
+
|
| 103 |
+
out_dict = {
|
| 104 |
+
k: np.concatenate([od[k] for od in out_dicts]) for k in out_dicts[0]
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
return out_dict
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def make_sequence_features(
|
| 111 |
+
sequence: str, description: str, num_res: int
|
| 112 |
+
) -> FeatureDict:
|
| 113 |
+
"""Construct a feature dict of sequence features."""
|
| 114 |
+
features = {}
|
| 115 |
+
features["aatype"] = residue_constants.sequence_to_onehot(
|
| 116 |
+
sequence=sequence,
|
| 117 |
+
mapping=residue_constants.restype_order_with_x,
|
| 118 |
+
map_unknown_to_x=True,
|
| 119 |
+
)
|
| 120 |
+
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
|
| 121 |
+
features["domain_name"] = np.array(
|
| 122 |
+
[description.encode("utf-8")], dtype=np.object_
|
| 123 |
+
)
|
| 124 |
+
features["residue_index"] = np.array(range(num_res), dtype=np.int32)
|
| 125 |
+
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
|
| 126 |
+
features["sequence"] = np.array(
|
| 127 |
+
[sequence.encode("utf-8")], dtype=np.object_
|
| 128 |
+
)
|
| 129 |
+
return features
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def make_mmcif_features(
|
| 133 |
+
mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
|
| 134 |
+
) -> FeatureDict:
|
| 135 |
+
input_sequence = mmcif_object.chain_to_seqres[chain_id]
|
| 136 |
+
description = "_".join([mmcif_object.file_id, chain_id])
|
| 137 |
+
num_res = len(input_sequence)
|
| 138 |
+
|
| 139 |
+
mmcif_feats = {}
|
| 140 |
+
|
| 141 |
+
mmcif_feats.update(
|
| 142 |
+
make_sequence_features(
|
| 143 |
+
sequence=input_sequence,
|
| 144 |
+
description=description,
|
| 145 |
+
num_res=num_res,
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
|
| 150 |
+
mmcif_object=mmcif_object, chain_id=chain_id
|
| 151 |
+
)
|
| 152 |
+
mmcif_feats["all_atom_positions"] = all_atom_positions
|
| 153 |
+
mmcif_feats["all_atom_mask"] = all_atom_mask
|
| 154 |
+
|
| 155 |
+
mmcif_feats["resolution"] = np.array(
|
| 156 |
+
[mmcif_object.header["resolution"]], dtype=np.float32
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
mmcif_feats["release_date"] = np.array(
|
| 160 |
+
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
|
| 164 |
+
|
| 165 |
+
return mmcif_feats
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _aatype_to_str_sequence(aatype):
|
| 169 |
+
return ''.join([
|
| 170 |
+
residue_constants.restypes_with_x[aatype[i]]
|
| 171 |
+
for i in range(len(aatype))
|
| 172 |
+
])
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def make_protein_features(
|
| 176 |
+
protein_object: protein.Protein,
|
| 177 |
+
description: str,
|
| 178 |
+
_is_distillation: bool = False,
|
| 179 |
+
) -> FeatureDict:
|
| 180 |
+
pdb_feats = {}
|
| 181 |
+
aatype = protein_object.aatype
|
| 182 |
+
sequence = _aatype_to_str_sequence(aatype)
|
| 183 |
+
pdb_feats.update(
|
| 184 |
+
make_sequence_features(
|
| 185 |
+
sequence=sequence,
|
| 186 |
+
description=description,
|
| 187 |
+
num_res=len(protein_object.aatype),
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
all_atom_positions = protein_object.atom_positions
|
| 192 |
+
all_atom_mask = protein_object.atom_mask
|
| 193 |
+
|
| 194 |
+
pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
|
| 195 |
+
pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
|
| 196 |
+
|
| 197 |
+
pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
|
| 198 |
+
pdb_feats["is_distillation"] = np.array(
|
| 199 |
+
1. if _is_distillation else 0.
|
| 200 |
+
).astype(np.float32)
|
| 201 |
+
|
| 202 |
+
return pdb_feats
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def make_pdb_features(
|
| 206 |
+
protein_object: protein.Protein,
|
| 207 |
+
description: str,
|
| 208 |
+
is_distillation: bool = True,
|
| 209 |
+
confidence_threshold: float = 50.,
|
| 210 |
+
) -> FeatureDict:
|
| 211 |
+
pdb_feats = make_protein_features(
|
| 212 |
+
protein_object, description, _is_distillation=True
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if(is_distillation):
|
| 216 |
+
high_confidence = protein_object.b_factors > confidence_threshold
|
| 217 |
+
high_confidence = np.any(high_confidence, axis=-1)
|
| 218 |
+
pdb_feats["all_atom_mask"] *= high_confidence[..., None]
|
| 219 |
+
|
| 220 |
+
return pdb_feats
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def make_msa_features(
|
| 224 |
+
msas: Sequence[Sequence[str]],
|
| 225 |
+
deletion_matrices: Sequence[parsers.DeletionMatrix],
|
| 226 |
+
) -> FeatureDict:
|
| 227 |
+
"""Constructs a feature dict of MSA features."""
|
| 228 |
+
if not msas:
|
| 229 |
+
raise ValueError("At least one MSA must be provided.")
|
| 230 |
+
|
| 231 |
+
int_msa = []
|
| 232 |
+
deletion_matrix = []
|
| 233 |
+
seen_sequences = set()
|
| 234 |
+
for msa_index, msa in enumerate(msas):
|
| 235 |
+
if not msa:
|
| 236 |
+
raise ValueError(
|
| 237 |
+
f"MSA {msa_index} must contain at least one sequence."
|
| 238 |
+
)
|
| 239 |
+
for sequence_index, sequence in enumerate(msa):
|
| 240 |
+
if sequence in seen_sequences:
|
| 241 |
+
continue
|
| 242 |
+
seen_sequences.add(sequence)
|
| 243 |
+
int_msa.append(
|
| 244 |
+
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
|
| 245 |
+
)
|
| 246 |
+
deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
|
| 247 |
+
|
| 248 |
+
num_res = len(msas[0][0])
|
| 249 |
+
num_alignments = len(int_msa)
|
| 250 |
+
features = {}
|
| 251 |
+
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
|
| 252 |
+
features["msa"] = np.array(int_msa, dtype=np.int32)
|
| 253 |
+
features["num_alignments"] = np.array(
|
| 254 |
+
[num_alignments] * num_res, dtype=np.int32
|
| 255 |
+
)
|
| 256 |
+
return features
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class AlignmentRunner:
|
| 260 |
+
"""Runs alignment tools and saves the results"""
|
| 261 |
+
def __init__(
|
| 262 |
+
self,
|
| 263 |
+
jackhmmer_binary_path: Optional[str] = None,
|
| 264 |
+
hhblits_binary_path: Optional[str] = None,
|
| 265 |
+
hhsearch_binary_path: Optional[str] = None,
|
| 266 |
+
uniref90_database_path: Optional[str] = None,
|
| 267 |
+
mgnify_database_path: Optional[str] = None,
|
| 268 |
+
bfd_database_path: Optional[str] = None,
|
| 269 |
+
uniclust30_database_path: Optional[str] = None,
|
| 270 |
+
pdb70_database_path: Optional[str] = None,
|
| 271 |
+
use_small_bfd: Optional[bool] = None,
|
| 272 |
+
no_cpus: Optional[int] = None,
|
| 273 |
+
uniref_max_hits: int = 10000,
|
| 274 |
+
mgnify_max_hits: int = 5000,
|
| 275 |
+
):
|
| 276 |
+
"""
|
| 277 |
+
Args:
|
| 278 |
+
jackhmmer_binary_path:
|
| 279 |
+
Path to jackhmmer binary
|
| 280 |
+
hhblits_binary_path:
|
| 281 |
+
Path to hhblits binary
|
| 282 |
+
hhsearch_binary_path:
|
| 283 |
+
Path to hhsearch binary
|
| 284 |
+
uniref90_database_path:
|
| 285 |
+
Path to uniref90 database. If provided, jackhmmer_binary_path
|
| 286 |
+
must also be provided
|
| 287 |
+
mgnify_database_path:
|
| 288 |
+
Path to mgnify database. If provided, jackhmmer_binary_path
|
| 289 |
+
must also be provided
|
| 290 |
+
bfd_database_path:
|
| 291 |
+
Path to BFD database. Depending on the value of use_small_bfd,
|
| 292 |
+
one of hhblits_binary_path or jackhmmer_binary_path must be
|
| 293 |
+
provided.
|
| 294 |
+
uniclust30_database_path:
|
| 295 |
+
Path to uniclust30. Searched alongside BFD if use_small_bfd is
|
| 296 |
+
false.
|
| 297 |
+
pdb70_database_path:
|
| 298 |
+
Path to pdb70 database.
|
| 299 |
+
use_small_bfd:
|
| 300 |
+
Whether to search the BFD database alone with jackhmmer or
|
| 301 |
+
in conjunction with uniclust30 with hhblits.
|
| 302 |
+
no_cpus:
|
| 303 |
+
The number of CPUs available for alignment. By default, all
|
| 304 |
+
CPUs are used.
|
| 305 |
+
uniref_max_hits:
|
| 306 |
+
Max number of uniref hits
|
| 307 |
+
mgnify_max_hits:
|
| 308 |
+
Max number of mgnify hits
|
| 309 |
+
"""
|
| 310 |
+
db_map = {
|
| 311 |
+
"jackhmmer": {
|
| 312 |
+
"binary": jackhmmer_binary_path,
|
| 313 |
+
"dbs": [
|
| 314 |
+
uniref90_database_path,
|
| 315 |
+
mgnify_database_path,
|
| 316 |
+
bfd_database_path if use_small_bfd else None,
|
| 317 |
+
],
|
| 318 |
+
},
|
| 319 |
+
"hhblits": {
|
| 320 |
+
"binary": hhblits_binary_path,
|
| 321 |
+
"dbs": [
|
| 322 |
+
bfd_database_path if not use_small_bfd else None,
|
| 323 |
+
],
|
| 324 |
+
},
|
| 325 |
+
"hhsearch": {
|
| 326 |
+
"binary": hhsearch_binary_path,
|
| 327 |
+
"dbs": [
|
| 328 |
+
pdb70_database_path,
|
| 329 |
+
],
|
| 330 |
+
},
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
for name, dic in db_map.items():
|
| 334 |
+
binary, dbs = dic["binary"], dic["dbs"]
|
| 335 |
+
if(binary is None and not all([x is None for x in dbs])):
|
| 336 |
+
raise ValueError(
|
| 337 |
+
f"{name} DBs provided but {name} binary is None"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
if(not all([x is None for x in db_map["hhsearch"]["dbs"]])
|
| 341 |
+
and uniref90_database_path is None):
|
| 342 |
+
raise ValueError(
|
| 343 |
+
"""uniref90_database_path must be specified in order to perform
|
| 344 |
+
template search"""
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
self.uniref_max_hits = uniref_max_hits
|
| 348 |
+
self.mgnify_max_hits = mgnify_max_hits
|
| 349 |
+
self.use_small_bfd = use_small_bfd
|
| 350 |
+
|
| 351 |
+
if(no_cpus is None):
|
| 352 |
+
no_cpus = cpu_count()
|
| 353 |
+
|
| 354 |
+
self.jackhmmer_uniref90_runner = None
|
| 355 |
+
if(jackhmmer_binary_path is not None and
|
| 356 |
+
uniref90_database_path is not None
|
| 357 |
+
):
|
| 358 |
+
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
|
| 359 |
+
binary_path=jackhmmer_binary_path,
|
| 360 |
+
database_path=uniref90_database_path,
|
| 361 |
+
n_cpu=no_cpus,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
self.jackhmmer_small_bfd_runner = None
|
| 365 |
+
self.hhblits_bfd_uniclust_runner = None
|
| 366 |
+
if(bfd_database_path is not None):
|
| 367 |
+
if use_small_bfd:
|
| 368 |
+
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
|
| 369 |
+
binary_path=jackhmmer_binary_path,
|
| 370 |
+
database_path=bfd_database_path,
|
| 371 |
+
n_cpu=no_cpus,
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
dbs = [bfd_database_path]
|
| 375 |
+
if(uniclust30_database_path is not None):
|
| 376 |
+
dbs.append(uniclust30_database_path)
|
| 377 |
+
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
|
| 378 |
+
binary_path=hhblits_binary_path,
|
| 379 |
+
databases=dbs,
|
| 380 |
+
n_cpu=no_cpus,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
self.jackhmmer_mgnify_runner = None
|
| 384 |
+
if(mgnify_database_path is not None):
|
| 385 |
+
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
|
| 386 |
+
binary_path=jackhmmer_binary_path,
|
| 387 |
+
database_path=mgnify_database_path,
|
| 388 |
+
n_cpu=no_cpus,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
self.hhsearch_pdb70_runner = None
|
| 392 |
+
if(pdb70_database_path is not None):
|
| 393 |
+
self.hhsearch_pdb70_runner = hhsearch.HHSearch(
|
| 394 |
+
binary_path=hhsearch_binary_path,
|
| 395 |
+
databases=[pdb70_database_path],
|
| 396 |
+
n_cpu=no_cpus,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
def run(
|
| 400 |
+
self,
|
| 401 |
+
fasta_path: str,
|
| 402 |
+
output_dir: str,
|
| 403 |
+
):
|
| 404 |
+
"""Runs alignment tools on a sequence"""
|
| 405 |
+
if(self.jackhmmer_uniref90_runner is not None):
|
| 406 |
+
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
|
| 407 |
+
fasta_path
|
| 408 |
+
)[0]
|
| 409 |
+
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
|
| 410 |
+
jackhmmer_uniref90_result["sto"],
|
| 411 |
+
max_sequences=self.uniref_max_hits
|
| 412 |
+
)
|
| 413 |
+
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
|
| 414 |
+
with open(uniref90_out_path, "w") as f:
|
| 415 |
+
f.write(uniref90_msa_as_a3m)
|
| 416 |
+
|
| 417 |
+
if(self.hhsearch_pdb70_runner is not None):
|
| 418 |
+
hhsearch_result = self.hhsearch_pdb70_runner.query(
|
| 419 |
+
uniref90_msa_as_a3m
|
| 420 |
+
)
|
| 421 |
+
pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
|
| 422 |
+
with open(pdb70_out_path, "w") as f:
|
| 423 |
+
f.write(hhsearch_result)
|
| 424 |
+
|
| 425 |
+
if(self.jackhmmer_mgnify_runner is not None):
|
| 426 |
+
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
|
| 427 |
+
fasta_path
|
| 428 |
+
)[0]
|
| 429 |
+
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
|
| 430 |
+
jackhmmer_mgnify_result["sto"],
|
| 431 |
+
max_sequences=self.mgnify_max_hits
|
| 432 |
+
)
|
| 433 |
+
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
|
| 434 |
+
with open(mgnify_out_path, "w") as f:
|
| 435 |
+
f.write(mgnify_msa_as_a3m)
|
| 436 |
+
|
| 437 |
+
if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
|
| 438 |
+
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
|
| 439 |
+
fasta_path
|
| 440 |
+
)[0]
|
| 441 |
+
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
|
| 442 |
+
with open(bfd_out_path, "w") as f:
|
| 443 |
+
f.write(jackhmmer_small_bfd_result["sto"])
|
| 444 |
+
elif(self.hhblits_bfd_uniclust_runner is not None):
|
| 445 |
+
hhblits_bfd_uniclust_result = (
|
| 446 |
+
self.hhblits_bfd_uniclust_runner.query(fasta_path)
|
| 447 |
+
)
|
| 448 |
+
if output_dir is not None:
|
| 449 |
+
bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
|
| 450 |
+
with open(bfd_out_path, "w") as f:
|
| 451 |
+
f.write(hhblits_bfd_uniclust_result["a3m"])
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
class DataPipeline:
|
| 455 |
+
"""Assembles input features."""
|
| 456 |
+
def __init__(
|
| 457 |
+
self,
|
| 458 |
+
template_featurizer: Optional[templates.TemplateHitFeaturizer],
|
| 459 |
+
):
|
| 460 |
+
self.template_featurizer = template_featurizer
|
| 461 |
+
|
| 462 |
+
def _parse_msa_data(
|
| 463 |
+
self,
|
| 464 |
+
alignment_dir: str,
|
| 465 |
+
alignment_index: Optional[Any] = None,
|
| 466 |
+
) -> Mapping[str, Any]:
|
| 467 |
+
msa_data = {}
|
| 468 |
+
if(alignment_index is not None):
|
| 469 |
+
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
|
| 470 |
+
|
| 471 |
+
def read_msa(start, size):
|
| 472 |
+
fp.seek(start)
|
| 473 |
+
msa = fp.read(size).decode("utf-8")
|
| 474 |
+
return msa
|
| 475 |
+
|
| 476 |
+
for (name, start, size) in alignment_index["files"]:
|
| 477 |
+
ext = os.path.splitext(name)[-1]
|
| 478 |
+
|
| 479 |
+
if(ext == ".a3m"):
|
| 480 |
+
msa, deletion_matrix = parsers.parse_a3m(
|
| 481 |
+
read_msa(start, size)
|
| 482 |
+
)
|
| 483 |
+
data = {"msa": msa, "deletion_matrix": deletion_matrix}
|
| 484 |
+
elif(ext == ".sto"):
|
| 485 |
+
msa, deletion_matrix, _ = parsers.parse_stockholm(
|
| 486 |
+
read_msa(start, size)
|
| 487 |
+
)
|
| 488 |
+
data = {"msa": msa, "deletion_matrix": deletion_matrix}
|
| 489 |
+
else:
|
| 490 |
+
continue
|
| 491 |
+
|
| 492 |
+
msa_data[name] = data
|
| 493 |
+
|
| 494 |
+
fp.close()
|
| 495 |
+
else:
|
| 496 |
+
for f in os.listdir(alignment_dir):
|
| 497 |
+
path = os.path.join(alignment_dir, f)
|
| 498 |
+
ext = os.path.splitext(f)[-1]
|
| 499 |
+
|
| 500 |
+
if(ext == ".a3m"):
|
| 501 |
+
with open(path, "r") as fp:
|
| 502 |
+
msa, deletion_matrix = parsers.parse_a3m(fp.read())
|
| 503 |
+
data = {"msa": msa, "deletion_matrix": deletion_matrix}
|
| 504 |
+
elif(ext == ".sto"):
|
| 505 |
+
with open(path, "r") as fp:
|
| 506 |
+
msa, deletion_matrix, _ = parsers.parse_stockholm(
|
| 507 |
+
fp.read()
|
| 508 |
+
)
|
| 509 |
+
data = {"msa": msa, "deletion_matrix": deletion_matrix}
|
| 510 |
+
else:
|
| 511 |
+
continue
|
| 512 |
+
|
| 513 |
+
msa_data[f] = data
|
| 514 |
+
|
| 515 |
+
return msa_data
|
| 516 |
+
|
| 517 |
+
def _parse_template_hits(
|
| 518 |
+
self,
|
| 519 |
+
alignment_dir: str,
|
| 520 |
+
alignment_index: Optional[Any] = None
|
| 521 |
+
) -> Mapping[str, Any]:
|
| 522 |
+
all_hits = {}
|
| 523 |
+
if(alignment_index is not None):
|
| 524 |
+
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
|
| 525 |
+
|
| 526 |
+
def read_template(start, size):
|
| 527 |
+
fp.seek(start)
|
| 528 |
+
return fp.read(size).decode("utf-8")
|
| 529 |
+
|
| 530 |
+
for (name, start, size) in alignment_index["files"]:
|
| 531 |
+
ext = os.path.splitext(name)[-1]
|
| 532 |
+
|
| 533 |
+
if(ext == ".hhr"):
|
| 534 |
+
hits = parsers.parse_hhr(read_template(start, size))
|
| 535 |
+
all_hits[name] = hits
|
| 536 |
+
|
| 537 |
+
fp.close()
|
| 538 |
+
else:
|
| 539 |
+
for f in os.listdir(alignment_dir):
|
| 540 |
+
path = os.path.join(alignment_dir, f)
|
| 541 |
+
ext = os.path.splitext(f)[-1]
|
| 542 |
+
|
| 543 |
+
if(ext == ".hhr"):
|
| 544 |
+
with open(path, "r") as fp:
|
| 545 |
+
hits = parsers.parse_hhr(fp.read())
|
| 546 |
+
all_hits[f] = hits
|
| 547 |
+
|
| 548 |
+
return all_hits
|
| 549 |
+
|
| 550 |
+
def _get_msas(self,
|
| 551 |
+
alignment_dir: str,
|
| 552 |
+
input_sequence: Optional[str] = None,
|
| 553 |
+
alignment_index: Optional[str] = None,
|
| 554 |
+
):
|
| 555 |
+
msa_data = self._parse_msa_data(alignment_dir, alignment_index)
|
| 556 |
+
if(len(msa_data) == 0):
|
| 557 |
+
if(input_sequence is None):
|
| 558 |
+
raise ValueError(
|
| 559 |
+
"""
|
| 560 |
+
If the alignment dir contains no MSAs, an input sequence
|
| 561 |
+
must be provided.
|
| 562 |
+
"""
|
| 563 |
+
)
|
| 564 |
+
msa_data["dummy"] = {
|
| 565 |
+
"msa": [input_sequence],
|
| 566 |
+
"deletion_matrix": [[0 for _ in input_sequence]],
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
msas, deletion_matrices = zip(*[
|
| 570 |
+
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
|
| 571 |
+
])
|
| 572 |
+
|
| 573 |
+
return msas, deletion_matrices
|
| 574 |
+
|
| 575 |
+
def _process_msa_feats(
|
| 576 |
+
self,
|
| 577 |
+
alignment_dir: str,
|
| 578 |
+
input_sequence: Optional[str] = None,
|
| 579 |
+
alignment_index: Optional[str] = None
|
| 580 |
+
) -> Mapping[str, Any]:
|
| 581 |
+
msas, deletion_matrices = self._get_msas(
|
| 582 |
+
alignment_dir, input_sequence, alignment_index
|
| 583 |
+
)
|
| 584 |
+
msa_features = make_msa_features(
|
| 585 |
+
msas=msas,
|
| 586 |
+
deletion_matrices=deletion_matrices,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
return msa_features
|
| 590 |
+
|
| 591 |
+
def process_fasta(
|
| 592 |
+
self,
|
| 593 |
+
fasta_path: str,
|
| 594 |
+
alignment_dir: str,
|
| 595 |
+
alignment_index: Optional[str] = None,
|
| 596 |
+
) -> FeatureDict:
|
| 597 |
+
"""Assembles features for a single sequence in a FASTA file"""
|
| 598 |
+
with open(fasta_path) as f:
|
| 599 |
+
fasta_str = f.read()
|
| 600 |
+
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
|
| 601 |
+
if len(input_seqs) != 1:
|
| 602 |
+
raise ValueError(
|
| 603 |
+
f"More than one input sequence found in {fasta_path}."
|
| 604 |
+
)
|
| 605 |
+
input_sequence = input_seqs[0]
|
| 606 |
+
input_description = input_descs[0]
|
| 607 |
+
num_res = len(input_sequence)
|
| 608 |
+
|
| 609 |
+
hits = self._parse_template_hits(alignment_dir, alignment_index)
|
| 610 |
+
template_features = make_template_features(
|
| 611 |
+
input_sequence,
|
| 612 |
+
hits,
|
| 613 |
+
self.template_featurizer,
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
sequence_features = make_sequence_features(
|
| 617 |
+
sequence=input_sequence,
|
| 618 |
+
description=input_description,
|
| 619 |
+
num_res=num_res,
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
|
| 623 |
+
|
| 624 |
+
return {
|
| 625 |
+
**sequence_features,
|
| 626 |
+
**msa_features,
|
| 627 |
+
**template_features
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
def process_mmcif(
|
| 631 |
+
self,
|
| 632 |
+
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
|
| 633 |
+
alignment_dir: str,
|
| 634 |
+
chain_id: Optional[str] = None,
|
| 635 |
+
alignment_index: Optional[str] = None,
|
| 636 |
+
) -> FeatureDict:
|
| 637 |
+
"""
|
| 638 |
+
Assembles features for a specific chain in an mmCIF object.
|
| 639 |
+
|
| 640 |
+
If chain_id is None, it is assumed that there is only one chain
|
| 641 |
+
in the object. Otherwise, a ValueError is thrown.
|
| 642 |
+
"""
|
| 643 |
+
if chain_id is None:
|
| 644 |
+
chains = mmcif.structure.get_chains()
|
| 645 |
+
chain = next(chains, None)
|
| 646 |
+
if chain is None:
|
| 647 |
+
raise ValueError("No chains in mmCIF file")
|
| 648 |
+
chain_id = chain.id
|
| 649 |
+
|
| 650 |
+
mmcif_feats = make_mmcif_features(mmcif, chain_id)
|
| 651 |
+
|
| 652 |
+
input_sequence = mmcif.chain_to_seqres[chain_id]
|
| 653 |
+
hits = self._parse_template_hits(alignment_dir, alignment_index)
|
| 654 |
+
template_features = make_template_features(
|
| 655 |
+
input_sequence,
|
| 656 |
+
hits,
|
| 657 |
+
self.template_featurizer,
|
| 658 |
+
query_release_date=to_date(mmcif.header["release_date"])
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
|
| 662 |
+
|
| 663 |
+
return {**mmcif_feats, **template_features, **msa_features}
|
| 664 |
+
|
| 665 |
+
def process_pdb(
|
| 666 |
+
self,
|
| 667 |
+
pdb_path: str,
|
| 668 |
+
alignment_dir: str,
|
| 669 |
+
is_distillation: bool = True,
|
| 670 |
+
chain_id: Optional[str] = None,
|
| 671 |
+
_structure_index: Optional[str] = None,
|
| 672 |
+
alignment_index: Optional[str] = None,
|
| 673 |
+
) -> FeatureDict:
|
| 674 |
+
"""
|
| 675 |
+
Assembles features for a protein in a PDB file.
|
| 676 |
+
"""
|
| 677 |
+
if(_structure_index is not None):
|
| 678 |
+
db_dir = os.path.dirname(pdb_path)
|
| 679 |
+
db = _structure_index["db"]
|
| 680 |
+
db_path = os.path.join(db_dir, db)
|
| 681 |
+
fp = open(db_path, "rb")
|
| 682 |
+
_, offset, length = _structure_index["files"][0]
|
| 683 |
+
fp.seek(offset)
|
| 684 |
+
pdb_str = fp.read(length).decode("utf-8")
|
| 685 |
+
fp.close()
|
| 686 |
+
else:
|
| 687 |
+
with open(pdb_path, 'r') as f:
|
| 688 |
+
pdb_str = f.read()
|
| 689 |
+
|
| 690 |
+
protein_object = protein.from_pdb_string(pdb_str, chain_id)
|
| 691 |
+
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
|
| 692 |
+
description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
|
| 693 |
+
pdb_feats = make_pdb_features(
|
| 694 |
+
protein_object,
|
| 695 |
+
description,
|
| 696 |
+
is_distillation=is_distillation
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
hits = self._parse_template_hits(alignment_dir, alignment_index)
|
| 700 |
+
template_features = make_template_features(
|
| 701 |
+
input_sequence,
|
| 702 |
+
hits,
|
| 703 |
+
self.template_featurizer,
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
|
| 707 |
+
|
| 708 |
+
return {**pdb_feats, **template_features, **msa_features}
|
| 709 |
+
|
| 710 |
+
def process_core(
|
| 711 |
+
self,
|
| 712 |
+
core_path: str,
|
| 713 |
+
alignment_dir: str,
|
| 714 |
+
alignment_index: Optional[str] = None,
|
| 715 |
+
) -> FeatureDict:
|
| 716 |
+
"""
|
| 717 |
+
Assembles features for a protein in a ProteinNet .core file.
|
| 718 |
+
"""
|
| 719 |
+
with open(core_path, 'r') as f:
|
| 720 |
+
core_str = f.read()
|
| 721 |
+
|
| 722 |
+
protein_object = protein.from_proteinnet_string(core_str)
|
| 723 |
+
input_sequence = _aatype_to_str_sequence(protein_object.aatype)
|
| 724 |
+
description = os.path.splitext(os.path.basename(core_path))[0].upper()
|
| 725 |
+
core_feats = make_protein_features(protein_object, description)
|
| 726 |
+
|
| 727 |
+
hits = self._parse_template_hits(alignment_dir, alignment_index)
|
| 728 |
+
template_features = make_template_features(
|
| 729 |
+
input_sequence,
|
| 730 |
+
hits,
|
| 731 |
+
self.template_featurizer,
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
|
| 735 |
+
|
| 736 |
+
return {**core_feats, **template_features, **msa_features}
|
| 737 |
+
|
| 738 |
+
def process_multiseq_fasta(self,
|
| 739 |
+
fasta_path: str,
|
| 740 |
+
super_alignment_dir: str,
|
| 741 |
+
ri_gap: int = 200,
|
| 742 |
+
) -> FeatureDict:
|
| 743 |
+
"""
|
| 744 |
+
Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
|
| 745 |
+
hack from Twitter (a.k.a. AlphaFold-Gap).
|
| 746 |
+
"""
|
| 747 |
+
with open(fasta_path, 'r') as f:
|
| 748 |
+
fasta_str = f.read()
|
| 749 |
+
|
| 750 |
+
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
|
| 751 |
+
|
| 752 |
+
# No whitespace allowed
|
| 753 |
+
input_descs = [i.split()[0] for i in input_descs]
|
| 754 |
+
|
| 755 |
+
# Stitch all of the sequences together
|
| 756 |
+
input_sequence = ''.join(input_seqs)
|
| 757 |
+
input_description = '-'.join(input_descs)
|
| 758 |
+
num_res = len(input_sequence)
|
| 759 |
+
|
| 760 |
+
sequence_features = make_sequence_features(
|
| 761 |
+
sequence=input_sequence,
|
| 762 |
+
description=input_description,
|
| 763 |
+
num_res=num_res,
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
seq_lens = [len(s) for s in input_seqs]
|
| 767 |
+
total_offset = 0
|
| 768 |
+
for sl in seq_lens:
|
| 769 |
+
total_offset += sl
|
| 770 |
+
sequence_features["residue_index"][total_offset:] += ri_gap
|
| 771 |
+
|
| 772 |
+
msa_list = []
|
| 773 |
+
deletion_mat_list = []
|
| 774 |
+
for seq, desc in zip(input_seqs, input_descs):
|
| 775 |
+
alignment_dir = os.path.join(
|
| 776 |
+
super_alignment_dir, desc
|
| 777 |
+
)
|
| 778 |
+
msas, deletion_mats = self._get_msas(
|
| 779 |
+
alignment_dir, seq, None
|
| 780 |
+
)
|
| 781 |
+
msa_list.append(msas)
|
| 782 |
+
deletion_mat_list.append(deletion_mats)
|
| 783 |
+
|
| 784 |
+
final_msa = []
|
| 785 |
+
final_deletion_mat = []
|
| 786 |
+
msa_it = enumerate(zip(msa_list, deletion_mat_list))
|
| 787 |
+
for i, (msas, deletion_mats) in msa_it:
|
| 788 |
+
prec, post = sum(seq_lens[:i]), sum(seq_lens[i + 1:])
|
| 789 |
+
msas = [
|
| 790 |
+
[prec * '-' + seq + post * '-' for seq in msa] for msa in msas
|
| 791 |
+
]
|
| 792 |
+
deletion_mats = [
|
| 793 |
+
[prec * [0] + dml + post * [0] for dml in deletion_mat]
|
| 794 |
+
for deletion_mat in deletion_mats
|
| 795 |
+
]
|
| 796 |
+
|
| 797 |
+
assert(len(msas[0][-1]) == len(input_sequence))
|
| 798 |
+
|
| 799 |
+
final_msa.extend(msas)
|
| 800 |
+
final_deletion_mat.extend(deletion_mats)
|
| 801 |
+
|
| 802 |
+
msa_features = make_msa_features(
|
| 803 |
+
msas=final_msa,
|
| 804 |
+
deletion_matrices=final_deletion_mat,
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
template_feature_list = []
|
| 808 |
+
for seq, desc in zip(input_seqs, input_descs):
|
| 809 |
+
alignment_dir = os.path.join(
|
| 810 |
+
super_alignment_dir, desc
|
| 811 |
+
)
|
| 812 |
+
hits = self._parse_template_hits(alignment_dir, alignment_index=None)
|
| 813 |
+
template_features = make_template_features(
|
| 814 |
+
seq,
|
| 815 |
+
hits,
|
| 816 |
+
self.template_featurizer,
|
| 817 |
+
)
|
| 818 |
+
template_feature_list.append(template_features)
|
| 819 |
+
|
| 820 |
+
template_features = unify_template_features(template_feature_list)
|
| 821 |
+
|
| 822 |
+
return {
|
| 823 |
+
**sequence_features,
|
| 824 |
+
**msa_features,
|
| 825 |
+
**template_features,
|
| 826 |
+
}
|
openfold/data/data_transforms.py
ADDED
|
@@ -0,0 +1,1212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import itertools
|
| 17 |
+
from functools import reduce, wraps
|
| 18 |
+
from operator import add
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
|
| 24 |
+
from openfold.np import residue_constants as rc
|
| 25 |
+
from openfold.utils.rigid_utils import Rotation, Rigid
|
| 26 |
+
from openfold.utils.tensor_utils import (
|
| 27 |
+
tree_map,
|
| 28 |
+
tensor_tree_map,
|
| 29 |
+
batched_gather,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
MSA_FEATURE_NAMES = [
|
| 34 |
+
"msa",
|
| 35 |
+
"deletion_matrix",
|
| 36 |
+
"msa_mask",
|
| 37 |
+
"msa_row_mask",
|
| 38 |
+
"bert_mask",
|
| 39 |
+
"true_msa",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def cast_to_64bit_ints(protein):
|
| 44 |
+
# We keep all ints as int64
|
| 45 |
+
for k, v in protein.items():
|
| 46 |
+
if v.dtype == torch.int32:
|
| 47 |
+
protein[k] = v.type(torch.int64)
|
| 48 |
+
|
| 49 |
+
return protein
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def make_one_hot(x, num_classes):
|
| 53 |
+
x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
|
| 54 |
+
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
|
| 55 |
+
return x_one_hot
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def make_seq_mask(protein):
|
| 59 |
+
protein["seq_mask"] = torch.ones(
|
| 60 |
+
protein["aatype"].shape, dtype=torch.float32
|
| 61 |
+
)
|
| 62 |
+
return protein
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def make_template_mask(protein):
|
| 66 |
+
protein["template_mask"] = torch.ones(
|
| 67 |
+
protein["template_aatype"].shape[0], dtype=torch.float32
|
| 68 |
+
)
|
| 69 |
+
return protein
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def curry1(f):
|
| 73 |
+
"""Supply all arguments but the first."""
|
| 74 |
+
@wraps(f)
|
| 75 |
+
def fc(*args, **kwargs):
|
| 76 |
+
return lambda x: f(x, *args, **kwargs)
|
| 77 |
+
|
| 78 |
+
return fc
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def make_all_atom_aatype(protein):
|
| 82 |
+
protein["all_atom_aatype"] = protein["aatype"]
|
| 83 |
+
return protein
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def fix_templates_aatype(protein):
|
| 87 |
+
# Map one-hot to indices
|
| 88 |
+
num_templates = protein["template_aatype"].shape[0]
|
| 89 |
+
if(num_templates > 0):
|
| 90 |
+
protein["template_aatype"] = torch.argmax(
|
| 91 |
+
protein["template_aatype"], dim=-1
|
| 92 |
+
)
|
| 93 |
+
# Map hhsearch-aatype to our aatype.
|
| 94 |
+
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
|
| 95 |
+
new_order = torch.tensor(
|
| 96 |
+
new_order_list, dtype=torch.int64, device=protein["aatype"].device,
|
| 97 |
+
).expand(num_templates, -1)
|
| 98 |
+
protein["template_aatype"] = torch.gather(
|
| 99 |
+
new_order, 1, index=protein["template_aatype"]
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
return protein
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def correct_msa_restypes(protein):
|
| 106 |
+
"""Correct MSA restype to have the same order as rc."""
|
| 107 |
+
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
|
| 108 |
+
new_order = torch.tensor(
|
| 109 |
+
[new_order_list] * protein["msa"].shape[1],
|
| 110 |
+
device=protein["msa"].device,
|
| 111 |
+
).transpose(0, 1)
|
| 112 |
+
protein["msa"] = torch.gather(new_order, 0, protein["msa"])
|
| 113 |
+
|
| 114 |
+
perm_matrix = np.zeros((22, 22), dtype=np.float32)
|
| 115 |
+
perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
|
| 116 |
+
|
| 117 |
+
for k in protein:
|
| 118 |
+
if "profile" in k:
|
| 119 |
+
num_dim = protein[k].shape.as_list()[-1]
|
| 120 |
+
assert num_dim in [
|
| 121 |
+
20,
|
| 122 |
+
21,
|
| 123 |
+
22,
|
| 124 |
+
], "num_dim for %s out of expected range: %s" % (k, num_dim)
|
| 125 |
+
protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
|
| 126 |
+
|
| 127 |
+
return protein
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def squeeze_features(protein):
|
| 131 |
+
"""Remove singleton and repeated dimensions in protein features."""
|
| 132 |
+
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
|
| 133 |
+
for k in [
|
| 134 |
+
"domain_name",
|
| 135 |
+
"msa",
|
| 136 |
+
"num_alignments",
|
| 137 |
+
"seq_length",
|
| 138 |
+
"sequence",
|
| 139 |
+
"superfamily",
|
| 140 |
+
"deletion_matrix",
|
| 141 |
+
"resolution",
|
| 142 |
+
"between_segment_residues",
|
| 143 |
+
"residue_index",
|
| 144 |
+
"template_all_atom_mask",
|
| 145 |
+
]:
|
| 146 |
+
if k in protein:
|
| 147 |
+
final_dim = protein[k].shape[-1]
|
| 148 |
+
if isinstance(final_dim, int) and final_dim == 1:
|
| 149 |
+
if torch.is_tensor(protein[k]):
|
| 150 |
+
protein[k] = torch.squeeze(protein[k], dim=-1)
|
| 151 |
+
else:
|
| 152 |
+
protein[k] = np.squeeze(protein[k], axis=-1)
|
| 153 |
+
|
| 154 |
+
for k in ["seq_length", "num_alignments"]:
|
| 155 |
+
if k in protein:
|
| 156 |
+
protein[k] = protein[k][0]
|
| 157 |
+
|
| 158 |
+
return protein
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@curry1
|
| 162 |
+
def randomly_replace_msa_with_unknown(protein, replace_proportion):
|
| 163 |
+
"""Replace a portion of the MSA with 'X'."""
|
| 164 |
+
msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
|
| 165 |
+
x_idx = 20
|
| 166 |
+
gap_idx = 21
|
| 167 |
+
msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
|
| 168 |
+
protein["msa"] = torch.where(
|
| 169 |
+
msa_mask,
|
| 170 |
+
torch.ones_like(protein["msa"]) * x_idx,
|
| 171 |
+
protein["msa"]
|
| 172 |
+
)
|
| 173 |
+
aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
|
| 174 |
+
|
| 175 |
+
protein["aatype"] = torch.where(
|
| 176 |
+
aatype_mask,
|
| 177 |
+
torch.ones_like(protein["aatype"]) * x_idx,
|
| 178 |
+
protein["aatype"],
|
| 179 |
+
)
|
| 180 |
+
return protein
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@curry1
|
| 184 |
+
def sample_msa(protein, max_seq, keep_extra, seed=None):
|
| 185 |
+
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
|
| 186 |
+
num_seq = protein["msa"].shape[0]
|
| 187 |
+
g = torch.Generator(device=protein["msa"].device)
|
| 188 |
+
if seed is not None:
|
| 189 |
+
g.manual_seed(seed)
|
| 190 |
+
shuffled = torch.randperm(num_seq - 1, generator=g) + 1
|
| 191 |
+
index_order = torch.cat(
|
| 192 |
+
(torch.tensor([0], device=shuffled.device), shuffled),
|
| 193 |
+
dim=0
|
| 194 |
+
)
|
| 195 |
+
num_sel = min(max_seq, num_seq)
|
| 196 |
+
sel_seq, not_sel_seq = torch.split(
|
| 197 |
+
index_order, [num_sel, num_seq - num_sel]
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
for k in MSA_FEATURE_NAMES:
|
| 201 |
+
if k in protein:
|
| 202 |
+
if keep_extra:
|
| 203 |
+
protein["extra_" + k] = torch.index_select(
|
| 204 |
+
protein[k], 0, not_sel_seq
|
| 205 |
+
)
|
| 206 |
+
protein[k] = torch.index_select(protein[k], 0, sel_seq)
|
| 207 |
+
|
| 208 |
+
return protein
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@curry1
|
| 212 |
+
def add_distillation_flag(protein, distillation):
|
| 213 |
+
protein['is_distillation'] = distillation
|
| 214 |
+
return protein
|
| 215 |
+
|
| 216 |
+
@curry1
|
| 217 |
+
def sample_msa_distillation(protein, max_seq):
|
| 218 |
+
if(protein["is_distillation"] == 1):
|
| 219 |
+
protein = sample_msa(max_seq, keep_extra=False)(protein)
|
| 220 |
+
return protein
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@curry1
|
| 224 |
+
def crop_extra_msa(protein, max_extra_msa):
|
| 225 |
+
num_seq = protein["extra_msa"].shape[0]
|
| 226 |
+
num_sel = min(max_extra_msa, num_seq)
|
| 227 |
+
select_indices = torch.randperm(num_seq)[:num_sel]
|
| 228 |
+
for k in MSA_FEATURE_NAMES:
|
| 229 |
+
if "extra_" + k in protein:
|
| 230 |
+
protein["extra_" + k] = torch.index_select(
|
| 231 |
+
protein["extra_" + k], 0, select_indices
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return protein
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def delete_extra_msa(protein):
|
| 238 |
+
for k in MSA_FEATURE_NAMES:
|
| 239 |
+
if "extra_" + k in protein:
|
| 240 |
+
del protein["extra_" + k]
|
| 241 |
+
return protein
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# Not used in inference
|
| 245 |
+
@curry1
|
| 246 |
+
def block_delete_msa(protein, config):
|
| 247 |
+
num_seq = protein["msa"].shape[0]
|
| 248 |
+
block_num_seq = torch.floor(
|
| 249 |
+
torch.tensor(num_seq, dtype=torch.float32, device=protein["msa"].device)
|
| 250 |
+
* config.msa_fraction_per_block
|
| 251 |
+
).to(torch.int32)
|
| 252 |
+
|
| 253 |
+
if config.randomize_num_blocks:
|
| 254 |
+
nb = torch.distributions.uniform.Uniform(
|
| 255 |
+
0, config.num_blocks + 1
|
| 256 |
+
).sample()
|
| 257 |
+
else:
|
| 258 |
+
nb = config.num_blocks
|
| 259 |
+
|
| 260 |
+
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
|
| 261 |
+
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
|
| 262 |
+
del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
|
| 263 |
+
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
|
| 264 |
+
|
| 265 |
+
# Make sure we keep the original sequence
|
| 266 |
+
combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
|
| 267 |
+
uniques, counts = combined.unique(return_counts=True)
|
| 268 |
+
difference = uniques[counts == 1]
|
| 269 |
+
intersection = uniques[counts > 1]
|
| 270 |
+
keep_indices = torch.squeeze(difference, 0)
|
| 271 |
+
|
| 272 |
+
for k in MSA_FEATURE_NAMES:
|
| 273 |
+
if k in protein:
|
| 274 |
+
protein[k] = torch.gather(protein[k], keep_indices)
|
| 275 |
+
|
| 276 |
+
return protein
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
@curry1
|
| 280 |
+
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
|
| 281 |
+
weights = torch.cat(
|
| 282 |
+
[
|
| 283 |
+
torch.ones(21, device=protein["msa"].device),
|
| 284 |
+
gap_agreement_weight * torch.ones(1, device=protein["msa"].device),
|
| 285 |
+
torch.zeros(1, device=protein["msa"].device)
|
| 286 |
+
],
|
| 287 |
+
0,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Make agreement score as weighted Hamming distance
|
| 291 |
+
msa_one_hot = make_one_hot(protein["msa"], 23)
|
| 292 |
+
sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
|
| 293 |
+
extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
|
| 294 |
+
extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
|
| 295 |
+
|
| 296 |
+
num_seq, num_res, _ = sample_one_hot.shape
|
| 297 |
+
extra_num_seq, _, _ = extra_one_hot.shape
|
| 298 |
+
|
| 299 |
+
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
|
| 300 |
+
# in an optimized fashion to avoid possible memory or computation blowup.
|
| 301 |
+
agreement = torch.matmul(
|
| 302 |
+
torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
|
| 303 |
+
torch.reshape(
|
| 304 |
+
sample_one_hot * weights, [num_seq, num_res * 23]
|
| 305 |
+
).transpose(0, 1),
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# Assign each sequence in the extra sequences to the closest MSA sample
|
| 309 |
+
protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
|
| 310 |
+
torch.int64
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
return protein
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def unsorted_segment_sum(data, segment_ids, num_segments):
|
| 317 |
+
"""
|
| 318 |
+
Computes the sum along segments of a tensor. Similar to
|
| 319 |
+
tf.unsorted_segment_sum, but only supports 1-D indices.
|
| 320 |
+
|
| 321 |
+
:param data: A tensor whose segments are to be summed.
|
| 322 |
+
:param segment_ids: The 1-D segment indices tensor.
|
| 323 |
+
:param num_segments: The number of segments.
|
| 324 |
+
:return: A tensor of same data type as the data argument.
|
| 325 |
+
"""
|
| 326 |
+
assert (
|
| 327 |
+
len(segment_ids.shape) == 1 and
|
| 328 |
+
segment_ids.shape[0] == data.shape[0]
|
| 329 |
+
)
|
| 330 |
+
segment_ids = segment_ids.view(
|
| 331 |
+
segment_ids.shape[0], *((1,) * len(data.shape[1:]))
|
| 332 |
+
)
|
| 333 |
+
segment_ids = segment_ids.expand(data.shape)
|
| 334 |
+
shape = [num_segments] + list(data.shape[1:])
|
| 335 |
+
tensor = (
|
| 336 |
+
torch.zeros(*shape, device=segment_ids.device)
|
| 337 |
+
.scatter_add_(0, segment_ids, data.float())
|
| 338 |
+
)
|
| 339 |
+
tensor = tensor.type(data.dtype)
|
| 340 |
+
return tensor
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@curry1
|
| 344 |
+
def summarize_clusters(protein):
|
| 345 |
+
"""Produce profile and deletion_matrix_mean within each cluster."""
|
| 346 |
+
num_seq = protein["msa"].shape[0]
|
| 347 |
+
|
| 348 |
+
def csum(x):
|
| 349 |
+
return unsorted_segment_sum(
|
| 350 |
+
x, protein["extra_cluster_assignment"], num_seq
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
mask = protein["extra_msa_mask"]
|
| 354 |
+
mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center
|
| 355 |
+
|
| 356 |
+
msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
|
| 357 |
+
msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
|
| 358 |
+
protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
|
| 359 |
+
del msa_sum
|
| 360 |
+
|
| 361 |
+
del_sum = csum(mask * protein["extra_deletion_matrix"])
|
| 362 |
+
del_sum += protein["deletion_matrix"] # Original sequence
|
| 363 |
+
protein["cluster_deletion_mean"] = del_sum / mask_counts
|
| 364 |
+
del del_sum
|
| 365 |
+
|
| 366 |
+
return protein
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def make_msa_mask(protein):
|
| 370 |
+
"""Mask features are all ones, but will later be zero-padded."""
|
| 371 |
+
protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
|
| 372 |
+
protein["msa_row_mask"] = torch.ones(
|
| 373 |
+
(protein["msa"].shape[0]), dtype=torch.float32
|
| 374 |
+
)
|
| 375 |
+
return protein
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
|
| 379 |
+
"""Create pseudo beta features."""
|
| 380 |
+
is_gly = torch.eq(aatype, rc.restype_order["G"])
|
| 381 |
+
ca_idx = rc.atom_order["CA"]
|
| 382 |
+
cb_idx = rc.atom_order["CB"]
|
| 383 |
+
pseudo_beta = torch.where(
|
| 384 |
+
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
|
| 385 |
+
all_atom_positions[..., ca_idx, :],
|
| 386 |
+
all_atom_positions[..., cb_idx, :],
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
if all_atom_mask is not None:
|
| 390 |
+
pseudo_beta_mask = torch.where(
|
| 391 |
+
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
|
| 392 |
+
)
|
| 393 |
+
return pseudo_beta, pseudo_beta_mask
|
| 394 |
+
else:
|
| 395 |
+
return pseudo_beta
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
@curry1
|
| 399 |
+
def make_pseudo_beta(protein, prefix=""):
|
| 400 |
+
"""Create pseudo-beta (alpha for glycine) position and mask."""
|
| 401 |
+
assert prefix in ["", "template_"]
|
| 402 |
+
(
|
| 403 |
+
protein[prefix + "pseudo_beta"],
|
| 404 |
+
protein[prefix + "pseudo_beta_mask"],
|
| 405 |
+
) = pseudo_beta_fn(
|
| 406 |
+
protein["template_aatype" if prefix else "aatype"],
|
| 407 |
+
protein[prefix + "all_atom_positions"],
|
| 408 |
+
protein["template_all_atom_mask" if prefix else "all_atom_mask"],
|
| 409 |
+
)
|
| 410 |
+
return protein
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
@curry1
|
| 414 |
+
def add_constant_field(protein, key, value):
|
| 415 |
+
protein[key] = torch.tensor(value, device=protein["msa"].device)
|
| 416 |
+
return protein
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def shaped_categorical(probs, epsilon=1e-10):
|
| 420 |
+
ds = probs.shape
|
| 421 |
+
num_classes = ds[-1]
|
| 422 |
+
distribution = torch.distributions.categorical.Categorical(
|
| 423 |
+
torch.reshape(probs + epsilon, [-1, num_classes])
|
| 424 |
+
)
|
| 425 |
+
counts = distribution.sample()
|
| 426 |
+
return torch.reshape(counts, ds[:-1])
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def make_hhblits_profile(protein):
|
| 430 |
+
"""Compute the HHblits MSA profile if not already present."""
|
| 431 |
+
if "hhblits_profile" in protein:
|
| 432 |
+
return protein
|
| 433 |
+
|
| 434 |
+
# Compute the profile for every residue (over all MSA sequences).
|
| 435 |
+
msa_one_hot = make_one_hot(protein["msa"], 22)
|
| 436 |
+
|
| 437 |
+
protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
|
| 438 |
+
return protein
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
@curry1
|
| 442 |
+
def make_masked_msa(protein, config, replace_fraction):
|
| 443 |
+
"""Create data for BERT on raw MSA."""
|
| 444 |
+
# Add a random amino acid uniformly.
|
| 445 |
+
random_aa = torch.tensor(
|
| 446 |
+
[0.05] * 20 + [0.0, 0.0],
|
| 447 |
+
dtype=torch.float32,
|
| 448 |
+
device=protein["aatype"].device
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
categorical_probs = (
|
| 452 |
+
config.uniform_prob * random_aa
|
| 453 |
+
+ config.profile_prob * protein["hhblits_profile"]
|
| 454 |
+
+ config.same_prob * make_one_hot(protein["msa"], 22)
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# Put all remaining probability on [MASK] which is a new column
|
| 458 |
+
pad_shapes = list(
|
| 459 |
+
reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
|
| 460 |
+
)
|
| 461 |
+
pad_shapes[1] = 1
|
| 462 |
+
mask_prob = (
|
| 463 |
+
1.0 - config.profile_prob - config.same_prob - config.uniform_prob
|
| 464 |
+
)
|
| 465 |
+
assert mask_prob >= 0.0
|
| 466 |
+
|
| 467 |
+
categorical_probs = torch.nn.functional.pad(
|
| 468 |
+
categorical_probs, pad_shapes, value=mask_prob
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
sh = protein["msa"].shape
|
| 472 |
+
mask_position = torch.rand(sh) < replace_fraction
|
| 473 |
+
|
| 474 |
+
bert_msa = shaped_categorical(categorical_probs)
|
| 475 |
+
bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
|
| 476 |
+
|
| 477 |
+
# Mix real and masked MSA
|
| 478 |
+
protein["bert_mask"] = mask_position.to(torch.float32)
|
| 479 |
+
protein["true_msa"] = protein["msa"]
|
| 480 |
+
protein["msa"] = bert_msa
|
| 481 |
+
|
| 482 |
+
return protein
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
@curry1
|
| 486 |
+
def make_fixed_size(
|
| 487 |
+
protein,
|
| 488 |
+
shape_schema,
|
| 489 |
+
msa_cluster_size,
|
| 490 |
+
extra_msa_size,
|
| 491 |
+
num_res=0,
|
| 492 |
+
num_templates=0,
|
| 493 |
+
):
|
| 494 |
+
"""Guess at the MSA and sequence dimension to make fixed size."""
|
| 495 |
+
pad_size_map = {
|
| 496 |
+
NUM_RES: num_res,
|
| 497 |
+
NUM_MSA_SEQ: msa_cluster_size,
|
| 498 |
+
NUM_EXTRA_SEQ: extra_msa_size,
|
| 499 |
+
NUM_TEMPLATES: num_templates,
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
for k, v in protein.items():
|
| 503 |
+
# Don't transfer this to the accelerator.
|
| 504 |
+
if k == "extra_cluster_assignment":
|
| 505 |
+
continue
|
| 506 |
+
shape = list(v.shape)
|
| 507 |
+
schema = shape_schema[k]
|
| 508 |
+
msg = "Rank mismatch between shape and shape schema for"
|
| 509 |
+
assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
|
| 510 |
+
pad_size = [
|
| 511 |
+
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
|
| 512 |
+
]
|
| 513 |
+
|
| 514 |
+
padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
|
| 515 |
+
padding.reverse()
|
| 516 |
+
padding = list(itertools.chain(*padding))
|
| 517 |
+
if padding:
|
| 518 |
+
protein[k] = torch.nn.functional.pad(v, padding)
|
| 519 |
+
protein[k] = torch.reshape(protein[k], pad_size)
|
| 520 |
+
|
| 521 |
+
return protein
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
@curry1
|
| 525 |
+
def make_msa_feat(protein):
|
| 526 |
+
"""Create and concatenate MSA features."""
|
| 527 |
+
# Whether there is a domain break. Always zero for chains, but keeping for
|
| 528 |
+
# compatibility with domain datasets.
|
| 529 |
+
has_break = torch.clip(
|
| 530 |
+
protein["between_segment_residues"].to(torch.float32), 0, 1
|
| 531 |
+
)
|
| 532 |
+
aatype_1hot = make_one_hot(protein["aatype"], 21)
|
| 533 |
+
|
| 534 |
+
target_feat = [
|
| 535 |
+
torch.unsqueeze(has_break, dim=-1),
|
| 536 |
+
aatype_1hot, # Everyone gets the original sequence.
|
| 537 |
+
]
|
| 538 |
+
|
| 539 |
+
msa_1hot = make_one_hot(protein["msa"], 23)
|
| 540 |
+
has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
|
| 541 |
+
deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
|
| 542 |
+
2.0 / np.pi
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
msa_feat = [
|
| 546 |
+
msa_1hot,
|
| 547 |
+
torch.unsqueeze(has_deletion, dim=-1),
|
| 548 |
+
torch.unsqueeze(deletion_value, dim=-1),
|
| 549 |
+
]
|
| 550 |
+
|
| 551 |
+
if "cluster_profile" in protein:
|
| 552 |
+
deletion_mean_value = torch.atan(
|
| 553 |
+
protein["cluster_deletion_mean"] / 3.0
|
| 554 |
+
) * (2.0 / np.pi)
|
| 555 |
+
msa_feat.extend(
|
| 556 |
+
[
|
| 557 |
+
protein["cluster_profile"],
|
| 558 |
+
torch.unsqueeze(deletion_mean_value, dim=-1),
|
| 559 |
+
]
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
if "extra_deletion_matrix" in protein:
|
| 563 |
+
protein["extra_has_deletion"] = torch.clip(
|
| 564 |
+
protein["extra_deletion_matrix"], 0.0, 1.0
|
| 565 |
+
)
|
| 566 |
+
protein["extra_deletion_value"] = torch.atan(
|
| 567 |
+
protein["extra_deletion_matrix"] / 3.0
|
| 568 |
+
) * (2.0 / np.pi)
|
| 569 |
+
|
| 570 |
+
protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
|
| 571 |
+
protein["target_feat"] = torch.cat(target_feat, dim=-1)
|
| 572 |
+
return protein
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
@curry1
|
| 576 |
+
def select_feat(protein, feature_list):
|
| 577 |
+
return {k: v for k, v in protein.items() if k in feature_list}
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
@curry1
|
| 581 |
+
def crop_templates(protein, max_templates):
|
| 582 |
+
for k, v in protein.items():
|
| 583 |
+
if k.startswith("template_"):
|
| 584 |
+
protein[k] = v[:max_templates]
|
| 585 |
+
return protein
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def make_atom14_masks(protein):
|
| 589 |
+
"""Construct denser atom positions (14 dimensions instead of 37)."""
|
| 590 |
+
restype_atom14_to_atom37 = []
|
| 591 |
+
restype_atom37_to_atom14 = []
|
| 592 |
+
restype_atom14_mask = []
|
| 593 |
+
|
| 594 |
+
for rt in rc.restypes:
|
| 595 |
+
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
|
| 596 |
+
restype_atom14_to_atom37.append(
|
| 597 |
+
[(rc.atom_order[name] if name else 0) for name in atom_names]
|
| 598 |
+
)
|
| 599 |
+
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
|
| 600 |
+
restype_atom37_to_atom14.append(
|
| 601 |
+
[
|
| 602 |
+
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
|
| 603 |
+
for name in rc.atom_types
|
| 604 |
+
]
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
restype_atom14_mask.append(
|
| 608 |
+
[(1.0 if name else 0.0) for name in atom_names]
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# Add dummy mapping for restype 'UNK'
|
| 612 |
+
restype_atom14_to_atom37.append([0] * 14)
|
| 613 |
+
restype_atom37_to_atom14.append([0] * 37)
|
| 614 |
+
restype_atom14_mask.append([0.0] * 14)
|
| 615 |
+
|
| 616 |
+
restype_atom14_to_atom37 = torch.tensor(
|
| 617 |
+
restype_atom14_to_atom37,
|
| 618 |
+
dtype=torch.int32,
|
| 619 |
+
device=protein["aatype"].device,
|
| 620 |
+
)
|
| 621 |
+
restype_atom37_to_atom14 = torch.tensor(
|
| 622 |
+
restype_atom37_to_atom14,
|
| 623 |
+
dtype=torch.int32,
|
| 624 |
+
device=protein["aatype"].device,
|
| 625 |
+
)
|
| 626 |
+
restype_atom14_mask = torch.tensor(
|
| 627 |
+
restype_atom14_mask,
|
| 628 |
+
dtype=torch.float32,
|
| 629 |
+
device=protein["aatype"].device,
|
| 630 |
+
)
|
| 631 |
+
protein_aatype = protein['aatype'].to(torch.long)
|
| 632 |
+
|
| 633 |
+
# create the mapping for (residx, atom14) --> atom37, i.e. an array
|
| 634 |
+
# with shape (num_res, 14) containing the atom37 indices for this protein
|
| 635 |
+
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
|
| 636 |
+
residx_atom14_mask = restype_atom14_mask[protein_aatype]
|
| 637 |
+
|
| 638 |
+
protein["atom14_atom_exists"] = residx_atom14_mask
|
| 639 |
+
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
|
| 640 |
+
|
| 641 |
+
# create the gather indices for mapping back
|
| 642 |
+
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
|
| 643 |
+
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
|
| 644 |
+
|
| 645 |
+
# create the corresponding mask
|
| 646 |
+
restype_atom37_mask = torch.zeros(
|
| 647 |
+
[21, 37], dtype=torch.float32, device=protein["aatype"].device
|
| 648 |
+
)
|
| 649 |
+
for restype, restype_letter in enumerate(rc.restypes):
|
| 650 |
+
restype_name = rc.restype_1to3[restype_letter]
|
| 651 |
+
atom_names = rc.residue_atoms[restype_name]
|
| 652 |
+
for atom_name in atom_names:
|
| 653 |
+
atom_type = rc.atom_order[atom_name]
|
| 654 |
+
restype_atom37_mask[restype, atom_type] = 1
|
| 655 |
+
|
| 656 |
+
residx_atom37_mask = restype_atom37_mask[protein_aatype]
|
| 657 |
+
protein["atom37_atom_exists"] = residx_atom37_mask
|
| 658 |
+
|
| 659 |
+
return protein
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def make_atom14_masks_np(batch):
|
| 663 |
+
batch = tree_map(
|
| 664 |
+
lambda n: torch.tensor(n, device=batch["aatype"].device),
|
| 665 |
+
batch,
|
| 666 |
+
np.ndarray
|
| 667 |
+
)
|
| 668 |
+
out = make_atom14_masks(batch)
|
| 669 |
+
out = tensor_tree_map(lambda t: np.array(t), out)
|
| 670 |
+
return out
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def make_atom14_positions(protein):
|
| 674 |
+
"""Constructs denser atom positions (14 dimensions instead of 37)."""
|
| 675 |
+
residx_atom14_mask = protein["atom14_atom_exists"]
|
| 676 |
+
residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
|
| 677 |
+
|
| 678 |
+
# Create a mask for known ground truth positions.
|
| 679 |
+
residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
|
| 680 |
+
protein["all_atom_mask"],
|
| 681 |
+
residx_atom14_to_atom37,
|
| 682 |
+
dim=-1,
|
| 683 |
+
no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
# Gather the ground truth positions.
|
| 687 |
+
residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
|
| 688 |
+
batched_gather(
|
| 689 |
+
protein["all_atom_positions"],
|
| 690 |
+
residx_atom14_to_atom37,
|
| 691 |
+
dim=-2,
|
| 692 |
+
no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
|
| 693 |
+
)
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
protein["atom14_atom_exists"] = residx_atom14_mask
|
| 697 |
+
protein["atom14_gt_exists"] = residx_atom14_gt_mask
|
| 698 |
+
protein["atom14_gt_positions"] = residx_atom14_gt_positions
|
| 699 |
+
|
| 700 |
+
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
|
| 701 |
+
# alternative ground truth coordinates where the naming is swapped
|
| 702 |
+
restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
|
| 703 |
+
restype_3 += ["UNK"]
|
| 704 |
+
|
| 705 |
+
# Matrices for renaming ambiguous atoms.
|
| 706 |
+
all_matrices = {
|
| 707 |
+
res: torch.eye(
|
| 708 |
+
14,
|
| 709 |
+
dtype=protein["all_atom_mask"].dtype,
|
| 710 |
+
device=protein["all_atom_mask"].device,
|
| 711 |
+
)
|
| 712 |
+
for res in restype_3
|
| 713 |
+
}
|
| 714 |
+
for resname, swap in rc.residue_atom_renaming_swaps.items():
|
| 715 |
+
correspondences = torch.arange(
|
| 716 |
+
14, device=protein["all_atom_mask"].device
|
| 717 |
+
)
|
| 718 |
+
for source_atom_swap, target_atom_swap in swap.items():
|
| 719 |
+
source_index = rc.restype_name_to_atom14_names[resname].index(
|
| 720 |
+
source_atom_swap
|
| 721 |
+
)
|
| 722 |
+
target_index = rc.restype_name_to_atom14_names[resname].index(
|
| 723 |
+
target_atom_swap
|
| 724 |
+
)
|
| 725 |
+
correspondences[source_index] = target_index
|
| 726 |
+
correspondences[target_index] = source_index
|
| 727 |
+
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
|
| 728 |
+
for index, correspondence in enumerate(correspondences):
|
| 729 |
+
renaming_matrix[index, correspondence] = 1.0
|
| 730 |
+
all_matrices[resname] = renaming_matrix
|
| 731 |
+
|
| 732 |
+
renaming_matrices = torch.stack(
|
| 733 |
+
[all_matrices[restype] for restype in restype_3]
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# Pick the transformation matrices for the given residue sequence
|
| 737 |
+
# shape (num_res, 14, 14).
|
| 738 |
+
renaming_transform = renaming_matrices[protein["aatype"]]
|
| 739 |
+
|
| 740 |
+
# Apply it to the ground truth positions. shape (num_res, 14, 3).
|
| 741 |
+
alternative_gt_positions = torch.einsum(
|
| 742 |
+
"...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
|
| 743 |
+
)
|
| 744 |
+
protein["atom14_alt_gt_positions"] = alternative_gt_positions
|
| 745 |
+
|
| 746 |
+
# Create the mask for the alternative ground truth (differs from the
|
| 747 |
+
# ground truth mask, if only one of the atoms in an ambiguous pair has a
|
| 748 |
+
# ground truth position).
|
| 749 |
+
alternative_gt_mask = torch.einsum(
|
| 750 |
+
"...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
|
| 751 |
+
)
|
| 752 |
+
protein["atom14_alt_gt_exists"] = alternative_gt_mask
|
| 753 |
+
|
| 754 |
+
# Create an ambiguous atoms mask. shape: (21, 14).
|
| 755 |
+
restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
|
| 756 |
+
for resname, swap in rc.residue_atom_renaming_swaps.items():
|
| 757 |
+
for atom_name1, atom_name2 in swap.items():
|
| 758 |
+
restype = rc.restype_order[rc.restype_3to1[resname]]
|
| 759 |
+
atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
|
| 760 |
+
atom_name1
|
| 761 |
+
)
|
| 762 |
+
atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
|
| 763 |
+
atom_name2
|
| 764 |
+
)
|
| 765 |
+
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
|
| 766 |
+
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
|
| 767 |
+
|
| 768 |
+
# From this create an ambiguous_mask for the given sequence.
|
| 769 |
+
protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
|
| 770 |
+
protein["aatype"]
|
| 771 |
+
]
|
| 772 |
+
|
| 773 |
+
return protein
|
| 774 |
+
|
| 775 |
+
|
| 776 |
+
def atom37_to_frames(protein, eps=1e-8):
|
| 777 |
+
aatype = protein["aatype"]
|
| 778 |
+
all_atom_positions = protein["all_atom_positions"]
|
| 779 |
+
all_atom_mask = protein["all_atom_mask"]
|
| 780 |
+
|
| 781 |
+
batch_dims = len(aatype.shape[:-1])
|
| 782 |
+
|
| 783 |
+
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
|
| 784 |
+
restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
|
| 785 |
+
restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
|
| 786 |
+
|
| 787 |
+
for restype, restype_letter in enumerate(rc.restypes):
|
| 788 |
+
resname = rc.restype_1to3[restype_letter]
|
| 789 |
+
for chi_idx in range(4):
|
| 790 |
+
if rc.chi_angles_mask[restype][chi_idx]:
|
| 791 |
+
names = rc.chi_angles_atoms[resname][chi_idx]
|
| 792 |
+
restype_rigidgroup_base_atom_names[
|
| 793 |
+
restype, chi_idx + 4, :
|
| 794 |
+
] = names[1:]
|
| 795 |
+
|
| 796 |
+
restype_rigidgroup_mask = all_atom_mask.new_zeros(
|
| 797 |
+
(*aatype.shape[:-1], 21, 8),
|
| 798 |
+
)
|
| 799 |
+
restype_rigidgroup_mask[..., 0] = 1
|
| 800 |
+
restype_rigidgroup_mask[..., 3] = 1
|
| 801 |
+
restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
|
| 802 |
+
rc.chi_angles_mask
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
lookuptable = rc.atom_order.copy()
|
| 806 |
+
lookuptable[""] = 0
|
| 807 |
+
lookup = np.vectorize(lambda x: lookuptable[x])
|
| 808 |
+
restype_rigidgroup_base_atom37_idx = lookup(
|
| 809 |
+
restype_rigidgroup_base_atom_names,
|
| 810 |
+
)
|
| 811 |
+
restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
|
| 812 |
+
restype_rigidgroup_base_atom37_idx,
|
| 813 |
+
)
|
| 814 |
+
restype_rigidgroup_base_atom37_idx = (
|
| 815 |
+
restype_rigidgroup_base_atom37_idx.view(
|
| 816 |
+
*((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
|
| 817 |
+
)
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
residx_rigidgroup_base_atom37_idx = batched_gather(
|
| 821 |
+
restype_rigidgroup_base_atom37_idx,
|
| 822 |
+
aatype,
|
| 823 |
+
dim=-3,
|
| 824 |
+
no_batch_dims=batch_dims,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
base_atom_pos = batched_gather(
|
| 828 |
+
all_atom_positions,
|
| 829 |
+
residx_rigidgroup_base_atom37_idx,
|
| 830 |
+
dim=-2,
|
| 831 |
+
no_batch_dims=len(all_atom_positions.shape[:-2]),
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
gt_frames = Rigid.from_3_points(
|
| 835 |
+
p_neg_x_axis=base_atom_pos[..., 0, :],
|
| 836 |
+
origin=base_atom_pos[..., 1, :],
|
| 837 |
+
p_xy_plane=base_atom_pos[..., 2, :],
|
| 838 |
+
eps=eps,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
group_exists = batched_gather(
|
| 842 |
+
restype_rigidgroup_mask,
|
| 843 |
+
aatype,
|
| 844 |
+
dim=-2,
|
| 845 |
+
no_batch_dims=batch_dims,
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
gt_atoms_exist = batched_gather(
|
| 849 |
+
all_atom_mask,
|
| 850 |
+
residx_rigidgroup_base_atom37_idx,
|
| 851 |
+
dim=-1,
|
| 852 |
+
no_batch_dims=len(all_atom_mask.shape[:-1]),
|
| 853 |
+
)
|
| 854 |
+
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
|
| 855 |
+
|
| 856 |
+
rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
|
| 857 |
+
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
|
| 858 |
+
rots[..., 0, 0, 0] = -1
|
| 859 |
+
rots[..., 0, 2, 2] = -1
|
| 860 |
+
rots = Rotation(rot_mats=rots)
|
| 861 |
+
|
| 862 |
+
gt_frames = gt_frames.compose(Rigid(rots, None))
|
| 863 |
+
|
| 864 |
+
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
|
| 865 |
+
*((1,) * batch_dims), 21, 8
|
| 866 |
+
)
|
| 867 |
+
restype_rigidgroup_rots = torch.eye(
|
| 868 |
+
3, dtype=all_atom_mask.dtype, device=aatype.device
|
| 869 |
+
)
|
| 870 |
+
restype_rigidgroup_rots = torch.tile(
|
| 871 |
+
restype_rigidgroup_rots,
|
| 872 |
+
(*((1,) * batch_dims), 21, 8, 1, 1),
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
for resname, _ in rc.residue_atom_renaming_swaps.items():
|
| 876 |
+
restype = rc.restype_order[rc.restype_3to1[resname]]
|
| 877 |
+
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
|
| 878 |
+
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
|
| 879 |
+
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
|
| 880 |
+
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
|
| 881 |
+
|
| 882 |
+
residx_rigidgroup_is_ambiguous = batched_gather(
|
| 883 |
+
restype_rigidgroup_is_ambiguous,
|
| 884 |
+
aatype,
|
| 885 |
+
dim=-2,
|
| 886 |
+
no_batch_dims=batch_dims,
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
residx_rigidgroup_ambiguity_rot = batched_gather(
|
| 890 |
+
restype_rigidgroup_rots,
|
| 891 |
+
aatype,
|
| 892 |
+
dim=-4,
|
| 893 |
+
no_batch_dims=batch_dims,
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
residx_rigidgroup_ambiguity_rot = Rotation(
|
| 897 |
+
rot_mats=residx_rigidgroup_ambiguity_rot
|
| 898 |
+
)
|
| 899 |
+
alt_gt_frames = gt_frames.compose(
|
| 900 |
+
Rigid(residx_rigidgroup_ambiguity_rot, None)
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
gt_frames_tensor = gt_frames.to_tensor_4x4()
|
| 904 |
+
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
|
| 905 |
+
|
| 906 |
+
protein["rigidgroups_gt_frames"] = gt_frames_tensor
|
| 907 |
+
protein["rigidgroups_gt_exists"] = gt_exists
|
| 908 |
+
protein["rigidgroups_group_exists"] = group_exists
|
| 909 |
+
protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
|
| 910 |
+
protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
|
| 911 |
+
|
| 912 |
+
return protein
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
def get_chi_atom_indices():
|
| 916 |
+
"""Returns atom indices needed to compute chi angles for all residue types.
|
| 917 |
+
|
| 918 |
+
Returns:
|
| 919 |
+
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
|
| 920 |
+
in the order specified in rc.restypes + unknown residue type
|
| 921 |
+
at the end. For chi angles which are not defined on the residue, the
|
| 922 |
+
positions indices are by default set to 0.
|
| 923 |
+
"""
|
| 924 |
+
chi_atom_indices = []
|
| 925 |
+
for residue_name in rc.restypes:
|
| 926 |
+
residue_name = rc.restype_1to3[residue_name]
|
| 927 |
+
residue_chi_angles = rc.chi_angles_atoms[residue_name]
|
| 928 |
+
atom_indices = []
|
| 929 |
+
for chi_angle in residue_chi_angles:
|
| 930 |
+
atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
|
| 931 |
+
for _ in range(4 - len(atom_indices)):
|
| 932 |
+
atom_indices.append(
|
| 933 |
+
[0, 0, 0, 0]
|
| 934 |
+
) # For chi angles not defined on the AA.
|
| 935 |
+
chi_atom_indices.append(atom_indices)
|
| 936 |
+
|
| 937 |
+
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
|
| 938 |
+
|
| 939 |
+
return chi_atom_indices
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
@curry1
|
| 943 |
+
def atom37_to_torsion_angles(
|
| 944 |
+
protein,
|
| 945 |
+
prefix="",
|
| 946 |
+
):
|
| 947 |
+
"""
|
| 948 |
+
Convert coordinates to torsion angles.
|
| 949 |
+
|
| 950 |
+
This function is extremely sensitive to floating point imprecisions
|
| 951 |
+
and should be run with double precision whenever possible.
|
| 952 |
+
|
| 953 |
+
Args:
|
| 954 |
+
Dict containing:
|
| 955 |
+
* (prefix)aatype:
|
| 956 |
+
[*, N_res] residue indices
|
| 957 |
+
* (prefix)all_atom_positions:
|
| 958 |
+
[*, N_res, 37, 3] atom positions (in atom37
|
| 959 |
+
format)
|
| 960 |
+
* (prefix)all_atom_mask:
|
| 961 |
+
[*, N_res, 37] atom position mask
|
| 962 |
+
Returns:
|
| 963 |
+
The same dictionary updated with the following features:
|
| 964 |
+
|
| 965 |
+
"(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
|
| 966 |
+
Torsion angles
|
| 967 |
+
"(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
|
| 968 |
+
Alternate torsion angles (accounting for 180-degree symmetry)
|
| 969 |
+
"(prefix)torsion_angles_mask" ([*, N_res, 7])
|
| 970 |
+
Torsion angles mask
|
| 971 |
+
"""
|
| 972 |
+
aatype = protein[prefix + "aatype"]
|
| 973 |
+
all_atom_positions = protein[prefix + "all_atom_positions"]
|
| 974 |
+
all_atom_mask = protein[prefix + "all_atom_mask"]
|
| 975 |
+
|
| 976 |
+
aatype = torch.clamp(aatype, max=20)
|
| 977 |
+
|
| 978 |
+
pad = all_atom_positions.new_zeros(
|
| 979 |
+
[*all_atom_positions.shape[:-3], 1, 37, 3]
|
| 980 |
+
)
|
| 981 |
+
prev_all_atom_positions = torch.cat(
|
| 982 |
+
[pad, all_atom_positions[..., :-1, :, :]], dim=-3
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
|
| 986 |
+
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
|
| 987 |
+
|
| 988 |
+
pre_omega_atom_pos = torch.cat(
|
| 989 |
+
[prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
|
| 990 |
+
dim=-2,
|
| 991 |
+
)
|
| 992 |
+
phi_atom_pos = torch.cat(
|
| 993 |
+
[prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
|
| 994 |
+
dim=-2,
|
| 995 |
+
)
|
| 996 |
+
psi_atom_pos = torch.cat(
|
| 997 |
+
[all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
|
| 998 |
+
dim=-2,
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
pre_omega_mask = torch.prod(
|
| 1002 |
+
prev_all_atom_mask[..., 1:3], dim=-1
|
| 1003 |
+
) * torch.prod(all_atom_mask[..., :2], dim=-1)
|
| 1004 |
+
phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
|
| 1005 |
+
all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
|
| 1006 |
+
)
|
| 1007 |
+
psi_mask = (
|
| 1008 |
+
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
|
| 1009 |
+
* all_atom_mask[..., 4]
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
chi_atom_indices = torch.as_tensor(
|
| 1013 |
+
get_chi_atom_indices(), device=aatype.device
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
atom_indices = chi_atom_indices[..., aatype, :, :]
|
| 1017 |
+
chis_atom_pos = batched_gather(
|
| 1018 |
+
all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
|
| 1019 |
+
)
|
| 1020 |
+
|
| 1021 |
+
chi_angles_mask = list(rc.chi_angles_mask)
|
| 1022 |
+
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
|
| 1023 |
+
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
|
| 1024 |
+
|
| 1025 |
+
chis_mask = chi_angles_mask[aatype, :]
|
| 1026 |
+
|
| 1027 |
+
chi_angle_atoms_mask = batched_gather(
|
| 1028 |
+
all_atom_mask,
|
| 1029 |
+
atom_indices,
|
| 1030 |
+
dim=-1,
|
| 1031 |
+
no_batch_dims=len(atom_indices.shape[:-2]),
|
| 1032 |
+
)
|
| 1033 |
+
chi_angle_atoms_mask = torch.prod(
|
| 1034 |
+
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
|
| 1035 |
+
)
|
| 1036 |
+
chis_mask = chis_mask * chi_angle_atoms_mask
|
| 1037 |
+
|
| 1038 |
+
torsions_atom_pos = torch.cat(
|
| 1039 |
+
[
|
| 1040 |
+
pre_omega_atom_pos[..., None, :, :],
|
| 1041 |
+
phi_atom_pos[..., None, :, :],
|
| 1042 |
+
psi_atom_pos[..., None, :, :],
|
| 1043 |
+
chis_atom_pos,
|
| 1044 |
+
],
|
| 1045 |
+
dim=-3,
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
torsion_angles_mask = torch.cat(
|
| 1049 |
+
[
|
| 1050 |
+
pre_omega_mask[..., None],
|
| 1051 |
+
phi_mask[..., None],
|
| 1052 |
+
psi_mask[..., None],
|
| 1053 |
+
chis_mask,
|
| 1054 |
+
],
|
| 1055 |
+
dim=-1,
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
torsion_frames = Rigid.from_3_points(
|
| 1059 |
+
torsions_atom_pos[..., 1, :],
|
| 1060 |
+
torsions_atom_pos[..., 2, :],
|
| 1061 |
+
torsions_atom_pos[..., 0, :],
|
| 1062 |
+
eps=1e-8,
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
fourth_atom_rel_pos = torsion_frames.invert().apply(
|
| 1066 |
+
torsions_atom_pos[..., 3, :]
|
| 1067 |
+
)
|
| 1068 |
+
|
| 1069 |
+
torsion_angles_sin_cos = torch.stack(
|
| 1070 |
+
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
denom = torch.sqrt(
|
| 1074 |
+
torch.sum(
|
| 1075 |
+
torch.square(torsion_angles_sin_cos),
|
| 1076 |
+
dim=-1,
|
| 1077 |
+
dtype=torsion_angles_sin_cos.dtype,
|
| 1078 |
+
keepdims=True,
|
| 1079 |
+
)
|
| 1080 |
+
+ 1e-8
|
| 1081 |
+
)
|
| 1082 |
+
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
|
| 1083 |
+
|
| 1084 |
+
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
|
| 1085 |
+
[1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
|
| 1086 |
+
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
|
| 1087 |
+
|
| 1088 |
+
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
|
| 1089 |
+
rc.chi_pi_periodic,
|
| 1090 |
+
)[aatype, ...]
|
| 1091 |
+
|
| 1092 |
+
mirror_torsion_angles = torch.cat(
|
| 1093 |
+
[
|
| 1094 |
+
all_atom_mask.new_ones(*aatype.shape, 3),
|
| 1095 |
+
1.0 - 2.0 * chi_is_ambiguous,
|
| 1096 |
+
],
|
| 1097 |
+
dim=-1,
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
alt_torsion_angles_sin_cos = (
|
| 1101 |
+
torsion_angles_sin_cos * mirror_torsion_angles[..., None]
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
|
| 1105 |
+
protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
|
| 1106 |
+
protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
|
| 1107 |
+
|
| 1108 |
+
return protein
|
| 1109 |
+
|
| 1110 |
+
|
| 1111 |
+
def get_backbone_frames(protein):
|
| 1112 |
+
# DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
|
| 1113 |
+
protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
|
| 1114 |
+
..., 0, :, :
|
| 1115 |
+
]
|
| 1116 |
+
protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
|
| 1117 |
+
|
| 1118 |
+
return protein
|
| 1119 |
+
|
| 1120 |
+
|
| 1121 |
+
def get_chi_angles(protein):
|
| 1122 |
+
dtype = protein["all_atom_mask"].dtype
|
| 1123 |
+
protein["chi_angles_sin_cos"] = (
|
| 1124 |
+
protein["torsion_angles_sin_cos"][..., 3:, :]
|
| 1125 |
+
).to(dtype)
|
| 1126 |
+
protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
|
| 1127 |
+
|
| 1128 |
+
return protein
|
| 1129 |
+
|
| 1130 |
+
|
| 1131 |
+
@curry1
|
| 1132 |
+
def random_crop_to_size(
|
| 1133 |
+
protein,
|
| 1134 |
+
crop_size,
|
| 1135 |
+
max_templates,
|
| 1136 |
+
shape_schema,
|
| 1137 |
+
subsample_templates=False,
|
| 1138 |
+
seed=None,
|
| 1139 |
+
):
|
| 1140 |
+
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
|
| 1141 |
+
# We want each ensemble to be cropped the same way
|
| 1142 |
+
g = torch.Generator(device=protein["seq_length"].device)
|
| 1143 |
+
if seed is not None:
|
| 1144 |
+
g.manual_seed(seed)
|
| 1145 |
+
|
| 1146 |
+
seq_length = protein["seq_length"]
|
| 1147 |
+
|
| 1148 |
+
if "template_mask" in protein:
|
| 1149 |
+
num_templates = protein["template_mask"].shape[-1]
|
| 1150 |
+
else:
|
| 1151 |
+
num_templates = 0
|
| 1152 |
+
|
| 1153 |
+
# No need to subsample templates if there aren't any
|
| 1154 |
+
subsample_templates = subsample_templates and num_templates
|
| 1155 |
+
|
| 1156 |
+
num_res_crop_size = min(int(seq_length), crop_size)
|
| 1157 |
+
|
| 1158 |
+
def _randint(lower, upper):
|
| 1159 |
+
return int(torch.randint(
|
| 1160 |
+
lower,
|
| 1161 |
+
upper + 1,
|
| 1162 |
+
(1,),
|
| 1163 |
+
device=protein["seq_length"].device,
|
| 1164 |
+
generator=g,
|
| 1165 |
+
)[0])
|
| 1166 |
+
|
| 1167 |
+
if subsample_templates:
|
| 1168 |
+
templates_crop_start = _randint(0, num_templates)
|
| 1169 |
+
templates_select_indices = torch.randperm(
|
| 1170 |
+
num_templates, device=protein["seq_length"].device, generator=g
|
| 1171 |
+
)
|
| 1172 |
+
else:
|
| 1173 |
+
templates_crop_start = 0
|
| 1174 |
+
|
| 1175 |
+
num_templates_crop_size = min(
|
| 1176 |
+
num_templates - templates_crop_start, max_templates
|
| 1177 |
+
)
|
| 1178 |
+
|
| 1179 |
+
n = seq_length - num_res_crop_size
|
| 1180 |
+
if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
|
| 1181 |
+
right_anchor = n
|
| 1182 |
+
else:
|
| 1183 |
+
x = _randint(0, n)
|
| 1184 |
+
right_anchor = n - x
|
| 1185 |
+
|
| 1186 |
+
num_res_crop_start = _randint(0, right_anchor)
|
| 1187 |
+
|
| 1188 |
+
for k, v in protein.items():
|
| 1189 |
+
if k not in shape_schema or (
|
| 1190 |
+
"template" not in k and NUM_RES not in shape_schema[k]
|
| 1191 |
+
):
|
| 1192 |
+
continue
|
| 1193 |
+
|
| 1194 |
+
# randomly permute the templates before cropping them.
|
| 1195 |
+
if k.startswith("template") and subsample_templates:
|
| 1196 |
+
v = v[templates_select_indices]
|
| 1197 |
+
|
| 1198 |
+
slices = []
|
| 1199 |
+
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
|
| 1200 |
+
is_num_res = dim_size == NUM_RES
|
| 1201 |
+
if i == 0 and k.startswith("template"):
|
| 1202 |
+
crop_size = num_templates_crop_size
|
| 1203 |
+
crop_start = templates_crop_start
|
| 1204 |
+
else:
|
| 1205 |
+
crop_start = num_res_crop_start if is_num_res else 0
|
| 1206 |
+
crop_size = num_res_crop_size if is_num_res else dim
|
| 1207 |
+
slices.append(slice(crop_start, crop_start + crop_size))
|
| 1208 |
+
protein[k] = v[slices]
|
| 1209 |
+
|
| 1210 |
+
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
|
| 1211 |
+
|
| 1212 |
+
return protein
|
openfold/data/errors.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""General-purpose errors used throughout the data pipeline"""
|
| 17 |
+
class Error(Exception):
|
| 18 |
+
"""Base class for exceptions."""
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MultipleChainsError(Error):
|
| 22 |
+
"""An error indicating that multiple chains were found for a given ID."""
|
openfold/data/feature_pipeline.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
from typing import Mapping, Tuple, List, Optional, Dict, Sequence
|
| 18 |
+
|
| 19 |
+
import ml_collections
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from openfold.data import input_pipeline
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
FeatureDict = Mapping[str, np.ndarray]
|
| 27 |
+
TensorDict = Dict[str, torch.Tensor]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def np_to_tensor_dict(
|
| 31 |
+
np_example: Mapping[str, np.ndarray],
|
| 32 |
+
features: Sequence[str],
|
| 33 |
+
) -> TensorDict:
|
| 34 |
+
"""Creates dict of tensors from a dict of NumPy arrays.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
np_example: A dict of NumPy feature arrays.
|
| 38 |
+
features: A list of strings of feature names to be returned in the dataset.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
A dictionary of features mapping feature names to features. Only the given
|
| 42 |
+
features are returned, all other ones are filtered out.
|
| 43 |
+
"""
|
| 44 |
+
tensor_dict = {
|
| 45 |
+
k: torch.tensor(v) for k, v in np_example.items() if k in features
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
return tensor_dict
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def make_data_config(
|
| 52 |
+
config: ml_collections.ConfigDict,
|
| 53 |
+
mode: str,
|
| 54 |
+
num_res: int,
|
| 55 |
+
) -> Tuple[ml_collections.ConfigDict, List[str]]:
|
| 56 |
+
cfg = copy.deepcopy(config)
|
| 57 |
+
mode_cfg = cfg[mode]
|
| 58 |
+
with cfg.unlocked():
|
| 59 |
+
if mode_cfg.crop_size is None:
|
| 60 |
+
mode_cfg.crop_size = num_res
|
| 61 |
+
|
| 62 |
+
feature_names = cfg.common.unsupervised_features
|
| 63 |
+
|
| 64 |
+
if cfg.common.use_templates:
|
| 65 |
+
feature_names += cfg.common.template_features
|
| 66 |
+
|
| 67 |
+
if cfg[mode].supervised:
|
| 68 |
+
feature_names += cfg.supervised.supervised_features
|
| 69 |
+
|
| 70 |
+
return cfg, feature_names
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def np_example_to_features(
|
| 74 |
+
np_example: FeatureDict,
|
| 75 |
+
config: ml_collections.ConfigDict,
|
| 76 |
+
mode: str,
|
| 77 |
+
):
|
| 78 |
+
np_example = dict(np_example)
|
| 79 |
+
num_res = int(np_example["seq_length"][0])
|
| 80 |
+
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
|
| 81 |
+
|
| 82 |
+
if "deletion_matrix_int" in np_example:
|
| 83 |
+
np_example["deletion_matrix"] = np_example.pop(
|
| 84 |
+
"deletion_matrix_int"
|
| 85 |
+
).astype(np.float32)
|
| 86 |
+
|
| 87 |
+
tensor_dict = np_to_tensor_dict(
|
| 88 |
+
np_example=np_example, features=feature_names
|
| 89 |
+
)
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
features = input_pipeline.process_tensors_from_config(
|
| 92 |
+
tensor_dict,
|
| 93 |
+
cfg.common,
|
| 94 |
+
cfg[mode],
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return {k: v for k, v in features.items()}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class FeaturePipeline:
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
config: ml_collections.ConfigDict,
|
| 104 |
+
):
|
| 105 |
+
self.config = config
|
| 106 |
+
|
| 107 |
+
def process_features(
|
| 108 |
+
self,
|
| 109 |
+
raw_features: FeatureDict,
|
| 110 |
+
mode: str = "train",
|
| 111 |
+
) -> FeatureDict:
|
| 112 |
+
return np_example_to_features(
|
| 113 |
+
np_example=raw_features,
|
| 114 |
+
config=self.config,
|
| 115 |
+
mode=mode,
|
| 116 |
+
)
|
openfold/data/input_pipeline.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from functools import partial
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from openfold.data import data_transforms
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def nonensembled_transform_fns(common_cfg, mode_cfg):
|
| 24 |
+
"""Input pipeline data transformers that are not ensembled."""
|
| 25 |
+
transforms = [
|
| 26 |
+
data_transforms.cast_to_64bit_ints,
|
| 27 |
+
data_transforms.correct_msa_restypes,
|
| 28 |
+
data_transforms.squeeze_features,
|
| 29 |
+
data_transforms.randomly_replace_msa_with_unknown(0.0),
|
| 30 |
+
data_transforms.make_seq_mask,
|
| 31 |
+
data_transforms.make_msa_mask,
|
| 32 |
+
data_transforms.make_hhblits_profile,
|
| 33 |
+
]
|
| 34 |
+
if common_cfg.use_templates:
|
| 35 |
+
transforms.extend(
|
| 36 |
+
[
|
| 37 |
+
data_transforms.fix_templates_aatype,
|
| 38 |
+
data_transforms.make_template_mask,
|
| 39 |
+
data_transforms.make_pseudo_beta("template_"),
|
| 40 |
+
]
|
| 41 |
+
)
|
| 42 |
+
if common_cfg.use_template_torsion_angles:
|
| 43 |
+
transforms.extend(
|
| 44 |
+
[
|
| 45 |
+
data_transforms.atom37_to_torsion_angles("template_"),
|
| 46 |
+
]
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
transforms.extend(
|
| 50 |
+
[
|
| 51 |
+
data_transforms.make_atom14_masks,
|
| 52 |
+
]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if mode_cfg.supervised:
|
| 56 |
+
transforms.extend(
|
| 57 |
+
[
|
| 58 |
+
data_transforms.make_atom14_positions,
|
| 59 |
+
data_transforms.atom37_to_frames,
|
| 60 |
+
data_transforms.atom37_to_torsion_angles(""),
|
| 61 |
+
data_transforms.make_pseudo_beta(""),
|
| 62 |
+
data_transforms.get_backbone_frames,
|
| 63 |
+
data_transforms.get_chi_angles,
|
| 64 |
+
]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
return transforms
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
|
| 71 |
+
"""Input pipeline data transformers that can be ensembled and averaged."""
|
| 72 |
+
transforms = []
|
| 73 |
+
|
| 74 |
+
if "max_distillation_msa_clusters" in mode_cfg:
|
| 75 |
+
transforms.append(
|
| 76 |
+
data_transforms.sample_msa_distillation(
|
| 77 |
+
mode_cfg.max_distillation_msa_clusters
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
if common_cfg.reduce_msa_clusters_by_max_templates:
|
| 82 |
+
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
|
| 83 |
+
else:
|
| 84 |
+
pad_msa_clusters = mode_cfg.max_msa_clusters
|
| 85 |
+
|
| 86 |
+
max_msa_clusters = pad_msa_clusters
|
| 87 |
+
max_extra_msa = mode_cfg.max_extra_msa
|
| 88 |
+
|
| 89 |
+
msa_seed = None
|
| 90 |
+
if(not common_cfg.resample_msa_in_recycling):
|
| 91 |
+
msa_seed = ensemble_seed
|
| 92 |
+
|
| 93 |
+
transforms.append(
|
| 94 |
+
data_transforms.sample_msa(
|
| 95 |
+
max_msa_clusters,
|
| 96 |
+
keep_extra=True,
|
| 97 |
+
seed=msa_seed,
|
| 98 |
+
)
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if "masked_msa" in common_cfg:
|
| 102 |
+
# Masked MSA should come *before* MSA clustering so that
|
| 103 |
+
# the clustering and full MSA profile do not leak information about
|
| 104 |
+
# the masked locations and secret corrupted locations.
|
| 105 |
+
transforms.append(
|
| 106 |
+
data_transforms.make_masked_msa(
|
| 107 |
+
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
if common_cfg.msa_cluster_features:
|
| 112 |
+
transforms.append(data_transforms.nearest_neighbor_clusters())
|
| 113 |
+
transforms.append(data_transforms.summarize_clusters())
|
| 114 |
+
|
| 115 |
+
# Crop after creating the cluster profiles.
|
| 116 |
+
if max_extra_msa:
|
| 117 |
+
transforms.append(data_transforms.crop_extra_msa(max_extra_msa))
|
| 118 |
+
else:
|
| 119 |
+
transforms.append(data_transforms.delete_extra_msa)
|
| 120 |
+
|
| 121 |
+
transforms.append(data_transforms.make_msa_feat())
|
| 122 |
+
|
| 123 |
+
crop_feats = dict(common_cfg.feat)
|
| 124 |
+
|
| 125 |
+
if mode_cfg.fixed_size:
|
| 126 |
+
transforms.append(data_transforms.select_feat(list(crop_feats)))
|
| 127 |
+
transforms.append(
|
| 128 |
+
data_transforms.random_crop_to_size(
|
| 129 |
+
mode_cfg.crop_size,
|
| 130 |
+
mode_cfg.max_templates,
|
| 131 |
+
crop_feats,
|
| 132 |
+
mode_cfg.subsample_templates,
|
| 133 |
+
seed=ensemble_seed + 1,
|
| 134 |
+
)
|
| 135 |
+
)
|
| 136 |
+
transforms.append(
|
| 137 |
+
data_transforms.make_fixed_size(
|
| 138 |
+
crop_feats,
|
| 139 |
+
pad_msa_clusters,
|
| 140 |
+
mode_cfg.max_extra_msa,
|
| 141 |
+
mode_cfg.crop_size,
|
| 142 |
+
mode_cfg.max_templates,
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
transforms.append(
|
| 147 |
+
data_transforms.crop_templates(mode_cfg.max_templates)
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return transforms
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
|
| 154 |
+
"""Based on the config, apply filters and transformations to the data."""
|
| 155 |
+
|
| 156 |
+
ensemble_seed = torch.Generator().seed()
|
| 157 |
+
|
| 158 |
+
def wrap_ensemble_fn(data, i):
|
| 159 |
+
"""Function to be mapped over the ensemble dimension."""
|
| 160 |
+
d = data.copy()
|
| 161 |
+
fns = ensembled_transform_fns(
|
| 162 |
+
common_cfg,
|
| 163 |
+
mode_cfg,
|
| 164 |
+
ensemble_seed,
|
| 165 |
+
)
|
| 166 |
+
fn = compose(fns)
|
| 167 |
+
d["ensemble_index"] = i
|
| 168 |
+
return fn(d)
|
| 169 |
+
|
| 170 |
+
no_templates = True
|
| 171 |
+
if("template_aatype" in tensors):
|
| 172 |
+
no_templates = tensors["template_aatype"].shape[0] == 0
|
| 173 |
+
|
| 174 |
+
nonensembled = nonensembled_transform_fns(
|
| 175 |
+
common_cfg,
|
| 176 |
+
mode_cfg,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
tensors = compose(nonensembled)(tensors)
|
| 180 |
+
|
| 181 |
+
if("no_recycling_iters" in tensors):
|
| 182 |
+
num_recycling = int(tensors["no_recycling_iters"])
|
| 183 |
+
else:
|
| 184 |
+
num_recycling = common_cfg.max_recycling_iters
|
| 185 |
+
|
| 186 |
+
tensors = map_fn(
|
| 187 |
+
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
return tensors
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@data_transforms.curry1
|
| 194 |
+
def compose(x, fs):
|
| 195 |
+
for f in fs:
|
| 196 |
+
x = f(x)
|
| 197 |
+
return x
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def map_fn(fun, x):
|
| 201 |
+
ensembles = [fun(elem) for elem in x]
|
| 202 |
+
features = ensembles[0].keys()
|
| 203 |
+
ensembled_dict = {}
|
| 204 |
+
for feat in features:
|
| 205 |
+
ensembled_dict[feat] = torch.stack(
|
| 206 |
+
[dict_i[feat] for dict_i in ensembles], dim=-1
|
| 207 |
+
)
|
| 208 |
+
return ensembled_dict
|
openfold/data/mmcif_parsing.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Parses the mmCIF file format."""
|
| 17 |
+
import collections
|
| 18 |
+
import dataclasses
|
| 19 |
+
import io
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import os
|
| 23 |
+
from typing import Any, Mapping, Optional, Sequence, Tuple
|
| 24 |
+
|
| 25 |
+
from Bio import PDB
|
| 26 |
+
from Bio.Data import SCOPData
|
| 27 |
+
import numpy as np
|
| 28 |
+
|
| 29 |
+
from openfold.data.errors import MultipleChainsError
|
| 30 |
+
import openfold.np.residue_constants as residue_constants
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Type aliases:
|
| 34 |
+
ChainId = str
|
| 35 |
+
PdbHeader = Mapping[str, Any]
|
| 36 |
+
PdbStructure = PDB.Structure.Structure
|
| 37 |
+
SeqRes = str
|
| 38 |
+
MmCIFDict = Mapping[str, Sequence[str]]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclasses.dataclass(frozen=True)
|
| 42 |
+
class Monomer:
|
| 43 |
+
id: str
|
| 44 |
+
num: int
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Note - mmCIF format provides no guarantees on the type of author-assigned
|
| 48 |
+
# sequence numbers. They need not be integers.
|
| 49 |
+
@dataclasses.dataclass(frozen=True)
|
| 50 |
+
class AtomSite:
|
| 51 |
+
residue_name: str
|
| 52 |
+
author_chain_id: str
|
| 53 |
+
mmcif_chain_id: str
|
| 54 |
+
author_seq_num: str
|
| 55 |
+
mmcif_seq_num: int
|
| 56 |
+
insertion_code: str
|
| 57 |
+
hetatm_atom: str
|
| 58 |
+
model_num: int
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Used to map SEQRES index to a residue in the structure.
|
| 62 |
+
@dataclasses.dataclass(frozen=True)
|
| 63 |
+
class ResiduePosition:
|
| 64 |
+
chain_id: str
|
| 65 |
+
residue_number: int
|
| 66 |
+
insertion_code: str
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dataclasses.dataclass(frozen=True)
|
| 70 |
+
class ResidueAtPosition:
|
| 71 |
+
position: Optional[ResiduePosition]
|
| 72 |
+
name: str
|
| 73 |
+
is_missing: bool
|
| 74 |
+
hetflag: str
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclasses.dataclass(frozen=True)
|
| 78 |
+
class MmcifObject:
|
| 79 |
+
"""Representation of a parsed mmCIF file.
|
| 80 |
+
|
| 81 |
+
Contains:
|
| 82 |
+
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
|
| 83 |
+
files being processed.
|
| 84 |
+
header: Biopython header.
|
| 85 |
+
structure: Biopython structure.
|
| 86 |
+
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
|
| 87 |
+
{'A': 'ABCDEFG'}
|
| 88 |
+
seqres_to_structure: Dict; for each chain_id contains a mapping between
|
| 89 |
+
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
|
| 90 |
+
1: ResidueAtPosition,
|
| 91 |
+
...}}
|
| 92 |
+
raw_string: The raw string used to construct the MmcifObject.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
file_id: str
|
| 96 |
+
header: PdbHeader
|
| 97 |
+
structure: PdbStructure
|
| 98 |
+
chain_to_seqres: Mapping[ChainId, SeqRes]
|
| 99 |
+
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
|
| 100 |
+
raw_string: Any
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@dataclasses.dataclass(frozen=True)
|
| 104 |
+
class ParsingResult:
|
| 105 |
+
"""Returned by the parse function.
|
| 106 |
+
|
| 107 |
+
Contains:
|
| 108 |
+
mmcif_object: A MmcifObject, may be None if no chain could be successfully
|
| 109 |
+
parsed.
|
| 110 |
+
errors: A dict mapping (file_id, chain_id) to any exception generated.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
mmcif_object: Optional[MmcifObject]
|
| 114 |
+
errors: Mapping[Tuple[str, str], Any]
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class ParseError(Exception):
|
| 118 |
+
"""An error indicating that an mmCIF file could not be parsed."""
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def mmcif_loop_to_list(
|
| 122 |
+
prefix: str, parsed_info: MmCIFDict
|
| 123 |
+
) -> Sequence[Mapping[str, str]]:
|
| 124 |
+
"""Extracts loop associated with a prefix from mmCIF data as a list.
|
| 125 |
+
|
| 126 |
+
Reference for loop_ in mmCIF:
|
| 127 |
+
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
prefix: Prefix shared by each of the data items in the loop.
|
| 131 |
+
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
|
| 132 |
+
_entity_poly_seq.mon_id. Should include the trailing period.
|
| 133 |
+
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
|
| 134 |
+
parser.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
|
| 138 |
+
"""
|
| 139 |
+
cols = []
|
| 140 |
+
data = []
|
| 141 |
+
for key, value in parsed_info.items():
|
| 142 |
+
if key.startswith(prefix):
|
| 143 |
+
cols.append(key)
|
| 144 |
+
data.append(value)
|
| 145 |
+
|
| 146 |
+
assert all([len(xs) == len(data[0]) for xs in data]), (
|
| 147 |
+
"mmCIF error: Not all loops are the same length: %s" % cols
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return [dict(zip(cols, xs)) for xs in zip(*data)]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def mmcif_loop_to_dict(
|
| 154 |
+
prefix: str,
|
| 155 |
+
index: str,
|
| 156 |
+
parsed_info: MmCIFDict,
|
| 157 |
+
) -> Mapping[str, Mapping[str, str]]:
|
| 158 |
+
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
prefix: Prefix shared by each of the data items in the loop.
|
| 162 |
+
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
|
| 163 |
+
_entity_poly_seq.mon_id. Should include the trailing period.
|
| 164 |
+
index: Which item of loop data should serve as the key.
|
| 165 |
+
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
|
| 166 |
+
parser.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
|
| 170 |
+
indexed by the index column.
|
| 171 |
+
"""
|
| 172 |
+
entries = mmcif_loop_to_list(prefix, parsed_info)
|
| 173 |
+
return {entry[index]: entry for entry in entries}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def parse(
|
| 177 |
+
*, file_id: str, mmcif_string: str, catch_all_errors: bool = True
|
| 178 |
+
) -> ParsingResult:
|
| 179 |
+
"""Entry point, parses an mmcif_string.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
file_id: A string identifier for this file. Should be unique within the
|
| 183 |
+
collection of files being processed.
|
| 184 |
+
mmcif_string: Contents of an mmCIF file.
|
| 185 |
+
catch_all_errors: If True, all exceptions are caught and error messages are
|
| 186 |
+
returned as part of the ParsingResult. If False exceptions will be allowed
|
| 187 |
+
to propagate.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
A ParsingResult.
|
| 191 |
+
"""
|
| 192 |
+
errors = {}
|
| 193 |
+
try:
|
| 194 |
+
parser = PDB.MMCIFParser(QUIET=True)
|
| 195 |
+
handle = io.StringIO(mmcif_string)
|
| 196 |
+
full_structure = parser.get_structure("", handle)
|
| 197 |
+
first_model_structure = _get_first_model(full_structure)
|
| 198 |
+
# Extract the _mmcif_dict from the parser, which contains useful fields not
|
| 199 |
+
# reflected in the Biopython structure.
|
| 200 |
+
parsed_info = parser._mmcif_dict # pylint:disable=protected-access
|
| 201 |
+
|
| 202 |
+
# Ensure all values are lists, even if singletons.
|
| 203 |
+
for key, value in parsed_info.items():
|
| 204 |
+
if not isinstance(value, list):
|
| 205 |
+
parsed_info[key] = [value]
|
| 206 |
+
|
| 207 |
+
header = _get_header(parsed_info)
|
| 208 |
+
|
| 209 |
+
# Determine the protein chains, and their start numbers according to the
|
| 210 |
+
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
|
| 211 |
+
valid_chains = _get_protein_chains(parsed_info=parsed_info)
|
| 212 |
+
if not valid_chains:
|
| 213 |
+
return ParsingResult(
|
| 214 |
+
None, {(file_id, ""): "No protein chains found in this file."}
|
| 215 |
+
)
|
| 216 |
+
seq_start_num = {
|
| 217 |
+
chain_id: min([monomer.num for monomer in seq])
|
| 218 |
+
for chain_id, seq in valid_chains.items()
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
# Loop over the atoms for which we have coordinates. Populate two mappings:
|
| 222 |
+
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
|
| 223 |
+
# the authors / Biopython).
|
| 224 |
+
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
|
| 225 |
+
mmcif_to_author_chain_id = {}
|
| 226 |
+
seq_to_structure_mappings = {}
|
| 227 |
+
for atom in _get_atom_site_list(parsed_info):
|
| 228 |
+
if atom.model_num != "1":
|
| 229 |
+
# We only process the first model at the moment.
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
|
| 233 |
+
|
| 234 |
+
if atom.mmcif_chain_id in valid_chains:
|
| 235 |
+
hetflag = " "
|
| 236 |
+
if atom.hetatm_atom == "HETATM":
|
| 237 |
+
# Water atoms are assigned a special hetflag of W in Biopython. We
|
| 238 |
+
# need to do the same, so that this hetflag can be used to fetch
|
| 239 |
+
# a residue from the Biopython structure by id.
|
| 240 |
+
if atom.residue_name in ("HOH", "WAT"):
|
| 241 |
+
hetflag = "W"
|
| 242 |
+
else:
|
| 243 |
+
hetflag = "H_" + atom.residue_name
|
| 244 |
+
insertion_code = atom.insertion_code
|
| 245 |
+
if not _is_set(atom.insertion_code):
|
| 246 |
+
insertion_code = " "
|
| 247 |
+
position = ResiduePosition(
|
| 248 |
+
chain_id=atom.author_chain_id,
|
| 249 |
+
residue_number=int(atom.author_seq_num),
|
| 250 |
+
insertion_code=insertion_code,
|
| 251 |
+
)
|
| 252 |
+
seq_idx = (
|
| 253 |
+
int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
|
| 254 |
+
)
|
| 255 |
+
current = seq_to_structure_mappings.get(
|
| 256 |
+
atom.author_chain_id, {}
|
| 257 |
+
)
|
| 258 |
+
current[seq_idx] = ResidueAtPosition(
|
| 259 |
+
position=position,
|
| 260 |
+
name=atom.residue_name,
|
| 261 |
+
is_missing=False,
|
| 262 |
+
hetflag=hetflag,
|
| 263 |
+
)
|
| 264 |
+
seq_to_structure_mappings[atom.author_chain_id] = current
|
| 265 |
+
|
| 266 |
+
# Add missing residue information to seq_to_structure_mappings.
|
| 267 |
+
for chain_id, seq_info in valid_chains.items():
|
| 268 |
+
author_chain = mmcif_to_author_chain_id[chain_id]
|
| 269 |
+
current_mapping = seq_to_structure_mappings[author_chain]
|
| 270 |
+
for idx, monomer in enumerate(seq_info):
|
| 271 |
+
if idx not in current_mapping:
|
| 272 |
+
current_mapping[idx] = ResidueAtPosition(
|
| 273 |
+
position=None,
|
| 274 |
+
name=monomer.id,
|
| 275 |
+
is_missing=True,
|
| 276 |
+
hetflag=" ",
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
author_chain_to_sequence = {}
|
| 280 |
+
for chain_id, seq_info in valid_chains.items():
|
| 281 |
+
author_chain = mmcif_to_author_chain_id[chain_id]
|
| 282 |
+
seq = []
|
| 283 |
+
for monomer in seq_info:
|
| 284 |
+
code = SCOPData.protein_letters_3to1.get(monomer.id, "X")
|
| 285 |
+
seq.append(code if len(code) == 1 else "X")
|
| 286 |
+
seq = "".join(seq)
|
| 287 |
+
author_chain_to_sequence[author_chain] = seq
|
| 288 |
+
|
| 289 |
+
mmcif_object = MmcifObject(
|
| 290 |
+
file_id=file_id,
|
| 291 |
+
header=header,
|
| 292 |
+
structure=first_model_structure,
|
| 293 |
+
chain_to_seqres=author_chain_to_sequence,
|
| 294 |
+
seqres_to_structure=seq_to_structure_mappings,
|
| 295 |
+
raw_string=parsed_info,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
return ParsingResult(mmcif_object=mmcif_object, errors=errors)
|
| 299 |
+
except Exception as e: # pylint:disable=broad-except
|
| 300 |
+
errors[(file_id, "")] = e
|
| 301 |
+
if not catch_all_errors:
|
| 302 |
+
raise
|
| 303 |
+
return ParsingResult(mmcif_object=None, errors=errors)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _get_first_model(structure: PdbStructure) -> PdbStructure:
|
| 307 |
+
"""Returns the first model in a Biopython structure."""
|
| 308 |
+
return next(structure.get_models())
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def get_release_date(parsed_info: MmCIFDict) -> str:
|
| 315 |
+
"""Returns the oldest revision date."""
|
| 316 |
+
revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"]
|
| 317 |
+
return min(revision_dates)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
|
| 321 |
+
"""Returns a basic header containing method, release date and resolution."""
|
| 322 |
+
header = {}
|
| 323 |
+
|
| 324 |
+
experiments = mmcif_loop_to_list("_exptl.", parsed_info)
|
| 325 |
+
header["structure_method"] = ",".join(
|
| 326 |
+
[experiment["_exptl.method"].lower() for experiment in experiments]
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# Note: The release_date here corresponds to the oldest revision. We prefer to
|
| 330 |
+
# use this for dataset filtering over the deposition_date.
|
| 331 |
+
if "_pdbx_audit_revision_history.revision_date" in parsed_info:
|
| 332 |
+
header["release_date"] = get_release_date(parsed_info)
|
| 333 |
+
else:
|
| 334 |
+
logging.warning(
|
| 335 |
+
"Could not determine release_date: %s", parsed_info["_entry.id"]
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
header["resolution"] = 0.00
|
| 339 |
+
for res_key in (
|
| 340 |
+
"_refine.ls_d_res_high",
|
| 341 |
+
"_em_3d_reconstruction.resolution",
|
| 342 |
+
"_reflns.d_resolution_high",
|
| 343 |
+
):
|
| 344 |
+
if res_key in parsed_info:
|
| 345 |
+
try:
|
| 346 |
+
raw_resolution = parsed_info[res_key][0]
|
| 347 |
+
header["resolution"] = float(raw_resolution)
|
| 348 |
+
except ValueError:
|
| 349 |
+
logging.info(
|
| 350 |
+
"Invalid resolution format: %s", parsed_info[res_key]
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
return header
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
|
| 357 |
+
"""Returns list of atom sites; contains data not present in the structure."""
|
| 358 |
+
return [
|
| 359 |
+
AtomSite(*site)
|
| 360 |
+
for site in zip( # pylint:disable=g-complex-comprehension
|
| 361 |
+
parsed_info["_atom_site.label_comp_id"],
|
| 362 |
+
parsed_info["_atom_site.auth_asym_id"],
|
| 363 |
+
parsed_info["_atom_site.label_asym_id"],
|
| 364 |
+
parsed_info["_atom_site.auth_seq_id"],
|
| 365 |
+
parsed_info["_atom_site.label_seq_id"],
|
| 366 |
+
parsed_info["_atom_site.pdbx_PDB_ins_code"],
|
| 367 |
+
parsed_info["_atom_site.group_PDB"],
|
| 368 |
+
parsed_info["_atom_site.pdbx_PDB_model_num"],
|
| 369 |
+
)
|
| 370 |
+
]
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def _get_protein_chains(
|
| 374 |
+
*, parsed_info: Mapping[str, Any]
|
| 375 |
+
) -> Mapping[ChainId, Sequence[Monomer]]:
|
| 376 |
+
"""Extracts polymer information for protein chains only.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
parsed_info: _mmcif_dict produced by the Biopython parser.
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
A dict mapping mmcif chain id to a list of Monomers.
|
| 383 |
+
"""
|
| 384 |
+
# Get polymer information for each entity in the structure.
|
| 385 |
+
entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info)
|
| 386 |
+
|
| 387 |
+
polymers = collections.defaultdict(list)
|
| 388 |
+
for entity_poly_seq in entity_poly_seqs:
|
| 389 |
+
polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append(
|
| 390 |
+
Monomer(
|
| 391 |
+
id=entity_poly_seq["_entity_poly_seq.mon_id"],
|
| 392 |
+
num=int(entity_poly_seq["_entity_poly_seq.num"]),
|
| 393 |
+
)
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# Get chemical compositions. Will allow us to identify which of these polymers
|
| 397 |
+
# are proteins.
|
| 398 |
+
chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info)
|
| 399 |
+
|
| 400 |
+
# Get chains information for each entity. Necessary so that we can return a
|
| 401 |
+
# dict keyed on chain id rather than entity.
|
| 402 |
+
struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info)
|
| 403 |
+
|
| 404 |
+
entity_to_mmcif_chains = collections.defaultdict(list)
|
| 405 |
+
for struct_asym in struct_asyms:
|
| 406 |
+
chain_id = struct_asym["_struct_asym.id"]
|
| 407 |
+
entity_id = struct_asym["_struct_asym.entity_id"]
|
| 408 |
+
entity_to_mmcif_chains[entity_id].append(chain_id)
|
| 409 |
+
|
| 410 |
+
# Identify and return the valid protein chains.
|
| 411 |
+
valid_chains = {}
|
| 412 |
+
for entity_id, seq_info in polymers.items():
|
| 413 |
+
chain_ids = entity_to_mmcif_chains[entity_id]
|
| 414 |
+
|
| 415 |
+
# Reject polymers without any peptide-like components, such as DNA/RNA.
|
| 416 |
+
if any(
|
| 417 |
+
[
|
| 418 |
+
"peptide" in chem_comps[monomer.id]["_chem_comp.type"]
|
| 419 |
+
for monomer in seq_info
|
| 420 |
+
]
|
| 421 |
+
):
|
| 422 |
+
for chain_id in chain_ids:
|
| 423 |
+
valid_chains[chain_id] = seq_info
|
| 424 |
+
return valid_chains
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def _is_set(data: str) -> bool:
|
| 428 |
+
"""Returns False if data is a special mmCIF character indicating 'unset'."""
|
| 429 |
+
return data not in (".", "?")
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def get_atom_coords(
|
| 433 |
+
mmcif_object: MmcifObject,
|
| 434 |
+
chain_id: str,
|
| 435 |
+
_zero_center_positions: bool = False
|
| 436 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 437 |
+
# Locate the right chain
|
| 438 |
+
chains = list(mmcif_object.structure.get_chains())
|
| 439 |
+
relevant_chains = [c for c in chains if c.id == chain_id]
|
| 440 |
+
if len(relevant_chains) != 1:
|
| 441 |
+
raise MultipleChainsError(
|
| 442 |
+
f"Expected exactly one chain in structure with id {chain_id}."
|
| 443 |
+
)
|
| 444 |
+
chain = relevant_chains[0]
|
| 445 |
+
|
| 446 |
+
# Extract the coordinates
|
| 447 |
+
num_res = len(mmcif_object.chain_to_seqres[chain_id])
|
| 448 |
+
all_atom_positions = np.zeros(
|
| 449 |
+
[num_res, residue_constants.atom_type_num, 3], dtype=np.float32
|
| 450 |
+
)
|
| 451 |
+
all_atom_mask = np.zeros(
|
| 452 |
+
[num_res, residue_constants.atom_type_num], dtype=np.float32
|
| 453 |
+
)
|
| 454 |
+
for res_index in range(num_res):
|
| 455 |
+
pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
|
| 456 |
+
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
|
| 457 |
+
res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
|
| 458 |
+
if not res_at_position.is_missing:
|
| 459 |
+
res = chain[
|
| 460 |
+
(
|
| 461 |
+
res_at_position.hetflag,
|
| 462 |
+
res_at_position.position.residue_number,
|
| 463 |
+
res_at_position.position.insertion_code,
|
| 464 |
+
)
|
| 465 |
+
]
|
| 466 |
+
for atom in res.get_atoms():
|
| 467 |
+
atom_name = atom.get_name()
|
| 468 |
+
x, y, z = atom.get_coord()
|
| 469 |
+
if atom_name in residue_constants.atom_order.keys():
|
| 470 |
+
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
|
| 471 |
+
mask[residue_constants.atom_order[atom_name]] = 1.0
|
| 472 |
+
elif atom_name.upper() == "SE" and res.get_resname() == "MSE":
|
| 473 |
+
# Put the coords of the selenium atom in the sulphur column
|
| 474 |
+
pos[residue_constants.atom_order["SD"]] = [x, y, z]
|
| 475 |
+
mask[residue_constants.atom_order["SD"]] = 1.0
|
| 476 |
+
|
| 477 |
+
all_atom_positions[res_index] = pos
|
| 478 |
+
all_atom_mask[res_index] = mask
|
| 479 |
+
|
| 480 |
+
if _zero_center_positions:
|
| 481 |
+
binary_mask = all_atom_mask.astype(bool)
|
| 482 |
+
translation_vec = all_atom_positions[binary_mask].mean(axis=0)
|
| 483 |
+
all_atom_positions[binary_mask] -= translation_vec
|
| 484 |
+
|
| 485 |
+
return all_atom_positions, all_atom_mask
|
openfold/data/parsers.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Functions for parsing various file formats."""
|
| 17 |
+
import collections
|
| 18 |
+
import dataclasses
|
| 19 |
+
import re
|
| 20 |
+
import string
|
| 21 |
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
DeletionMatrix = Sequence[Sequence[int]]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclasses.dataclass(frozen=True)
|
| 28 |
+
class TemplateHit:
|
| 29 |
+
"""Class representing a template hit."""
|
| 30 |
+
|
| 31 |
+
index: int
|
| 32 |
+
name: str
|
| 33 |
+
aligned_cols: int
|
| 34 |
+
sum_probs: float
|
| 35 |
+
query: str
|
| 36 |
+
hit_sequence: str
|
| 37 |
+
indices_query: List[int]
|
| 38 |
+
indices_hit: List[int]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
|
| 42 |
+
"""Parses FASTA string and returns list of strings with amino-acid sequences.
|
| 43 |
+
|
| 44 |
+
Arguments:
|
| 45 |
+
fasta_string: The string contents of a FASTA file.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
A tuple of two lists:
|
| 49 |
+
* A list of sequences.
|
| 50 |
+
* A list of sequence descriptions taken from the comment lines. In the
|
| 51 |
+
same order as the sequences.
|
| 52 |
+
"""
|
| 53 |
+
sequences = []
|
| 54 |
+
descriptions = []
|
| 55 |
+
index = -1
|
| 56 |
+
for line in fasta_string.splitlines():
|
| 57 |
+
line = line.strip()
|
| 58 |
+
if line.startswith(">"):
|
| 59 |
+
index += 1
|
| 60 |
+
descriptions.append(line[1:]) # Remove the '>' at the beginning.
|
| 61 |
+
sequences.append("")
|
| 62 |
+
continue
|
| 63 |
+
elif not line:
|
| 64 |
+
continue # Skip blank lines.
|
| 65 |
+
sequences[index] += line
|
| 66 |
+
|
| 67 |
+
return sequences, descriptions
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def parse_stockholm(
|
| 71 |
+
stockholm_string: str,
|
| 72 |
+
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
|
| 73 |
+
"""Parses sequences and deletion matrix from stockholm format alignment.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
stockholm_string: The string contents of a stockholm file. The first
|
| 77 |
+
sequence in the file should be the query sequence.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
A tuple of:
|
| 81 |
+
* A list of sequences that have been aligned to the query. These
|
| 82 |
+
might contain duplicates.
|
| 83 |
+
* The deletion matrix for the alignment as a list of lists. The element
|
| 84 |
+
at `deletion_matrix[i][j]` is the number of residues deleted from
|
| 85 |
+
the aligned sequence i at residue position j.
|
| 86 |
+
* The names of the targets matched, including the jackhmmer subsequence
|
| 87 |
+
suffix.
|
| 88 |
+
"""
|
| 89 |
+
name_to_sequence = collections.OrderedDict()
|
| 90 |
+
for line in stockholm_string.splitlines():
|
| 91 |
+
line = line.strip()
|
| 92 |
+
if not line or line.startswith(("#", "//")):
|
| 93 |
+
continue
|
| 94 |
+
name, sequence = line.split()
|
| 95 |
+
if name not in name_to_sequence:
|
| 96 |
+
name_to_sequence[name] = ""
|
| 97 |
+
name_to_sequence[name] += sequence
|
| 98 |
+
|
| 99 |
+
msa = []
|
| 100 |
+
deletion_matrix = []
|
| 101 |
+
|
| 102 |
+
query = ""
|
| 103 |
+
keep_columns = []
|
| 104 |
+
for seq_index, sequence in enumerate(name_to_sequence.values()):
|
| 105 |
+
if seq_index == 0:
|
| 106 |
+
# Gather the columns with gaps from the query
|
| 107 |
+
query = sequence
|
| 108 |
+
keep_columns = [i for i, res in enumerate(query) if res != "-"]
|
| 109 |
+
|
| 110 |
+
# Remove the columns with gaps in the query from all sequences.
|
| 111 |
+
aligned_sequence = "".join([sequence[c] for c in keep_columns])
|
| 112 |
+
|
| 113 |
+
msa.append(aligned_sequence)
|
| 114 |
+
|
| 115 |
+
# Count the number of deletions w.r.t. query.
|
| 116 |
+
deletion_vec = []
|
| 117 |
+
deletion_count = 0
|
| 118 |
+
for seq_res, query_res in zip(sequence, query):
|
| 119 |
+
if seq_res != "-" or query_res != "-":
|
| 120 |
+
if query_res == "-":
|
| 121 |
+
deletion_count += 1
|
| 122 |
+
else:
|
| 123 |
+
deletion_vec.append(deletion_count)
|
| 124 |
+
deletion_count = 0
|
| 125 |
+
deletion_matrix.append(deletion_vec)
|
| 126 |
+
|
| 127 |
+
return msa, deletion_matrix, list(name_to_sequence.keys())
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
|
| 131 |
+
"""Parses sequences and deletion matrix from a3m format alignment.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
a3m_string: The string contents of a a3m file. The first sequence in the
|
| 135 |
+
file should be the query sequence.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
A tuple of:
|
| 139 |
+
* A list of sequences that have been aligned to the query. These
|
| 140 |
+
might contain duplicates.
|
| 141 |
+
* The deletion matrix for the alignment as a list of lists. The element
|
| 142 |
+
at `deletion_matrix[i][j]` is the number of residues deleted from
|
| 143 |
+
the aligned sequence i at residue position j.
|
| 144 |
+
"""
|
| 145 |
+
sequences, _ = parse_fasta(a3m_string)
|
| 146 |
+
deletion_matrix = []
|
| 147 |
+
for msa_sequence in sequences:
|
| 148 |
+
deletion_vec = []
|
| 149 |
+
deletion_count = 0
|
| 150 |
+
for j in msa_sequence:
|
| 151 |
+
if j.islower():
|
| 152 |
+
deletion_count += 1
|
| 153 |
+
else:
|
| 154 |
+
deletion_vec.append(deletion_count)
|
| 155 |
+
deletion_count = 0
|
| 156 |
+
deletion_matrix.append(deletion_vec)
|
| 157 |
+
|
| 158 |
+
# Make the MSA matrix out of aligned (deletion-free) sequences.
|
| 159 |
+
deletion_table = str.maketrans("", "", string.ascii_lowercase)
|
| 160 |
+
aligned_sequences = [s.translate(deletion_table) for s in sequences]
|
| 161 |
+
return aligned_sequences, deletion_matrix
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _convert_sto_seq_to_a3m(
|
| 165 |
+
query_non_gaps: Sequence[bool], sto_seq: str
|
| 166 |
+
) -> Iterable[str]:
|
| 167 |
+
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
|
| 168 |
+
if is_query_res_non_gap:
|
| 169 |
+
yield sequence_res
|
| 170 |
+
elif sequence_res != "-":
|
| 171 |
+
yield sequence_res.lower()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def convert_stockholm_to_a3m(
|
| 175 |
+
stockholm_format: str, max_sequences: Optional[int] = None
|
| 176 |
+
) -> str:
|
| 177 |
+
"""Converts MSA in Stockholm format to the A3M format."""
|
| 178 |
+
descriptions = {}
|
| 179 |
+
sequences = {}
|
| 180 |
+
reached_max_sequences = False
|
| 181 |
+
|
| 182 |
+
for line in stockholm_format.splitlines():
|
| 183 |
+
reached_max_sequences = (
|
| 184 |
+
max_sequences and len(sequences) >= max_sequences
|
| 185 |
+
)
|
| 186 |
+
if line.strip() and not line.startswith(("#", "//")):
|
| 187 |
+
# Ignore blank lines, markup and end symbols - remainder are alignment
|
| 188 |
+
# sequence parts.
|
| 189 |
+
seqname, aligned_seq = line.split(maxsplit=1)
|
| 190 |
+
if seqname not in sequences:
|
| 191 |
+
if reached_max_sequences:
|
| 192 |
+
continue
|
| 193 |
+
sequences[seqname] = ""
|
| 194 |
+
sequences[seqname] += aligned_seq
|
| 195 |
+
|
| 196 |
+
for line in stockholm_format.splitlines():
|
| 197 |
+
if line[:4] == "#=GS":
|
| 198 |
+
# Description row - example format is:
|
| 199 |
+
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
|
| 200 |
+
columns = line.split(maxsplit=3)
|
| 201 |
+
seqname, feature = columns[1:3]
|
| 202 |
+
value = columns[3] if len(columns) == 4 else ""
|
| 203 |
+
if feature != "DE":
|
| 204 |
+
continue
|
| 205 |
+
if reached_max_sequences and seqname not in sequences:
|
| 206 |
+
continue
|
| 207 |
+
descriptions[seqname] = value
|
| 208 |
+
if len(descriptions) == len(sequences):
|
| 209 |
+
break
|
| 210 |
+
|
| 211 |
+
# Convert sto format to a3m line by line
|
| 212 |
+
a3m_sequences = {}
|
| 213 |
+
# query_sequence is assumed to be the first sequence
|
| 214 |
+
query_sequence = next(iter(sequences.values()))
|
| 215 |
+
query_non_gaps = [res != "-" for res in query_sequence]
|
| 216 |
+
for seqname, sto_sequence in sequences.items():
|
| 217 |
+
a3m_sequences[seqname] = "".join(
|
| 218 |
+
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
fasta_chunks = (
|
| 222 |
+
f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
|
| 223 |
+
for k in a3m_sequences
|
| 224 |
+
)
|
| 225 |
+
return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _get_hhr_line_regex_groups(
|
| 229 |
+
regex_pattern: str, line: str
|
| 230 |
+
) -> Sequence[Optional[str]]:
|
| 231 |
+
match = re.match(regex_pattern, line)
|
| 232 |
+
if match is None:
|
| 233 |
+
raise RuntimeError(f"Could not parse query line {line}")
|
| 234 |
+
return match.groups()
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def _update_hhr_residue_indices_list(
|
| 238 |
+
sequence: str, start_index: int, indices_list: List[int]
|
| 239 |
+
):
|
| 240 |
+
"""Computes the relative indices for each residue with respect to the original sequence."""
|
| 241 |
+
counter = start_index
|
| 242 |
+
for symbol in sequence:
|
| 243 |
+
if symbol == "-":
|
| 244 |
+
indices_list.append(-1)
|
| 245 |
+
else:
|
| 246 |
+
indices_list.append(counter)
|
| 247 |
+
counter += 1
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
|
| 251 |
+
"""Parses the detailed HMM HMM comparison section for a single Hit.
|
| 252 |
+
|
| 253 |
+
This works on .hhr files generated from both HHBlits and HHSearch.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
detailed_lines: A list of lines from a single comparison section between 2
|
| 257 |
+
sequences (which each have their own HMM's)
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
A dictionary with the information from that detailed comparison section
|
| 261 |
+
|
| 262 |
+
Raises:
|
| 263 |
+
RuntimeError: If a certain line cannot be processed
|
| 264 |
+
"""
|
| 265 |
+
# Parse first 2 lines.
|
| 266 |
+
number_of_hit = int(detailed_lines[0].split()[-1])
|
| 267 |
+
name_hit = detailed_lines[1][1:]
|
| 268 |
+
|
| 269 |
+
# Parse the summary line.
|
| 270 |
+
pattern = (
|
| 271 |
+
"Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
|
| 272 |
+
" ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
|
| 273 |
+
"]*Template_Neff=(.*)"
|
| 274 |
+
)
|
| 275 |
+
match = re.match(pattern, detailed_lines[2])
|
| 276 |
+
if match is None:
|
| 277 |
+
raise RuntimeError(
|
| 278 |
+
"Could not parse section: %s. Expected this: \n%s to contain summary."
|
| 279 |
+
% (detailed_lines, detailed_lines[2])
|
| 280 |
+
)
|
| 281 |
+
(prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [
|
| 282 |
+
float(x) for x in match.groups()
|
| 283 |
+
]
|
| 284 |
+
|
| 285 |
+
# The next section reads the detailed comparisons. These are in a 'human
|
| 286 |
+
# readable' format which has a fixed length. The strategy employed is to
|
| 287 |
+
# assume that each block starts with the query sequence line, and to parse
|
| 288 |
+
# that with a regexp in order to deduce the fixed length used for that block.
|
| 289 |
+
query = ""
|
| 290 |
+
hit_sequence = ""
|
| 291 |
+
indices_query = []
|
| 292 |
+
indices_hit = []
|
| 293 |
+
length_block = None
|
| 294 |
+
|
| 295 |
+
for line in detailed_lines[3:]:
|
| 296 |
+
# Parse the query sequence line
|
| 297 |
+
if (
|
| 298 |
+
line.startswith("Q ")
|
| 299 |
+
and not line.startswith("Q ss_dssp")
|
| 300 |
+
and not line.startswith("Q ss_pred")
|
| 301 |
+
and not line.startswith("Q Consensus")
|
| 302 |
+
):
|
| 303 |
+
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
|
| 304 |
+
# everything after that.
|
| 305 |
+
# start sequence end total_sequence_length
|
| 306 |
+
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
|
| 307 |
+
groups = _get_hhr_line_regex_groups(patt, line[17:])
|
| 308 |
+
|
| 309 |
+
# Get the length of the parsed block using the start and finish indices,
|
| 310 |
+
# and ensure it is the same as the actual block length.
|
| 311 |
+
start = int(groups[0]) - 1 # Make index zero based.
|
| 312 |
+
delta_query = groups[1]
|
| 313 |
+
end = int(groups[2])
|
| 314 |
+
num_insertions = len([x for x in delta_query if x == "-"])
|
| 315 |
+
length_block = end - start + num_insertions
|
| 316 |
+
assert length_block == len(delta_query)
|
| 317 |
+
|
| 318 |
+
# Update the query sequence and indices list.
|
| 319 |
+
query += delta_query
|
| 320 |
+
_update_hhr_residue_indices_list(delta_query, start, indices_query)
|
| 321 |
+
|
| 322 |
+
elif line.startswith("T "):
|
| 323 |
+
# Parse the hit sequence.
|
| 324 |
+
if (
|
| 325 |
+
not line.startswith("T ss_dssp")
|
| 326 |
+
and not line.startswith("T ss_pred")
|
| 327 |
+
and not line.startswith("T Consensus")
|
| 328 |
+
):
|
| 329 |
+
# Thus the first 17 characters must be 'T <hit_name> ', and we can
|
| 330 |
+
# parse everything after that.
|
| 331 |
+
# start sequence end total_sequence_length
|
| 332 |
+
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
|
| 333 |
+
groups = _get_hhr_line_regex_groups(patt, line[17:])
|
| 334 |
+
start = int(groups[0]) - 1 # Make index zero based.
|
| 335 |
+
delta_hit_sequence = groups[1]
|
| 336 |
+
assert length_block == len(delta_hit_sequence)
|
| 337 |
+
|
| 338 |
+
# Update the hit sequence and indices list.
|
| 339 |
+
hit_sequence += delta_hit_sequence
|
| 340 |
+
_update_hhr_residue_indices_list(
|
| 341 |
+
delta_hit_sequence, start, indices_hit
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
return TemplateHit(
|
| 345 |
+
index=number_of_hit,
|
| 346 |
+
name=name_hit,
|
| 347 |
+
aligned_cols=int(aligned_cols),
|
| 348 |
+
sum_probs=sum_probs,
|
| 349 |
+
query=query,
|
| 350 |
+
hit_sequence=hit_sequence,
|
| 351 |
+
indices_query=indices_query,
|
| 352 |
+
indices_hit=indices_hit,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
|
| 357 |
+
"""Parses the content of an entire HHR file."""
|
| 358 |
+
lines = hhr_string.splitlines()
|
| 359 |
+
|
| 360 |
+
# Each .hhr file starts with a results table, then has a sequence of hit
|
| 361 |
+
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
|
| 362 |
+
# iterate through each paragraph to parse each hit.
|
| 363 |
+
|
| 364 |
+
block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
|
| 365 |
+
|
| 366 |
+
hits = []
|
| 367 |
+
if block_starts:
|
| 368 |
+
block_starts.append(len(lines)) # Add the end of the final block.
|
| 369 |
+
for i in range(len(block_starts) - 1):
|
| 370 |
+
hits.append(
|
| 371 |
+
_parse_hhr_hit(lines[block_starts[i] : block_starts[i + 1]])
|
| 372 |
+
)
|
| 373 |
+
return hits
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
|
| 377 |
+
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
|
| 378 |
+
e_values = {"query": 0}
|
| 379 |
+
lines = [line for line in tblout.splitlines() if line[0] != "#"]
|
| 380 |
+
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
|
| 381 |
+
# space-delimited. Relevant fields are (1) target name: and
|
| 382 |
+
# (5) E-value (full sequence) (numbering from 1).
|
| 383 |
+
for line in lines:
|
| 384 |
+
fields = line.split()
|
| 385 |
+
e_value = fields[4]
|
| 386 |
+
target_name = fields[0]
|
| 387 |
+
e_values[target_name] = float(e_value)
|
| 388 |
+
return e_values
|
openfold/data/templates.py
ADDED
|
@@ -0,0 +1,1108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Functions for getting templates and calculating template features."""
|
| 17 |
+
import dataclasses
|
| 18 |
+
import datetime
|
| 19 |
+
import glob
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import os
|
| 23 |
+
import re
|
| 24 |
+
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
from openfold.data import parsers, mmcif_parsing
|
| 29 |
+
from openfold.data.errors import Error
|
| 30 |
+
from openfold.data.tools import kalign
|
| 31 |
+
from openfold.data.tools.utils import to_date
|
| 32 |
+
from openfold.np import residue_constants
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class NoChainsError(Error):
|
| 36 |
+
"""An error indicating that template mmCIF didn't have any chains."""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SequenceNotInTemplateError(Error):
|
| 40 |
+
"""An error indicating that template mmCIF didn't contain the sequence."""
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class NoAtomDataInTemplateError(Error):
|
| 44 |
+
"""An error indicating that template mmCIF didn't contain atom positions."""
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TemplateAtomMaskAllZerosError(Error):
|
| 48 |
+
"""An error indicating that template mmCIF had all atom positions masked."""
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class QueryToTemplateAlignError(Error):
|
| 52 |
+
"""An error indicating that the query can't be aligned to the template."""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class CaDistanceError(Error):
|
| 56 |
+
"""An error indicating that a CA atom distance exceeds a threshold."""
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Prefilter exceptions.
|
| 60 |
+
class PrefilterError(Exception):
|
| 61 |
+
"""A base class for template prefilter exceptions."""
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class DateError(PrefilterError):
|
| 65 |
+
"""An error indicating that the hit date was after the max allowed date."""
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class PdbIdError(PrefilterError):
|
| 69 |
+
"""An error indicating that the hit PDB ID was identical to the query."""
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class AlignRatioError(PrefilterError):
|
| 73 |
+
"""An error indicating that the hit align ratio to the query was too small."""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class DuplicateError(PrefilterError):
|
| 77 |
+
"""An error indicating that the hit was an exact subsequence of the query."""
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class LengthError(PrefilterError):
|
| 81 |
+
"""An error indicating that the hit was too short."""
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
TEMPLATE_FEATURES = {
|
| 85 |
+
"template_aatype": np.int64,
|
| 86 |
+
"template_all_atom_mask": np.float32,
|
| 87 |
+
"template_all_atom_positions": np.float32,
|
| 88 |
+
"template_domain_names": np.object,
|
| 89 |
+
"template_sequence": np.object,
|
| 90 |
+
"template_sum_probs": np.float32,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
|
| 95 |
+
"""Returns PDB id and chain id for an HHSearch Hit."""
|
| 96 |
+
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
|
| 97 |
+
id_match = re.match(r"[a-zA-Z\d]{4}_[a-zA-Z0-9.]+", hit.name)
|
| 98 |
+
if not id_match:
|
| 99 |
+
raise ValueError(f"hit.name did not start with PDBID_chain: {hit.name}")
|
| 100 |
+
pdb_id, chain_id = id_match.group(0).split("_")
|
| 101 |
+
return pdb_id.lower(), chain_id
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _is_after_cutoff(
|
| 105 |
+
pdb_id: str,
|
| 106 |
+
release_dates: Mapping[str, datetime.datetime],
|
| 107 |
+
release_date_cutoff: Optional[datetime.datetime],
|
| 108 |
+
) -> bool:
|
| 109 |
+
"""Checks if the template date is after the release date cutoff.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
pdb_id: 4 letter pdb code.
|
| 113 |
+
release_dates: Dictionary mapping PDB ids to their structure release dates.
|
| 114 |
+
release_date_cutoff: Max release date that is valid for this query.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
True if the template release date is after the cutoff, False otherwise.
|
| 118 |
+
"""
|
| 119 |
+
pdb_id_upper = pdb_id.upper()
|
| 120 |
+
if release_date_cutoff is None:
|
| 121 |
+
raise ValueError("The release_date_cutoff must not be None.")
|
| 122 |
+
if pdb_id_upper in release_dates:
|
| 123 |
+
return release_dates[pdb_id_upper] > release_date_cutoff
|
| 124 |
+
else:
|
| 125 |
+
# Since this is just a quick prefilter to reduce the number of mmCIF files
|
| 126 |
+
# we need to parse, we don't have to worry about returning True here.
|
| 127 |
+
logging.info(
|
| 128 |
+
"Template structure not in release dates dict: %s", pdb_id
|
| 129 |
+
)
|
| 130 |
+
return False
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
|
| 134 |
+
"""Parses the data file from PDB that lists which PDB ids are obsolete."""
|
| 135 |
+
with open(obsolete_file_path) as f:
|
| 136 |
+
result = {}
|
| 137 |
+
for line in f:
|
| 138 |
+
line = line.strip()
|
| 139 |
+
# We skip obsolete entries that don't contain a mapping to a new entry.
|
| 140 |
+
if line.startswith("OBSLTE") and len(line) > 30:
|
| 141 |
+
# Format: Date From To
|
| 142 |
+
# 'OBSLTE 31-JUL-94 116L 216L'
|
| 143 |
+
from_id = line[20:24].lower()
|
| 144 |
+
to_id = line[29:33].lower()
|
| 145 |
+
result[from_id] = to_id
|
| 146 |
+
return result
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def generate_release_dates_cache(mmcif_dir: str, out_path: str):
|
| 150 |
+
dates = {}
|
| 151 |
+
for f in os.listdir(mmcif_dir):
|
| 152 |
+
if f.endswith(".cif"):
|
| 153 |
+
path = os.path.join(mmcif_dir, f)
|
| 154 |
+
with open(path, "r") as fp:
|
| 155 |
+
mmcif_string = fp.read()
|
| 156 |
+
|
| 157 |
+
file_id = os.path.splitext(f)[0]
|
| 158 |
+
mmcif = mmcif_parsing.parse(
|
| 159 |
+
file_id=file_id, mmcif_string=mmcif_string
|
| 160 |
+
)
|
| 161 |
+
if mmcif.mmcif_object is None:
|
| 162 |
+
logging.info(f"Failed to parse {f}. Skipping...")
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
mmcif = mmcif.mmcif_object
|
| 166 |
+
release_date = mmcif.header["release_date"]
|
| 167 |
+
|
| 168 |
+
dates[file_id] = release_date
|
| 169 |
+
|
| 170 |
+
with open(out_path, "r") as fp:
|
| 171 |
+
fp.write(json.dumps(dates))
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
|
| 175 |
+
"""Parses release dates file, returns a mapping from PDBs to release dates."""
|
| 176 |
+
with open(path, "r") as fp:
|
| 177 |
+
data = json.load(fp)
|
| 178 |
+
|
| 179 |
+
return {
|
| 180 |
+
pdb.upper(): to_date(v)
|
| 181 |
+
for pdb, d in data.items()
|
| 182 |
+
for k, v in d.items()
|
| 183 |
+
if k == "release_date"
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _assess_hhsearch_hit(
|
| 188 |
+
hit: parsers.TemplateHit,
|
| 189 |
+
hit_pdb_code: str,
|
| 190 |
+
query_sequence: str,
|
| 191 |
+
query_pdb_code: Optional[str],
|
| 192 |
+
release_dates: Mapping[str, datetime.datetime],
|
| 193 |
+
release_date_cutoff: datetime.datetime,
|
| 194 |
+
max_subsequence_ratio: float = 0.95,
|
| 195 |
+
min_align_ratio: float = 0.1,
|
| 196 |
+
) -> bool:
|
| 197 |
+
"""Determines if template is valid (without parsing the template mmcif file).
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
hit: HhrHit for the template.
|
| 201 |
+
hit_pdb_code: The 4 letter pdb code of the template hit. This might be
|
| 202 |
+
different from the value in the actual hit since the original pdb might
|
| 203 |
+
have become obsolete.
|
| 204 |
+
query_sequence: Amino acid sequence of the query.
|
| 205 |
+
query_pdb_code: 4 letter pdb code of the query.
|
| 206 |
+
release_dates: Dictionary mapping pdb codes to their structure release
|
| 207 |
+
dates.
|
| 208 |
+
release_date_cutoff: Max release date that is valid for this query.
|
| 209 |
+
max_subsequence_ratio: Exclude any exact matches with this much overlap.
|
| 210 |
+
min_align_ratio: Minimum overlap between the template and query.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
True if the hit passed the prefilter. Raises an exception otherwise.
|
| 214 |
+
|
| 215 |
+
Raises:
|
| 216 |
+
DateError: If the hit date was after the max allowed date.
|
| 217 |
+
PdbIdError: If the hit PDB ID was identical to the query.
|
| 218 |
+
AlignRatioError: If the hit align ratio to the query was too small.
|
| 219 |
+
DuplicateError: If the hit was an exact subsequence of the query.
|
| 220 |
+
LengthError: If the hit was too short.
|
| 221 |
+
"""
|
| 222 |
+
aligned_cols = hit.aligned_cols
|
| 223 |
+
align_ratio = aligned_cols / len(query_sequence)
|
| 224 |
+
|
| 225 |
+
template_sequence = hit.hit_sequence.replace("-", "")
|
| 226 |
+
length_ratio = float(len(template_sequence)) / len(query_sequence)
|
| 227 |
+
|
| 228 |
+
# Check whether the template is a large subsequence or duplicate of original
|
| 229 |
+
# query. This can happen due to duplicate entries in the PDB database.
|
| 230 |
+
duplicate = (
|
| 231 |
+
template_sequence in query_sequence
|
| 232 |
+
and length_ratio > max_subsequence_ratio
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
|
| 236 |
+
date = release_dates[hit_pdb_code.upper()]
|
| 237 |
+
raise DateError(
|
| 238 |
+
f"Date ({date}) > max template date "
|
| 239 |
+
f"({release_date_cutoff})."
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
if query_pdb_code is not None:
|
| 243 |
+
if query_pdb_code.lower() == hit_pdb_code.lower():
|
| 244 |
+
raise PdbIdError("PDB code identical to Query PDB code.")
|
| 245 |
+
|
| 246 |
+
if align_ratio <= min_align_ratio:
|
| 247 |
+
raise AlignRatioError(
|
| 248 |
+
"Proportion of residues aligned to query too small. "
|
| 249 |
+
f"Align ratio: {align_ratio}."
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
if duplicate:
|
| 253 |
+
raise DuplicateError(
|
| 254 |
+
"Template is an exact subsequence of query with large "
|
| 255 |
+
f"coverage. Length ratio: {length_ratio}."
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if len(template_sequence) < 10:
|
| 259 |
+
raise LengthError(
|
| 260 |
+
f"Template too short. Length: {len(template_sequence)}."
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
return True
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _find_template_in_pdb(
|
| 267 |
+
template_chain_id: str,
|
| 268 |
+
template_sequence: str,
|
| 269 |
+
mmcif_object: mmcif_parsing.MmcifObject,
|
| 270 |
+
) -> Tuple[str, str, int]:
|
| 271 |
+
"""Tries to find the template chain in the given pdb file.
|
| 272 |
+
|
| 273 |
+
This method tries the three following things in order:
|
| 274 |
+
1. Tries if there is an exact match in both the chain ID and the sequence.
|
| 275 |
+
If yes, the chain sequence is returned. Otherwise:
|
| 276 |
+
2. Tries if there is an exact match only in the sequence.
|
| 277 |
+
If yes, the chain sequence is returned. Otherwise:
|
| 278 |
+
3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
|
| 279 |
+
If yes, the chain sequence is returned.
|
| 280 |
+
If none of these succeed, a SequenceNotInTemplateError is thrown.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
template_chain_id: The template chain ID.
|
| 284 |
+
template_sequence: The template chain sequence.
|
| 285 |
+
mmcif_object: The PDB object to search for the template in.
|
| 286 |
+
|
| 287 |
+
Returns:
|
| 288 |
+
A tuple with:
|
| 289 |
+
* The chain sequence that was found to match the template in the PDB object.
|
| 290 |
+
* The ID of the chain that is being returned.
|
| 291 |
+
* The offset where the template sequence starts in the chain sequence.
|
| 292 |
+
|
| 293 |
+
Raises:
|
| 294 |
+
SequenceNotInTemplateError: If no match is found after the steps described
|
| 295 |
+
above.
|
| 296 |
+
"""
|
| 297 |
+
# Try if there is an exact match in both the chain ID and the (sub)sequence.
|
| 298 |
+
pdb_id = mmcif_object.file_id
|
| 299 |
+
chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
|
| 300 |
+
if chain_sequence and (template_sequence in chain_sequence):
|
| 301 |
+
logging.info(
|
| 302 |
+
"Found an exact template match %s_%s.", pdb_id, template_chain_id
|
| 303 |
+
)
|
| 304 |
+
mapping_offset = chain_sequence.find(template_sequence)
|
| 305 |
+
return chain_sequence, template_chain_id, mapping_offset
|
| 306 |
+
|
| 307 |
+
# Try if there is an exact match in the (sub)sequence only.
|
| 308 |
+
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
|
| 309 |
+
if chain_sequence and (template_sequence in chain_sequence):
|
| 310 |
+
logging.info("Found a sequence-only match %s_%s.", pdb_id, chain_id)
|
| 311 |
+
mapping_offset = chain_sequence.find(template_sequence)
|
| 312 |
+
return chain_sequence, chain_id, mapping_offset
|
| 313 |
+
|
| 314 |
+
# Return a chain sequence that fuzzy matches (X = wildcard) the template.
|
| 315 |
+
# Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
|
| 316 |
+
regex = ["." if aa == "X" else "(?:%s|X)" % aa for aa in template_sequence]
|
| 317 |
+
regex = re.compile("".join(regex))
|
| 318 |
+
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
|
| 319 |
+
match = re.search(regex, chain_sequence)
|
| 320 |
+
if match:
|
| 321 |
+
logging.info(
|
| 322 |
+
"Found a fuzzy sequence-only match %s_%s.", pdb_id, chain_id
|
| 323 |
+
)
|
| 324 |
+
mapping_offset = match.start()
|
| 325 |
+
return chain_sequence, chain_id, mapping_offset
|
| 326 |
+
|
| 327 |
+
# No hits, raise an error.
|
| 328 |
+
raise SequenceNotInTemplateError(
|
| 329 |
+
"Could not find the template sequence in %s_%s. Template sequence: %s, "
|
| 330 |
+
"chain_to_seqres: %s"
|
| 331 |
+
% (
|
| 332 |
+
pdb_id,
|
| 333 |
+
template_chain_id,
|
| 334 |
+
template_sequence,
|
| 335 |
+
mmcif_object.chain_to_seqres,
|
| 336 |
+
)
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def _realign_pdb_template_to_query(
|
| 341 |
+
old_template_sequence: str,
|
| 342 |
+
template_chain_id: str,
|
| 343 |
+
mmcif_object: mmcif_parsing.MmcifObject,
|
| 344 |
+
old_mapping: Mapping[int, int],
|
| 345 |
+
kalign_binary_path: str,
|
| 346 |
+
) -> Tuple[str, Mapping[int, int]]:
|
| 347 |
+
"""Aligns template from the mmcif_object to the query.
|
| 348 |
+
|
| 349 |
+
In case PDB70 contains a different version of the template sequence, we need
|
| 350 |
+
to perform a realignment to the actual sequence that is in the mmCIF file.
|
| 351 |
+
This method performs such realignment, but returns the new sequence and
|
| 352 |
+
mapping only if the sequence in the mmCIF file is 90% identical to the old
|
| 353 |
+
sequence.
|
| 354 |
+
|
| 355 |
+
Note that the old_template_sequence comes from the hit, and contains only that
|
| 356 |
+
part of the chain that matches with the query while the new_template_sequence
|
| 357 |
+
is the full chain.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
old_template_sequence: The template sequence that was returned by the PDB
|
| 361 |
+
template search (typically done using HHSearch).
|
| 362 |
+
template_chain_id: The template chain id was returned by the PDB template
|
| 363 |
+
search (typically done using HHSearch). This is used to find the right
|
| 364 |
+
chain in the mmcif_object chain_to_seqres mapping.
|
| 365 |
+
mmcif_object: A mmcif_object which holds the actual template data.
|
| 366 |
+
old_mapping: A mapping from the query sequence to the template sequence.
|
| 367 |
+
This mapping will be used to compute the new mapping from the query
|
| 368 |
+
sequence to the actual mmcif_object template sequence by aligning the
|
| 369 |
+
old_template_sequence and the actual template sequence.
|
| 370 |
+
kalign_binary_path: The path to a kalign executable.
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
A tuple (new_template_sequence, new_query_to_template_mapping) where:
|
| 374 |
+
* new_template_sequence is the actual template sequence that was found in
|
| 375 |
+
the mmcif_object.
|
| 376 |
+
* new_query_to_template_mapping is the new mapping from the query to the
|
| 377 |
+
actual template found in the mmcif_object.
|
| 378 |
+
|
| 379 |
+
Raises:
|
| 380 |
+
QueryToTemplateAlignError:
|
| 381 |
+
* If there was an error thrown by the alignment tool.
|
| 382 |
+
* Or if the actual template sequence differs by more than 10% from the
|
| 383 |
+
old_template_sequence.
|
| 384 |
+
"""
|
| 385 |
+
aligner = kalign.Kalign(binary_path=kalign_binary_path)
|
| 386 |
+
new_template_sequence = mmcif_object.chain_to_seqres.get(
|
| 387 |
+
template_chain_id, ""
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# Sometimes the template chain id is unknown. But if there is only a single
|
| 391 |
+
# sequence within the mmcif_object, it is safe to assume it is that one.
|
| 392 |
+
if not new_template_sequence:
|
| 393 |
+
if len(mmcif_object.chain_to_seqres) == 1:
|
| 394 |
+
logging.info(
|
| 395 |
+
"Could not find %s in %s, but there is only 1 sequence, so "
|
| 396 |
+
"using that one.",
|
| 397 |
+
template_chain_id,
|
| 398 |
+
mmcif_object.file_id,
|
| 399 |
+
)
|
| 400 |
+
new_template_sequence = list(mmcif_object.chain_to_seqres.values())[
|
| 401 |
+
0
|
| 402 |
+
]
|
| 403 |
+
else:
|
| 404 |
+
raise QueryToTemplateAlignError(
|
| 405 |
+
f"Could not find chain {template_chain_id} in {mmcif_object.file_id}. "
|
| 406 |
+
"If there are no mmCIF parsing errors, it is possible it was not a "
|
| 407 |
+
"protein chain."
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
try:
|
| 411 |
+
(old_aligned_template, new_aligned_template), _ = parsers.parse_a3m(
|
| 412 |
+
aligner.align([old_template_sequence, new_template_sequence])
|
| 413 |
+
)
|
| 414 |
+
except Exception as e:
|
| 415 |
+
raise QueryToTemplateAlignError(
|
| 416 |
+
"Could not align old template %s to template %s (%s_%s). Error: %s"
|
| 417 |
+
% (
|
| 418 |
+
old_template_sequence,
|
| 419 |
+
new_template_sequence,
|
| 420 |
+
mmcif_object.file_id,
|
| 421 |
+
template_chain_id,
|
| 422 |
+
str(e),
|
| 423 |
+
)
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
logging.info(
|
| 427 |
+
"Old aligned template: %s\nNew aligned template: %s",
|
| 428 |
+
old_aligned_template,
|
| 429 |
+
new_aligned_template,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
old_to_new_template_mapping = {}
|
| 433 |
+
old_template_index = -1
|
| 434 |
+
new_template_index = -1
|
| 435 |
+
num_same = 0
|
| 436 |
+
for old_template_aa, new_template_aa in zip(
|
| 437 |
+
old_aligned_template, new_aligned_template
|
| 438 |
+
):
|
| 439 |
+
if old_template_aa != "-":
|
| 440 |
+
old_template_index += 1
|
| 441 |
+
if new_template_aa != "-":
|
| 442 |
+
new_template_index += 1
|
| 443 |
+
if old_template_aa != "-" and new_template_aa != "-":
|
| 444 |
+
old_to_new_template_mapping[old_template_index] = new_template_index
|
| 445 |
+
if old_template_aa == new_template_aa:
|
| 446 |
+
num_same += 1
|
| 447 |
+
|
| 448 |
+
# Require at least 90 % sequence identity wrt to the shorter of the sequences.
|
| 449 |
+
if (
|
| 450 |
+
float(num_same)
|
| 451 |
+
/ min(len(old_template_sequence), len(new_template_sequence))
|
| 452 |
+
< 0.9
|
| 453 |
+
):
|
| 454 |
+
raise QueryToTemplateAlignError(
|
| 455 |
+
"Insufficient similarity of the sequence in the database: %s to the "
|
| 456 |
+
"actual sequence in the mmCIF file %s_%s: %s. We require at least "
|
| 457 |
+
"90 %% similarity wrt to the shorter of the sequences. This is not a "
|
| 458 |
+
"problem unless you think this is a template that should be included."
|
| 459 |
+
% (
|
| 460 |
+
old_template_sequence,
|
| 461 |
+
mmcif_object.file_id,
|
| 462 |
+
template_chain_id,
|
| 463 |
+
new_template_sequence,
|
| 464 |
+
)
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
new_query_to_template_mapping = {}
|
| 468 |
+
for query_index, old_template_index in old_mapping.items():
|
| 469 |
+
new_query_to_template_mapping[
|
| 470 |
+
query_index
|
| 471 |
+
] = old_to_new_template_mapping.get(old_template_index, -1)
|
| 472 |
+
|
| 473 |
+
new_template_sequence = new_template_sequence.replace("-", "")
|
| 474 |
+
|
| 475 |
+
return new_template_sequence, new_query_to_template_mapping
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def _check_residue_distances(
|
| 479 |
+
all_positions: np.ndarray,
|
| 480 |
+
all_positions_mask: np.ndarray,
|
| 481 |
+
max_ca_ca_distance: float,
|
| 482 |
+
):
|
| 483 |
+
"""Checks if the distance between unmasked neighbor residues is ok."""
|
| 484 |
+
ca_position = residue_constants.atom_order["CA"]
|
| 485 |
+
prev_is_unmasked = False
|
| 486 |
+
prev_calpha = None
|
| 487 |
+
for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
|
| 488 |
+
this_is_unmasked = bool(mask[ca_position])
|
| 489 |
+
if this_is_unmasked:
|
| 490 |
+
this_calpha = coords[ca_position]
|
| 491 |
+
if prev_is_unmasked:
|
| 492 |
+
distance = np.linalg.norm(this_calpha - prev_calpha)
|
| 493 |
+
if distance > max_ca_ca_distance:
|
| 494 |
+
raise CaDistanceError(
|
| 495 |
+
"The distance between residues %d and %d is %f > limit %f."
|
| 496 |
+
% (i, i + 1, distance, max_ca_ca_distance)
|
| 497 |
+
)
|
| 498 |
+
prev_calpha = this_calpha
|
| 499 |
+
prev_is_unmasked = this_is_unmasked
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
def _get_atom_positions(
|
| 503 |
+
mmcif_object: mmcif_parsing.MmcifObject,
|
| 504 |
+
auth_chain_id: str,
|
| 505 |
+
max_ca_ca_distance: float,
|
| 506 |
+
_zero_center_positions: bool = False,
|
| 507 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 508 |
+
"""Gets atom positions and mask from a list of Biopython Residues."""
|
| 509 |
+
coords_with_mask = mmcif_parsing.get_atom_coords(
|
| 510 |
+
mmcif_object=mmcif_object,
|
| 511 |
+
chain_id=auth_chain_id,
|
| 512 |
+
_zero_center_positions=_zero_center_positions,
|
| 513 |
+
)
|
| 514 |
+
all_atom_positions, all_atom_mask = coords_with_mask
|
| 515 |
+
_check_residue_distances(
|
| 516 |
+
all_atom_positions, all_atom_mask, max_ca_ca_distance
|
| 517 |
+
)
|
| 518 |
+
return all_atom_positions, all_atom_mask
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def _extract_template_features(
|
| 522 |
+
mmcif_object: mmcif_parsing.MmcifObject,
|
| 523 |
+
pdb_id: str,
|
| 524 |
+
mapping: Mapping[int, int],
|
| 525 |
+
template_sequence: str,
|
| 526 |
+
query_sequence: str,
|
| 527 |
+
template_chain_id: str,
|
| 528 |
+
kalign_binary_path: str,
|
| 529 |
+
_zero_center_positions: bool = True,
|
| 530 |
+
) -> Tuple[Dict[str, Any], Optional[str]]:
|
| 531 |
+
"""Parses atom positions in the target structure and aligns with the query.
|
| 532 |
+
|
| 533 |
+
Atoms for each residue in the template structure are indexed to coincide
|
| 534 |
+
with their corresponding residue in the query sequence, according to the
|
| 535 |
+
alignment mapping provided.
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
mmcif_object: mmcif_parsing.MmcifObject representing the template.
|
| 539 |
+
pdb_id: PDB code for the template.
|
| 540 |
+
mapping: Dictionary mapping indices in the query sequence to indices in
|
| 541 |
+
the template sequence.
|
| 542 |
+
template_sequence: String describing the amino acid sequence for the
|
| 543 |
+
template protein.
|
| 544 |
+
query_sequence: String describing the amino acid sequence for the query
|
| 545 |
+
protein.
|
| 546 |
+
template_chain_id: String ID describing which chain in the structure proto
|
| 547 |
+
should be used.
|
| 548 |
+
kalign_binary_path: The path to a kalign executable used for template
|
| 549 |
+
realignment.
|
| 550 |
+
|
| 551 |
+
Returns:
|
| 552 |
+
A tuple with:
|
| 553 |
+
* A dictionary containing the extra features derived from the template
|
| 554 |
+
protein structure.
|
| 555 |
+
* A warning message if the hit was realigned to the actual mmCIF sequence.
|
| 556 |
+
Otherwise None.
|
| 557 |
+
|
| 558 |
+
Raises:
|
| 559 |
+
NoChainsError: If the mmcif object doesn't contain any chains.
|
| 560 |
+
SequenceNotInTemplateError: If the given chain id / sequence can't
|
| 561 |
+
be found in the mmcif object.
|
| 562 |
+
QueryToTemplateAlignError: If the actual template in the mmCIF file
|
| 563 |
+
can't be aligned to the query.
|
| 564 |
+
NoAtomDataInTemplateError: If the mmcif object doesn't contain
|
| 565 |
+
atom positions.
|
| 566 |
+
TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
|
| 567 |
+
unmasked residues.
|
| 568 |
+
"""
|
| 569 |
+
if mmcif_object is None or not mmcif_object.chain_to_seqres:
|
| 570 |
+
raise NoChainsError(
|
| 571 |
+
"No chains in PDB: %s_%s" % (pdb_id, template_chain_id)
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
warning = None
|
| 575 |
+
try:
|
| 576 |
+
seqres, chain_id, mapping_offset = _find_template_in_pdb(
|
| 577 |
+
template_chain_id=template_chain_id,
|
| 578 |
+
template_sequence=template_sequence,
|
| 579 |
+
mmcif_object=mmcif_object,
|
| 580 |
+
)
|
| 581 |
+
except SequenceNotInTemplateError:
|
| 582 |
+
# If PDB70 contains a different version of the template, we use the sequence
|
| 583 |
+
# from the mmcif_object.
|
| 584 |
+
chain_id = template_chain_id
|
| 585 |
+
warning = (
|
| 586 |
+
f"The exact sequence {template_sequence} was not found in "
|
| 587 |
+
f"{pdb_id}_{chain_id}. Realigning the template to the actual sequence."
|
| 588 |
+
)
|
| 589 |
+
logging.warning(warning)
|
| 590 |
+
# This throws an exception if it fails to realign the hit.
|
| 591 |
+
seqres, mapping = _realign_pdb_template_to_query(
|
| 592 |
+
old_template_sequence=template_sequence,
|
| 593 |
+
template_chain_id=template_chain_id,
|
| 594 |
+
mmcif_object=mmcif_object,
|
| 595 |
+
old_mapping=mapping,
|
| 596 |
+
kalign_binary_path=kalign_binary_path,
|
| 597 |
+
)
|
| 598 |
+
logging.info(
|
| 599 |
+
"Sequence in %s_%s: %s successfully realigned to %s",
|
| 600 |
+
pdb_id,
|
| 601 |
+
chain_id,
|
| 602 |
+
template_sequence,
|
| 603 |
+
seqres,
|
| 604 |
+
)
|
| 605 |
+
# The template sequence changed.
|
| 606 |
+
template_sequence = seqres
|
| 607 |
+
# No mapping offset, the query is aligned to the actual sequence.
|
| 608 |
+
mapping_offset = 0
|
| 609 |
+
|
| 610 |
+
try:
|
| 611 |
+
# Essentially set to infinity - we don't want to reject templates unless
|
| 612 |
+
# they're really really bad.
|
| 613 |
+
all_atom_positions, all_atom_mask = _get_atom_positions(
|
| 614 |
+
mmcif_object,
|
| 615 |
+
chain_id,
|
| 616 |
+
max_ca_ca_distance=150.0,
|
| 617 |
+
_zero_center_positions=_zero_center_positions,
|
| 618 |
+
)
|
| 619 |
+
except (CaDistanceError, KeyError) as ex:
|
| 620 |
+
raise NoAtomDataInTemplateError(
|
| 621 |
+
"Could not get atom data (%s_%s): %s" % (pdb_id, chain_id, str(ex))
|
| 622 |
+
) from ex
|
| 623 |
+
|
| 624 |
+
all_atom_positions = np.split(
|
| 625 |
+
all_atom_positions, all_atom_positions.shape[0]
|
| 626 |
+
)
|
| 627 |
+
all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
|
| 628 |
+
|
| 629 |
+
output_templates_sequence = []
|
| 630 |
+
templates_all_atom_positions = []
|
| 631 |
+
templates_all_atom_masks = []
|
| 632 |
+
|
| 633 |
+
for _ in query_sequence:
|
| 634 |
+
# Residues in the query_sequence that are not in the template_sequence:
|
| 635 |
+
templates_all_atom_positions.append(
|
| 636 |
+
np.zeros((residue_constants.atom_type_num, 3))
|
| 637 |
+
)
|
| 638 |
+
templates_all_atom_masks.append(
|
| 639 |
+
np.zeros(residue_constants.atom_type_num)
|
| 640 |
+
)
|
| 641 |
+
output_templates_sequence.append("-")
|
| 642 |
+
|
| 643 |
+
for k, v in mapping.items():
|
| 644 |
+
template_index = v + mapping_offset
|
| 645 |
+
templates_all_atom_positions[k] = all_atom_positions[template_index][0]
|
| 646 |
+
templates_all_atom_masks[k] = all_atom_masks[template_index][0]
|
| 647 |
+
output_templates_sequence[k] = template_sequence[v]
|
| 648 |
+
|
| 649 |
+
# Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
|
| 650 |
+
if np.sum(templates_all_atom_masks) < 5:
|
| 651 |
+
raise TemplateAtomMaskAllZerosError(
|
| 652 |
+
"Template all atom mask was all zeros: %s_%s. Residue range: %d-%d"
|
| 653 |
+
% (
|
| 654 |
+
pdb_id,
|
| 655 |
+
chain_id,
|
| 656 |
+
min(mapping.values()) + mapping_offset,
|
| 657 |
+
max(mapping.values()) + mapping_offset,
|
| 658 |
+
)
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
output_templates_sequence = "".join(output_templates_sequence)
|
| 662 |
+
|
| 663 |
+
templates_aatype = residue_constants.sequence_to_onehot(
|
| 664 |
+
output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
return (
|
| 668 |
+
{
|
| 669 |
+
"template_all_atom_positions": np.array(
|
| 670 |
+
templates_all_atom_positions
|
| 671 |
+
),
|
| 672 |
+
"template_all_atom_mask": np.array(templates_all_atom_masks),
|
| 673 |
+
"template_sequence": output_templates_sequence.encode(),
|
| 674 |
+
"template_aatype": np.array(templates_aatype),
|
| 675 |
+
"template_domain_names": f"{pdb_id.lower()}_{chain_id}".encode(),
|
| 676 |
+
},
|
| 677 |
+
warning,
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def _build_query_to_hit_index_mapping(
|
| 682 |
+
hit_query_sequence: str,
|
| 683 |
+
hit_sequence: str,
|
| 684 |
+
indices_hit: Sequence[int],
|
| 685 |
+
indices_query: Sequence[int],
|
| 686 |
+
original_query_sequence: str,
|
| 687 |
+
) -> Mapping[int, int]:
|
| 688 |
+
"""Gets mapping from indices in original query sequence to indices in the hit.
|
| 689 |
+
|
| 690 |
+
hit_query_sequence and hit_sequence are two aligned sequences containing gap
|
| 691 |
+
characters. hit_query_sequence contains only the part of the original query
|
| 692 |
+
sequence that matched the hit. When interpreting the indices from the .hhr, we
|
| 693 |
+
need to correct for this to recover a mapping from original query sequence to
|
| 694 |
+
the hit sequence.
|
| 695 |
+
|
| 696 |
+
Args:
|
| 697 |
+
hit_query_sequence: The portion of the query sequence that is in the .hhr
|
| 698 |
+
hit
|
| 699 |
+
hit_sequence: The portion of the hit sequence that is in the .hhr
|
| 700 |
+
indices_hit: The indices for each aminoacid relative to the hit sequence
|
| 701 |
+
indices_query: The indices for each aminoacid relative to the original query
|
| 702 |
+
sequence
|
| 703 |
+
original_query_sequence: String describing the original query sequence.
|
| 704 |
+
|
| 705 |
+
Returns:
|
| 706 |
+
Dictionary with indices in the original query sequence as keys and indices
|
| 707 |
+
in the hit sequence as values.
|
| 708 |
+
"""
|
| 709 |
+
# If the hit is empty (no aligned residues), return empty mapping
|
| 710 |
+
if not hit_query_sequence:
|
| 711 |
+
return {}
|
| 712 |
+
|
| 713 |
+
# Remove gaps and find the offset of hit.query relative to original query.
|
| 714 |
+
hhsearch_query_sequence = hit_query_sequence.replace("-", "")
|
| 715 |
+
hit_sequence = hit_sequence.replace("-", "")
|
| 716 |
+
hhsearch_query_offset = original_query_sequence.find(
|
| 717 |
+
hhsearch_query_sequence
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
# Index of -1 used for gap characters. Subtract the min index ignoring gaps.
|
| 721 |
+
min_idx = min(x for x in indices_hit if x > -1)
|
| 722 |
+
fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit]
|
| 723 |
+
|
| 724 |
+
min_idx = min(x for x in indices_query if x > -1)
|
| 725 |
+
fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]
|
| 726 |
+
|
| 727 |
+
# Zip the corrected indices, ignore case where both seqs have gap characters.
|
| 728 |
+
mapping = {}
|
| 729 |
+
for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
|
| 730 |
+
if q_t != -1 and q_i != -1:
|
| 731 |
+
if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(
|
| 732 |
+
original_query_sequence
|
| 733 |
+
):
|
| 734 |
+
continue
|
| 735 |
+
mapping[q_i + hhsearch_query_offset] = q_t
|
| 736 |
+
|
| 737 |
+
return mapping
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
@dataclasses.dataclass(frozen=True)
|
| 741 |
+
class PrefilterResult:
|
| 742 |
+
valid: bool
|
| 743 |
+
error: Optional[str]
|
| 744 |
+
warning: Optional[str]
|
| 745 |
+
|
| 746 |
+
@dataclasses.dataclass(frozen=True)
|
| 747 |
+
class SingleHitResult:
|
| 748 |
+
features: Optional[Mapping[str, Any]]
|
| 749 |
+
error: Optional[str]
|
| 750 |
+
warning: Optional[str]
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
def _prefilter_hit(
|
| 754 |
+
query_sequence: str,
|
| 755 |
+
query_pdb_code: Optional[str],
|
| 756 |
+
hit: parsers.TemplateHit,
|
| 757 |
+
max_template_date: datetime.datetime,
|
| 758 |
+
release_dates: Mapping[str, datetime.datetime],
|
| 759 |
+
obsolete_pdbs: Mapping[str, str],
|
| 760 |
+
strict_error_check: bool = False,
|
| 761 |
+
):
|
| 762 |
+
# Fail hard if we can't get the PDB ID and chain name from the hit.
|
| 763 |
+
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
|
| 764 |
+
|
| 765 |
+
if hit_pdb_code not in release_dates:
|
| 766 |
+
if hit_pdb_code in obsolete_pdbs:
|
| 767 |
+
hit_pdb_code = obsolete_pdbs[hit_pdb_code]
|
| 768 |
+
|
| 769 |
+
# Pass hit_pdb_code since it might have changed due to the pdb being
|
| 770 |
+
# obsolete.
|
| 771 |
+
try:
|
| 772 |
+
_assess_hhsearch_hit(
|
| 773 |
+
hit=hit,
|
| 774 |
+
hit_pdb_code=hit_pdb_code,
|
| 775 |
+
query_sequence=query_sequence,
|
| 776 |
+
query_pdb_code=query_pdb_code,
|
| 777 |
+
release_dates=release_dates,
|
| 778 |
+
release_date_cutoff=max_template_date,
|
| 779 |
+
)
|
| 780 |
+
except PrefilterError as e:
|
| 781 |
+
hit_name = f"{hit_pdb_code}_{hit_chain_id}"
|
| 782 |
+
msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
|
| 783 |
+
logging.info("%s: %s", query_pdb_code, msg)
|
| 784 |
+
if strict_error_check and isinstance(
|
| 785 |
+
e, (DateError, PdbIdError, DuplicateError)
|
| 786 |
+
):
|
| 787 |
+
# In strict mode we treat some prefilter cases as errors.
|
| 788 |
+
return PrefilterResult(valid=False, error=msg, warning=None)
|
| 789 |
+
|
| 790 |
+
return PrefilterResult(valid=False, error=None, warning=None)
|
| 791 |
+
|
| 792 |
+
return PrefilterResult(valid=True, error=None, warning=None)
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
def _process_single_hit(
|
| 796 |
+
query_sequence: str,
|
| 797 |
+
query_pdb_code: Optional[str],
|
| 798 |
+
hit: parsers.TemplateHit,
|
| 799 |
+
mmcif_dir: str,
|
| 800 |
+
max_template_date: datetime.datetime,
|
| 801 |
+
release_dates: Mapping[str, datetime.datetime],
|
| 802 |
+
obsolete_pdbs: Mapping[str, str],
|
| 803 |
+
kalign_binary_path: str,
|
| 804 |
+
strict_error_check: bool = False,
|
| 805 |
+
_zero_center_positions: bool = True,
|
| 806 |
+
) -> SingleHitResult:
|
| 807 |
+
"""Tries to extract template features from a single HHSearch hit."""
|
| 808 |
+
# Fail hard if we can't get the PDB ID and chain name from the hit.
|
| 809 |
+
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
|
| 810 |
+
|
| 811 |
+
if hit_pdb_code not in release_dates:
|
| 812 |
+
if hit_pdb_code in obsolete_pdbs:
|
| 813 |
+
hit_pdb_code = obsolete_pdbs[hit_pdb_code]
|
| 814 |
+
|
| 815 |
+
mapping = _build_query_to_hit_index_mapping(
|
| 816 |
+
hit.query,
|
| 817 |
+
hit.hit_sequence,
|
| 818 |
+
hit.indices_hit,
|
| 819 |
+
hit.indices_query,
|
| 820 |
+
query_sequence,
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
# The mapping is from the query to the actual hit sequence, so we need to
|
| 824 |
+
# remove gaps (which regardless have a missing confidence score).
|
| 825 |
+
template_sequence = hit.hit_sequence.replace("-", "")
|
| 826 |
+
|
| 827 |
+
cif_path = os.path.join(mmcif_dir, hit_pdb_code + ".cif")
|
| 828 |
+
logging.info(
|
| 829 |
+
"Reading PDB entry from %s. Query: %s, template: %s",
|
| 830 |
+
cif_path,
|
| 831 |
+
query_sequence,
|
| 832 |
+
template_sequence,
|
| 833 |
+
)
|
| 834 |
+
# Fail if we can't find the mmCIF file.
|
| 835 |
+
with open(cif_path, "r") as cif_file:
|
| 836 |
+
cif_string = cif_file.read()
|
| 837 |
+
|
| 838 |
+
parsing_result = mmcif_parsing.parse(
|
| 839 |
+
file_id=hit_pdb_code, mmcif_string=cif_string
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
if parsing_result.mmcif_object is not None:
|
| 843 |
+
hit_release_date = datetime.datetime.strptime(
|
| 844 |
+
parsing_result.mmcif_object.header["release_date"], "%Y-%m-%d"
|
| 845 |
+
)
|
| 846 |
+
if hit_release_date > max_template_date:
|
| 847 |
+
error = "Template %s date (%s) > max template date (%s)." % (
|
| 848 |
+
hit_pdb_code,
|
| 849 |
+
hit_release_date,
|
| 850 |
+
max_template_date,
|
| 851 |
+
)
|
| 852 |
+
if strict_error_check:
|
| 853 |
+
return SingleHitResult(features=None, error=error, warning=None)
|
| 854 |
+
else:
|
| 855 |
+
logging.info(error)
|
| 856 |
+
return SingleHitResult(features=None, error=None, warning=None)
|
| 857 |
+
|
| 858 |
+
try:
|
| 859 |
+
features, realign_warning = _extract_template_features(
|
| 860 |
+
mmcif_object=parsing_result.mmcif_object,
|
| 861 |
+
pdb_id=hit_pdb_code,
|
| 862 |
+
mapping=mapping,
|
| 863 |
+
template_sequence=template_sequence,
|
| 864 |
+
query_sequence=query_sequence,
|
| 865 |
+
template_chain_id=hit_chain_id,
|
| 866 |
+
kalign_binary_path=kalign_binary_path,
|
| 867 |
+
_zero_center_positions=_zero_center_positions,
|
| 868 |
+
)
|
| 869 |
+
features["template_sum_probs"] = [hit.sum_probs]
|
| 870 |
+
|
| 871 |
+
# It is possible there were some errors when parsing the other chains in the
|
| 872 |
+
# mmCIF file, but the template features for the chain we want were still
|
| 873 |
+
# computed. In such case the mmCIF parsing errors are not relevant.
|
| 874 |
+
return SingleHitResult(
|
| 875 |
+
features=features, error=None, warning=realign_warning
|
| 876 |
+
)
|
| 877 |
+
except (
|
| 878 |
+
NoChainsError,
|
| 879 |
+
NoAtomDataInTemplateError,
|
| 880 |
+
TemplateAtomMaskAllZerosError,
|
| 881 |
+
) as e:
|
| 882 |
+
# These 3 errors indicate missing mmCIF experimental data rather than a
|
| 883 |
+
# problem with the template search, so turn them into warnings.
|
| 884 |
+
warning = (
|
| 885 |
+
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
|
| 886 |
+
"%s, mmCIF parsing errors: %s"
|
| 887 |
+
% (
|
| 888 |
+
hit_pdb_code,
|
| 889 |
+
hit_chain_id,
|
| 890 |
+
hit.sum_probs,
|
| 891 |
+
hit.index,
|
| 892 |
+
str(e),
|
| 893 |
+
parsing_result.errors,
|
| 894 |
+
)
|
| 895 |
+
)
|
| 896 |
+
if strict_error_check:
|
| 897 |
+
return SingleHitResult(features=None, error=warning, warning=None)
|
| 898 |
+
else:
|
| 899 |
+
return SingleHitResult(features=None, error=None, warning=warning)
|
| 900 |
+
except Error as e:
|
| 901 |
+
error = (
|
| 902 |
+
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
|
| 903 |
+
"%s, mmCIF parsing errors: %s"
|
| 904 |
+
% (
|
| 905 |
+
hit_pdb_code,
|
| 906 |
+
hit_chain_id,
|
| 907 |
+
hit.sum_probs,
|
| 908 |
+
hit.index,
|
| 909 |
+
str(e),
|
| 910 |
+
parsing_result.errors,
|
| 911 |
+
)
|
| 912 |
+
)
|
| 913 |
+
return SingleHitResult(features=None, error=error, warning=None)
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
@dataclasses.dataclass(frozen=True)
|
| 917 |
+
class TemplateSearchResult:
|
| 918 |
+
features: Mapping[str, Any]
|
| 919 |
+
errors: Sequence[str]
|
| 920 |
+
warnings: Sequence[str]
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
class TemplateHitFeaturizer:
|
| 924 |
+
"""A class for turning hhr hits to template features."""
|
| 925 |
+
def __init__(
|
| 926 |
+
self,
|
| 927 |
+
mmcif_dir: str,
|
| 928 |
+
max_template_date: str,
|
| 929 |
+
max_hits: int,
|
| 930 |
+
kalign_binary_path: str,
|
| 931 |
+
release_dates_path: Optional[str] = None,
|
| 932 |
+
obsolete_pdbs_path: Optional[str] = None,
|
| 933 |
+
strict_error_check: bool = False,
|
| 934 |
+
_shuffle_top_k_prefiltered: Optional[int] = None,
|
| 935 |
+
_zero_center_positions: bool = True,
|
| 936 |
+
):
|
| 937 |
+
"""Initializes the Template Search.
|
| 938 |
+
|
| 939 |
+
Args:
|
| 940 |
+
mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
|
| 941 |
+
is found by HHSearch, this directory is used to retrieve the template
|
| 942 |
+
data.
|
| 943 |
+
max_template_date: The maximum date permitted for template structures. No
|
| 944 |
+
template with date higher than this date will be returned. In ISO8601
|
| 945 |
+
date format, YYYY-MM-DD.
|
| 946 |
+
max_hits: The maximum number of templates that will be returned.
|
| 947 |
+
kalign_binary_path: The path to a kalign executable used for template
|
| 948 |
+
realignment.
|
| 949 |
+
release_dates_path: An optional path to a file with a mapping from PDB IDs
|
| 950 |
+
to their release dates. Thanks to this we don't have to redundantly
|
| 951 |
+
parse mmCIF files to get that information.
|
| 952 |
+
obsolete_pdbs_path: An optional path to a file containing a mapping from
|
| 953 |
+
obsolete PDB IDs to the PDB IDs of their replacements.
|
| 954 |
+
strict_error_check: If True, then the following will be treated as errors:
|
| 955 |
+
* If any template date is after the max_template_date.
|
| 956 |
+
* If any template has identical PDB ID to the query.
|
| 957 |
+
* If any template is a duplicate of the query.
|
| 958 |
+
* Any feature computation errors.
|
| 959 |
+
"""
|
| 960 |
+
self._mmcif_dir = mmcif_dir
|
| 961 |
+
if not glob.glob(os.path.join(self._mmcif_dir, "*.cif")):
|
| 962 |
+
logging.error("Could not find CIFs in %s", self._mmcif_dir)
|
| 963 |
+
raise ValueError(f"Could not find CIFs in {self._mmcif_dir}")
|
| 964 |
+
|
| 965 |
+
try:
|
| 966 |
+
self._max_template_date = datetime.datetime.strptime(
|
| 967 |
+
max_template_date, "%Y-%m-%d"
|
| 968 |
+
)
|
| 969 |
+
except ValueError:
|
| 970 |
+
raise ValueError(
|
| 971 |
+
"max_template_date must be set and have format YYYY-MM-DD."
|
| 972 |
+
)
|
| 973 |
+
self.max_hits = max_hits
|
| 974 |
+
self._kalign_binary_path = kalign_binary_path
|
| 975 |
+
self._strict_error_check = strict_error_check
|
| 976 |
+
|
| 977 |
+
if release_dates_path:
|
| 978 |
+
logging.info(
|
| 979 |
+
"Using precomputed release dates %s.", release_dates_path
|
| 980 |
+
)
|
| 981 |
+
self._release_dates = _parse_release_dates(release_dates_path)
|
| 982 |
+
else:
|
| 983 |
+
self._release_dates = {}
|
| 984 |
+
|
| 985 |
+
if obsolete_pdbs_path:
|
| 986 |
+
logging.info(
|
| 987 |
+
"Using precomputed obsolete pdbs %s.", obsolete_pdbs_path
|
| 988 |
+
)
|
| 989 |
+
self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
|
| 990 |
+
else:
|
| 991 |
+
self._obsolete_pdbs = {}
|
| 992 |
+
|
| 993 |
+
self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
|
| 994 |
+
self._zero_center_positions = _zero_center_positions
|
| 995 |
+
|
| 996 |
+
def get_templates(
|
| 997 |
+
self,
|
| 998 |
+
query_sequence: str,
|
| 999 |
+
query_pdb_code: Optional[str],
|
| 1000 |
+
query_release_date: Optional[datetime.datetime],
|
| 1001 |
+
hits: Sequence[parsers.TemplateHit],
|
| 1002 |
+
) -> TemplateSearchResult:
|
| 1003 |
+
"""Computes the templates for given query sequence (more details above)."""
|
| 1004 |
+
logging.info("Searching for template for: %s", query_pdb_code)
|
| 1005 |
+
|
| 1006 |
+
template_features = {}
|
| 1007 |
+
for template_feature_name in TEMPLATE_FEATURES:
|
| 1008 |
+
template_features[template_feature_name] = []
|
| 1009 |
+
|
| 1010 |
+
# Always use a max_template_date. Set to query_release_date minus 60 days
|
| 1011 |
+
# if that's earlier.
|
| 1012 |
+
template_cutoff_date = self._max_template_date
|
| 1013 |
+
if query_release_date:
|
| 1014 |
+
delta = datetime.timedelta(days=60)
|
| 1015 |
+
if query_release_date - delta < template_cutoff_date:
|
| 1016 |
+
template_cutoff_date = query_release_date - delta
|
| 1017 |
+
assert template_cutoff_date < query_release_date
|
| 1018 |
+
assert template_cutoff_date <= self._max_template_date
|
| 1019 |
+
|
| 1020 |
+
num_hits = 0
|
| 1021 |
+
errors = []
|
| 1022 |
+
warnings = []
|
| 1023 |
+
|
| 1024 |
+
filtered = []
|
| 1025 |
+
for hit in hits:
|
| 1026 |
+
prefilter_result = _prefilter_hit(
|
| 1027 |
+
query_sequence=query_sequence,
|
| 1028 |
+
query_pdb_code=query_pdb_code,
|
| 1029 |
+
hit=hit,
|
| 1030 |
+
max_template_date=template_cutoff_date,
|
| 1031 |
+
release_dates=self._release_dates,
|
| 1032 |
+
obsolete_pdbs=self._obsolete_pdbs,
|
| 1033 |
+
strict_error_check=self._strict_error_check,
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
if prefilter_result.error:
|
| 1037 |
+
errors.append(prefilter_result.error)
|
| 1038 |
+
|
| 1039 |
+
if prefilter_result.warning:
|
| 1040 |
+
warnings.append(prefilter_result.warning)
|
| 1041 |
+
|
| 1042 |
+
if prefilter_result.valid:
|
| 1043 |
+
filtered.append(hit)
|
| 1044 |
+
|
| 1045 |
+
filtered = list(
|
| 1046 |
+
sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
idx = list(range(len(filtered)))
|
| 1050 |
+
if(self._shuffle_top_k_prefiltered):
|
| 1051 |
+
stk = self._shuffle_top_k_prefiltered
|
| 1052 |
+
idx[:stk] = np.random.permutation(idx[:stk])
|
| 1053 |
+
|
| 1054 |
+
for i in idx:
|
| 1055 |
+
# We got all the templates we wanted, stop processing hits.
|
| 1056 |
+
if num_hits >= self.max_hits:
|
| 1057 |
+
break
|
| 1058 |
+
|
| 1059 |
+
hit = filtered[i]
|
| 1060 |
+
|
| 1061 |
+
result = _process_single_hit(
|
| 1062 |
+
query_sequence=query_sequence,
|
| 1063 |
+
query_pdb_code=query_pdb_code,
|
| 1064 |
+
hit=hit,
|
| 1065 |
+
mmcif_dir=self._mmcif_dir,
|
| 1066 |
+
max_template_date=template_cutoff_date,
|
| 1067 |
+
release_dates=self._release_dates,
|
| 1068 |
+
obsolete_pdbs=self._obsolete_pdbs,
|
| 1069 |
+
strict_error_check=self._strict_error_check,
|
| 1070 |
+
kalign_binary_path=self._kalign_binary_path,
|
| 1071 |
+
_zero_center_positions=self._zero_center_positions,
|
| 1072 |
+
)
|
| 1073 |
+
|
| 1074 |
+
if result.error:
|
| 1075 |
+
errors.append(result.error)
|
| 1076 |
+
|
| 1077 |
+
# There could be an error even if there are some results, e.g. thrown by
|
| 1078 |
+
# other unparsable chains in the same mmCIF file.
|
| 1079 |
+
if result.warning:
|
| 1080 |
+
warnings.append(result.warning)
|
| 1081 |
+
|
| 1082 |
+
if result.features is None:
|
| 1083 |
+
logging.info(
|
| 1084 |
+
"Skipped invalid hit %s, error: %s, warning: %s",
|
| 1085 |
+
hit.name,
|
| 1086 |
+
result.error,
|
| 1087 |
+
result.warning,
|
| 1088 |
+
)
|
| 1089 |
+
else:
|
| 1090 |
+
# Increment the hit counter, since we got features out of this hit.
|
| 1091 |
+
num_hits += 1
|
| 1092 |
+
for k in template_features:
|
| 1093 |
+
template_features[k].append(result.features[k])
|
| 1094 |
+
|
| 1095 |
+
for name in template_features:
|
| 1096 |
+
if num_hits > 0:
|
| 1097 |
+
template_features[name] = np.stack(
|
| 1098 |
+
template_features[name], axis=0
|
| 1099 |
+
).astype(TEMPLATE_FEATURES[name])
|
| 1100 |
+
else:
|
| 1101 |
+
# Make sure the feature has correct dtype even if empty.
|
| 1102 |
+
template_features[name] = np.array(
|
| 1103 |
+
[], dtype=TEMPLATE_FEATURES[name]
|
| 1104 |
+
)
|
| 1105 |
+
|
| 1106 |
+
return TemplateSearchResult(
|
| 1107 |
+
features=template_features, errors=errors, warnings=warnings
|
| 1108 |
+
)
|
openfold/data/tools/__init__.py
ADDED
|
File without changes
|
openfold/data/tools/hhblits.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Library to run HHblits from Python."""
|
| 17 |
+
import glob
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
import subprocess
|
| 21 |
+
from typing import Any, Mapping, Optional, Sequence
|
| 22 |
+
|
| 23 |
+
from openfold.data.tools import utils
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_HHBLITS_DEFAULT_P = 20
|
| 27 |
+
_HHBLITS_DEFAULT_Z = 500
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class HHBlits:
|
| 31 |
+
"""Python wrapper of the HHblits binary."""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
*,
|
| 36 |
+
binary_path: str,
|
| 37 |
+
databases: Sequence[str],
|
| 38 |
+
n_cpu: int = 4,
|
| 39 |
+
n_iter: int = 3,
|
| 40 |
+
e_value: float = 0.001,
|
| 41 |
+
maxseq: int = 1_000_000,
|
| 42 |
+
realign_max: int = 100_000,
|
| 43 |
+
maxfilt: int = 100_000,
|
| 44 |
+
min_prefilter_hits: int = 1000,
|
| 45 |
+
all_seqs: bool = False,
|
| 46 |
+
alt: Optional[int] = None,
|
| 47 |
+
p: int = _HHBLITS_DEFAULT_P,
|
| 48 |
+
z: int = _HHBLITS_DEFAULT_Z,
|
| 49 |
+
):
|
| 50 |
+
"""Initializes the Python HHblits wrapper.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
binary_path: The path to the HHblits executable.
|
| 54 |
+
databases: A sequence of HHblits database paths. This should be the
|
| 55 |
+
common prefix for the database files (i.e. up to but not including
|
| 56 |
+
_hhm.ffindex etc.)
|
| 57 |
+
n_cpu: The number of CPUs to give HHblits.
|
| 58 |
+
n_iter: The number of HHblits iterations.
|
| 59 |
+
e_value: The E-value, see HHblits docs for more details.
|
| 60 |
+
maxseq: The maximum number of rows in an input alignment. Note that this
|
| 61 |
+
parameter is only supported in HHBlits version 3.1 and higher.
|
| 62 |
+
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
|
| 63 |
+
maxfilt: Max number of hits allowed to pass the 2nd prefilter.
|
| 64 |
+
HHblits default: 20000.
|
| 65 |
+
min_prefilter_hits: Min number of hits to pass prefilter.
|
| 66 |
+
HHblits default: 100.
|
| 67 |
+
all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
|
| 68 |
+
HHblits default: False.
|
| 69 |
+
alt: Show up to this many alternative alignments.
|
| 70 |
+
p: Minimum Prob for a hit to be included in the output hhr file.
|
| 71 |
+
HHblits default: 20.
|
| 72 |
+
z: Hard cap on number of hits reported in the hhr file.
|
| 73 |
+
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
|
| 74 |
+
|
| 75 |
+
Raises:
|
| 76 |
+
RuntimeError: If HHblits binary not found within the path.
|
| 77 |
+
"""
|
| 78 |
+
self.binary_path = binary_path
|
| 79 |
+
self.databases = databases
|
| 80 |
+
|
| 81 |
+
for database_path in self.databases:
|
| 82 |
+
if not glob.glob(database_path + "_*"):
|
| 83 |
+
logging.error(
|
| 84 |
+
"Could not find HHBlits database %s", database_path
|
| 85 |
+
)
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"Could not find HHBlits database {database_path}"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
self.n_cpu = n_cpu
|
| 91 |
+
self.n_iter = n_iter
|
| 92 |
+
self.e_value = e_value
|
| 93 |
+
self.maxseq = maxseq
|
| 94 |
+
self.realign_max = realign_max
|
| 95 |
+
self.maxfilt = maxfilt
|
| 96 |
+
self.min_prefilter_hits = min_prefilter_hits
|
| 97 |
+
self.all_seqs = all_seqs
|
| 98 |
+
self.alt = alt
|
| 99 |
+
self.p = p
|
| 100 |
+
self.z = z
|
| 101 |
+
|
| 102 |
+
def query(self, input_fasta_path: str) -> Mapping[str, Any]:
|
| 103 |
+
"""Queries the database using HHblits."""
|
| 104 |
+
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
|
| 105 |
+
a3m_path = os.path.join(query_tmp_dir, "output.a3m")
|
| 106 |
+
|
| 107 |
+
db_cmd = []
|
| 108 |
+
for db_path in self.databases:
|
| 109 |
+
db_cmd.append("-d")
|
| 110 |
+
db_cmd.append(db_path)
|
| 111 |
+
cmd = [
|
| 112 |
+
self.binary_path,
|
| 113 |
+
"-i",
|
| 114 |
+
input_fasta_path,
|
| 115 |
+
"-cpu",
|
| 116 |
+
str(self.n_cpu),
|
| 117 |
+
"-oa3m",
|
| 118 |
+
a3m_path,
|
| 119 |
+
"-o",
|
| 120 |
+
"/dev/null",
|
| 121 |
+
"-n",
|
| 122 |
+
str(self.n_iter),
|
| 123 |
+
"-e",
|
| 124 |
+
str(self.e_value),
|
| 125 |
+
"-maxseq",
|
| 126 |
+
str(self.maxseq),
|
| 127 |
+
"-realign_max",
|
| 128 |
+
str(self.realign_max),
|
| 129 |
+
"-maxfilt",
|
| 130 |
+
str(self.maxfilt),
|
| 131 |
+
"-min_prefilter_hits",
|
| 132 |
+
str(self.min_prefilter_hits),
|
| 133 |
+
]
|
| 134 |
+
if self.all_seqs:
|
| 135 |
+
cmd += ["-all"]
|
| 136 |
+
if self.alt:
|
| 137 |
+
cmd += ["-alt", str(self.alt)]
|
| 138 |
+
if self.p != _HHBLITS_DEFAULT_P:
|
| 139 |
+
cmd += ["-p", str(self.p)]
|
| 140 |
+
if self.z != _HHBLITS_DEFAULT_Z:
|
| 141 |
+
cmd += ["-Z", str(self.z)]
|
| 142 |
+
cmd += db_cmd
|
| 143 |
+
|
| 144 |
+
logging.info('Launching subprocess "%s"', " ".join(cmd))
|
| 145 |
+
process = subprocess.Popen(
|
| 146 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
with utils.timing("HHblits query"):
|
| 150 |
+
stdout, stderr = process.communicate()
|
| 151 |
+
retcode = process.wait()
|
| 152 |
+
|
| 153 |
+
if retcode:
|
| 154 |
+
# Logs have a 15k character limit, so log HHblits error line by line.
|
| 155 |
+
logging.error("HHblits failed. HHblits stderr begin:")
|
| 156 |
+
for error_line in stderr.decode("utf-8").splitlines():
|
| 157 |
+
if error_line.strip():
|
| 158 |
+
logging.error(error_line.strip())
|
| 159 |
+
logging.error("HHblits stderr end")
|
| 160 |
+
raise RuntimeError(
|
| 161 |
+
"HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n"
|
| 162 |
+
% (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8"))
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
with open(a3m_path) as f:
|
| 166 |
+
a3m = f.read()
|
| 167 |
+
|
| 168 |
+
raw_output = dict(
|
| 169 |
+
a3m=a3m,
|
| 170 |
+
output=stdout,
|
| 171 |
+
stderr=stderr,
|
| 172 |
+
n_iter=self.n_iter,
|
| 173 |
+
e_value=self.e_value,
|
| 174 |
+
)
|
| 175 |
+
return raw_output
|
openfold/data/tools/hhsearch.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Library to run HHsearch from Python."""
|
| 17 |
+
import glob
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
import subprocess
|
| 21 |
+
from typing import Sequence
|
| 22 |
+
|
| 23 |
+
from openfold.data.tools import utils
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class HHSearch:
|
| 27 |
+
"""Python wrapper of the HHsearch binary."""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
*,
|
| 32 |
+
binary_path: str,
|
| 33 |
+
databases: Sequence[str],
|
| 34 |
+
n_cpu: int = 2,
|
| 35 |
+
maxseq: int = 1_000_000,
|
| 36 |
+
):
|
| 37 |
+
"""Initializes the Python HHsearch wrapper.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
binary_path: The path to the HHsearch executable.
|
| 41 |
+
databases: A sequence of HHsearch database paths. This should be the
|
| 42 |
+
common prefix for the database files (i.e. up to but not including
|
| 43 |
+
_hhm.ffindex etc.)
|
| 44 |
+
n_cpu: The number of CPUs to use
|
| 45 |
+
maxseq: The maximum number of rows in an input alignment. Note that this
|
| 46 |
+
parameter is only supported in HHBlits version 3.1 and higher.
|
| 47 |
+
|
| 48 |
+
Raises:
|
| 49 |
+
RuntimeError: If HHsearch binary not found within the path.
|
| 50 |
+
"""
|
| 51 |
+
self.binary_path = binary_path
|
| 52 |
+
self.databases = databases
|
| 53 |
+
self.n_cpu = n_cpu
|
| 54 |
+
self.maxseq = maxseq
|
| 55 |
+
|
| 56 |
+
for database_path in self.databases:
|
| 57 |
+
if not glob.glob(database_path + "_*"):
|
| 58 |
+
logging.error(
|
| 59 |
+
"Could not find HHsearch database %s", database_path
|
| 60 |
+
)
|
| 61 |
+
raise ValueError(
|
| 62 |
+
f"Could not find HHsearch database {database_path}"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def query(self, a3m: str) -> str:
|
| 66 |
+
"""Queries the database using HHsearch using a given a3m."""
|
| 67 |
+
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
|
| 68 |
+
input_path = os.path.join(query_tmp_dir, "query.a3m")
|
| 69 |
+
hhr_path = os.path.join(query_tmp_dir, "output.hhr")
|
| 70 |
+
with open(input_path, "w") as f:
|
| 71 |
+
f.write(a3m)
|
| 72 |
+
|
| 73 |
+
db_cmd = []
|
| 74 |
+
for db_path in self.databases:
|
| 75 |
+
db_cmd.append("-d")
|
| 76 |
+
db_cmd.append(db_path)
|
| 77 |
+
cmd = [
|
| 78 |
+
self.binary_path,
|
| 79 |
+
"-i",
|
| 80 |
+
input_path,
|
| 81 |
+
"-o",
|
| 82 |
+
hhr_path,
|
| 83 |
+
"-maxseq",
|
| 84 |
+
str(self.maxseq),
|
| 85 |
+
"-cpu",
|
| 86 |
+
str(self.n_cpu),
|
| 87 |
+
] + db_cmd
|
| 88 |
+
|
| 89 |
+
logging.info('Launching subprocess "%s"', " ".join(cmd))
|
| 90 |
+
process = subprocess.Popen(
|
| 91 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| 92 |
+
)
|
| 93 |
+
with utils.timing("HHsearch query"):
|
| 94 |
+
stdout, stderr = process.communicate()
|
| 95 |
+
retcode = process.wait()
|
| 96 |
+
|
| 97 |
+
if retcode:
|
| 98 |
+
# Stderr is truncated to prevent proto size errors in Beam.
|
| 99 |
+
raise RuntimeError(
|
| 100 |
+
"HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
|
| 101 |
+
% (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
with open(hhr_path) as f:
|
| 105 |
+
hhr = f.read()
|
| 106 |
+
return hhr
|
openfold/data/tools/jackhmmer.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Library to run Jackhmmer from Python."""
|
| 17 |
+
|
| 18 |
+
from concurrent import futures
|
| 19 |
+
import glob
|
| 20 |
+
import logging
|
| 21 |
+
import os
|
| 22 |
+
import subprocess
|
| 23 |
+
from typing import Any, Callable, Mapping, Optional, Sequence
|
| 24 |
+
from urllib import request
|
| 25 |
+
|
| 26 |
+
from openfold.data.tools import utils
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Jackhmmer:
|
| 30 |
+
"""Python wrapper of the Jackhmmer binary."""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
*,
|
| 35 |
+
binary_path: str,
|
| 36 |
+
database_path: str,
|
| 37 |
+
n_cpu: int = 8,
|
| 38 |
+
n_iter: int = 1,
|
| 39 |
+
e_value: float = 0.0001,
|
| 40 |
+
z_value: Optional[int] = None,
|
| 41 |
+
get_tblout: bool = False,
|
| 42 |
+
filter_f1: float = 0.0005,
|
| 43 |
+
filter_f2: float = 0.00005,
|
| 44 |
+
filter_f3: float = 0.0000005,
|
| 45 |
+
incdom_e: Optional[float] = None,
|
| 46 |
+
dom_e: Optional[float] = None,
|
| 47 |
+
num_streamed_chunks: Optional[int] = None,
|
| 48 |
+
streaming_callback: Optional[Callable[[int], None]] = None,
|
| 49 |
+
):
|
| 50 |
+
"""Initializes the Python Jackhmmer wrapper.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
binary_path: The path to the jackhmmer executable.
|
| 54 |
+
database_path: The path to the jackhmmer database (FASTA format).
|
| 55 |
+
n_cpu: The number of CPUs to give Jackhmmer.
|
| 56 |
+
n_iter: The number of Jackhmmer iterations.
|
| 57 |
+
e_value: The E-value, see Jackhmmer docs for more details.
|
| 58 |
+
z_value: The Z-value, see Jackhmmer docs for more details.
|
| 59 |
+
get_tblout: Whether to save tblout string.
|
| 60 |
+
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
|
| 61 |
+
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
|
| 62 |
+
filter_f3: Forward pre-filter, set to >1.0 to turn off.
|
| 63 |
+
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
|
| 64 |
+
round.
|
| 65 |
+
dom_e: Domain e-value criteria for inclusion in tblout.
|
| 66 |
+
num_streamed_chunks: Number of database chunks to stream over.
|
| 67 |
+
streaming_callback: Callback function run after each chunk iteration with
|
| 68 |
+
the iteration number as argument.
|
| 69 |
+
"""
|
| 70 |
+
self.binary_path = binary_path
|
| 71 |
+
self.database_path = database_path
|
| 72 |
+
self.num_streamed_chunks = num_streamed_chunks
|
| 73 |
+
|
| 74 |
+
if (
|
| 75 |
+
not os.path.exists(self.database_path)
|
| 76 |
+
and num_streamed_chunks is None
|
| 77 |
+
):
|
| 78 |
+
logging.error("Could not find Jackhmmer database %s", database_path)
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"Could not find Jackhmmer database {database_path}"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.n_cpu = n_cpu
|
| 84 |
+
self.n_iter = n_iter
|
| 85 |
+
self.e_value = e_value
|
| 86 |
+
self.z_value = z_value
|
| 87 |
+
self.filter_f1 = filter_f1
|
| 88 |
+
self.filter_f2 = filter_f2
|
| 89 |
+
self.filter_f3 = filter_f3
|
| 90 |
+
self.incdom_e = incdom_e
|
| 91 |
+
self.dom_e = dom_e
|
| 92 |
+
self.get_tblout = get_tblout
|
| 93 |
+
self.streaming_callback = streaming_callback
|
| 94 |
+
|
| 95 |
+
def _query_chunk(
|
| 96 |
+
self, input_fasta_path: str, database_path: str
|
| 97 |
+
) -> Mapping[str, Any]:
|
| 98 |
+
"""Queries the database chunk using Jackhmmer."""
|
| 99 |
+
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
|
| 100 |
+
sto_path = os.path.join(query_tmp_dir, "output.sto")
|
| 101 |
+
|
| 102 |
+
# The F1/F2/F3 are the expected proportion to pass each of the filtering
|
| 103 |
+
# stages (which get progressively more expensive), reducing these
|
| 104 |
+
# speeds up the pipeline at the expensive of sensitivity. They are
|
| 105 |
+
# currently set very low to make querying Mgnify run in a reasonable
|
| 106 |
+
# amount of time.
|
| 107 |
+
cmd_flags = [
|
| 108 |
+
# Don't pollute stdout with Jackhmmer output.
|
| 109 |
+
"-o",
|
| 110 |
+
"/dev/null",
|
| 111 |
+
"-A",
|
| 112 |
+
sto_path,
|
| 113 |
+
"--noali",
|
| 114 |
+
"--F1",
|
| 115 |
+
str(self.filter_f1),
|
| 116 |
+
"--F2",
|
| 117 |
+
str(self.filter_f2),
|
| 118 |
+
"--F3",
|
| 119 |
+
str(self.filter_f3),
|
| 120 |
+
"--incE",
|
| 121 |
+
str(self.e_value),
|
| 122 |
+
# Report only sequences with E-values <= x in per-sequence output.
|
| 123 |
+
"-E",
|
| 124 |
+
str(self.e_value),
|
| 125 |
+
"--cpu",
|
| 126 |
+
str(self.n_cpu),
|
| 127 |
+
"-N",
|
| 128 |
+
str(self.n_iter),
|
| 129 |
+
]
|
| 130 |
+
if self.get_tblout:
|
| 131 |
+
tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
|
| 132 |
+
cmd_flags.extend(["--tblout", tblout_path])
|
| 133 |
+
|
| 134 |
+
if self.z_value:
|
| 135 |
+
cmd_flags.extend(["-Z", str(self.z_value)])
|
| 136 |
+
|
| 137 |
+
if self.dom_e is not None:
|
| 138 |
+
cmd_flags.extend(["--domE", str(self.dom_e)])
|
| 139 |
+
|
| 140 |
+
if self.incdom_e is not None:
|
| 141 |
+
cmd_flags.extend(["--incdomE", str(self.incdom_e)])
|
| 142 |
+
|
| 143 |
+
cmd = (
|
| 144 |
+
[self.binary_path]
|
| 145 |
+
+ cmd_flags
|
| 146 |
+
+ [input_fasta_path, database_path]
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
logging.info('Launching subprocess "%s"', " ".join(cmd))
|
| 150 |
+
process = subprocess.Popen(
|
| 151 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| 152 |
+
)
|
| 153 |
+
with utils.timing(
|
| 154 |
+
f"Jackhmmer ({os.path.basename(database_path)}) query"
|
| 155 |
+
):
|
| 156 |
+
_, stderr = process.communicate()
|
| 157 |
+
retcode = process.wait()
|
| 158 |
+
|
| 159 |
+
if retcode:
|
| 160 |
+
raise RuntimeError(
|
| 161 |
+
"Jackhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Get e-values for each target name
|
| 165 |
+
tbl = ""
|
| 166 |
+
if self.get_tblout:
|
| 167 |
+
with open(tblout_path) as f:
|
| 168 |
+
tbl = f.read()
|
| 169 |
+
|
| 170 |
+
with open(sto_path) as f:
|
| 171 |
+
sto = f.read()
|
| 172 |
+
|
| 173 |
+
raw_output = dict(
|
| 174 |
+
sto=sto,
|
| 175 |
+
tbl=tbl,
|
| 176 |
+
stderr=stderr,
|
| 177 |
+
n_iter=self.n_iter,
|
| 178 |
+
e_value=self.e_value,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
return raw_output
|
| 182 |
+
|
| 183 |
+
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
|
| 184 |
+
"""Queries the database using Jackhmmer."""
|
| 185 |
+
if self.num_streamed_chunks is None:
|
| 186 |
+
return [self._query_chunk(input_fasta_path, self.database_path)]
|
| 187 |
+
|
| 188 |
+
db_basename = os.path.basename(self.database_path)
|
| 189 |
+
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
|
| 190 |
+
db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
|
| 191 |
+
|
| 192 |
+
# Remove existing files to prevent OOM
|
| 193 |
+
for f in glob.glob(db_local_chunk("[0-9]*")):
|
| 194 |
+
try:
|
| 195 |
+
os.remove(f)
|
| 196 |
+
except OSError:
|
| 197 |
+
print(f"OSError while deleting {f}")
|
| 198 |
+
|
| 199 |
+
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
|
| 200 |
+
with futures.ThreadPoolExecutor(max_workers=2) as executor:
|
| 201 |
+
chunked_output = []
|
| 202 |
+
for i in range(1, self.num_streamed_chunks + 1):
|
| 203 |
+
# Copy the chunk locally
|
| 204 |
+
if i == 1:
|
| 205 |
+
future = executor.submit(
|
| 206 |
+
request.urlretrieve,
|
| 207 |
+
db_remote_chunk(i),
|
| 208 |
+
db_local_chunk(i),
|
| 209 |
+
)
|
| 210 |
+
if i < self.num_streamed_chunks:
|
| 211 |
+
next_future = executor.submit(
|
| 212 |
+
request.urlretrieve,
|
| 213 |
+
db_remote_chunk(i + 1),
|
| 214 |
+
db_local_chunk(i + 1),
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Run Jackhmmer with the chunk
|
| 218 |
+
future.result()
|
| 219 |
+
chunked_output.append(
|
| 220 |
+
self._query_chunk(input_fasta_path, db_local_chunk(i))
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Remove the local copy of the chunk
|
| 224 |
+
os.remove(db_local_chunk(i))
|
| 225 |
+
future = next_future
|
| 226 |
+
if self.streaming_callback:
|
| 227 |
+
self.streaming_callback(i)
|
| 228 |
+
return chunked_output
|
openfold/data/tools/kalign.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""A Python wrapper for Kalign."""
|
| 17 |
+
import os
|
| 18 |
+
import subprocess
|
| 19 |
+
from typing import Sequence
|
| 20 |
+
|
| 21 |
+
from absl import logging
|
| 22 |
+
|
| 23 |
+
from openfold.data.tools import utils
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _to_a3m(sequences: Sequence[str]) -> str:
|
| 27 |
+
"""Converts sequences to an a3m file."""
|
| 28 |
+
names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
|
| 29 |
+
a3m = []
|
| 30 |
+
for sequence, name in zip(sequences, names):
|
| 31 |
+
a3m.append(u">" + name + u"\n")
|
| 32 |
+
a3m.append(sequence + u"\n")
|
| 33 |
+
return "".join(a3m)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Kalign:
|
| 37 |
+
"""Python wrapper of the Kalign binary."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, *, binary_path: str):
|
| 40 |
+
"""Initializes the Python Kalign wrapper.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
binary_path: The path to the Kalign binary.
|
| 44 |
+
|
| 45 |
+
Raises:
|
| 46 |
+
RuntimeError: If Kalign binary not found within the path.
|
| 47 |
+
"""
|
| 48 |
+
self.binary_path = binary_path
|
| 49 |
+
|
| 50 |
+
def align(self, sequences: Sequence[str]) -> str:
|
| 51 |
+
"""Aligns the sequences and returns the alignment in A3M string.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
sequences: A list of query sequence strings. The sequences have to be at
|
| 55 |
+
least 6 residues long (Kalign requires this). Note that the order in
|
| 56 |
+
which you give the sequences might alter the output slightly as
|
| 57 |
+
different alignment tree might get constructed.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
A string with the alignment in a3m format.
|
| 61 |
+
|
| 62 |
+
Raises:
|
| 63 |
+
RuntimeError: If Kalign fails.
|
| 64 |
+
ValueError: If any of the sequences is less than 6 residues long.
|
| 65 |
+
"""
|
| 66 |
+
logging.info("Aligning %d sequences", len(sequences))
|
| 67 |
+
|
| 68 |
+
for s in sequences:
|
| 69 |
+
if len(s) < 6:
|
| 70 |
+
raise ValueError(
|
| 71 |
+
"Kalign requires all sequences to be at least 6 "
|
| 72 |
+
"residues long. Got %s (%d residues)." % (s, len(s))
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
|
| 76 |
+
input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
|
| 77 |
+
output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
|
| 78 |
+
|
| 79 |
+
with open(input_fasta_path, "w") as f:
|
| 80 |
+
f.write(_to_a3m(sequences))
|
| 81 |
+
|
| 82 |
+
cmd = [
|
| 83 |
+
self.binary_path,
|
| 84 |
+
"-i",
|
| 85 |
+
input_fasta_path,
|
| 86 |
+
"-o",
|
| 87 |
+
output_a3m_path,
|
| 88 |
+
"-format",
|
| 89 |
+
"fasta",
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
logging.info('Launching subprocess "%s"', " ".join(cmd))
|
| 93 |
+
process = subprocess.Popen(
|
| 94 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
with utils.timing("Kalign query"):
|
| 98 |
+
stdout, stderr = process.communicate()
|
| 99 |
+
retcode = process.wait()
|
| 100 |
+
logging.info(
|
| 101 |
+
"Kalign stdout:\n%s\n\nstderr:\n%s\n",
|
| 102 |
+
stdout.decode("utf-8"),
|
| 103 |
+
stderr.decode("utf-8"),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if retcode:
|
| 107 |
+
raise RuntimeError(
|
| 108 |
+
"Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
|
| 109 |
+
% (stdout.decode("utf-8"), stderr.decode("utf-8"))
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
with open(output_a3m_path) as f:
|
| 113 |
+
a3m = f.read()
|
| 114 |
+
|
| 115 |
+
return a3m
|
openfold/data/tools/utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Common utilities for data pipeline tools."""
|
| 17 |
+
import contextlib
|
| 18 |
+
import datetime
|
| 19 |
+
import logging
|
| 20 |
+
import shutil
|
| 21 |
+
import tempfile
|
| 22 |
+
import time
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@contextlib.contextmanager
|
| 27 |
+
def tmpdir_manager(base_dir: Optional[str] = None):
|
| 28 |
+
"""Context manager that deletes a temporary directory on exit."""
|
| 29 |
+
tmpdir = tempfile.mkdtemp(dir=base_dir)
|
| 30 |
+
try:
|
| 31 |
+
yield tmpdir
|
| 32 |
+
finally:
|
| 33 |
+
shutil.rmtree(tmpdir, ignore_errors=True)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@contextlib.contextmanager
|
| 37 |
+
def timing(msg: str):
|
| 38 |
+
logging.info("Started %s", msg)
|
| 39 |
+
tic = time.perf_counter()
|
| 40 |
+
yield
|
| 41 |
+
toc = time.perf_counter()
|
| 42 |
+
logging.info("Finished %s in %.3f seconds", msg, toc - tic)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def to_date(s: str):
|
| 46 |
+
return datetime.datetime(
|
| 47 |
+
year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
|
| 48 |
+
)
|
openfold/np/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import importlib as importlib
|
| 4 |
+
|
| 5 |
+
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
|
| 6 |
+
__all__ = [
|
| 7 |
+
os.path.basename(f)[:-3]
|
| 8 |
+
for f in _files
|
| 9 |
+
if os.path.isfile(f) and not f.endswith("__init__.py")
|
| 10 |
+
]
|
| 11 |
+
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
|
| 12 |
+
for _m in _modules:
|
| 13 |
+
globals()[_m[0]] = _m[1]
|
| 14 |
+
|
| 15 |
+
# Avoid needlessly cluttering the global namespace
|
| 16 |
+
del _files, _m, _modules
|
openfold/np/protein.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Protein data type."""
|
| 17 |
+
import dataclasses
|
| 18 |
+
import io
|
| 19 |
+
from typing import Any, Sequence, Mapping, Optional
|
| 20 |
+
import re
|
| 21 |
+
import string
|
| 22 |
+
|
| 23 |
+
from openfold.np import residue_constants
|
| 24 |
+
from Bio import PDB
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
FeatureDict = Mapping[str, np.ndarray]
|
| 29 |
+
ModelOutput = Mapping[str, Any] # Is a nested dict.
|
| 30 |
+
PICO_TO_ANGSTROM = 0.01
|
| 31 |
+
|
| 32 |
+
@dataclasses.dataclass(frozen=True)
|
| 33 |
+
class Protein:
|
| 34 |
+
"""Protein structure representation."""
|
| 35 |
+
|
| 36 |
+
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
|
| 37 |
+
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
|
| 38 |
+
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
|
| 39 |
+
|
| 40 |
+
# Amino-acid type for each residue represented as an integer between 0 and
|
| 41 |
+
# 20, where 20 is 'X'.
|
| 42 |
+
aatype: np.ndarray # [num_res]
|
| 43 |
+
|
| 44 |
+
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
|
| 45 |
+
# is present and 0.0 if not. This should be used for loss masking.
|
| 46 |
+
atom_mask: np.ndarray # [num_res, num_atom_type]
|
| 47 |
+
|
| 48 |
+
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
|
| 49 |
+
residue_index: np.ndarray # [num_res]
|
| 50 |
+
|
| 51 |
+
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
|
| 52 |
+
# representing the displacement of the residue from its ground truth mean
|
| 53 |
+
# value.
|
| 54 |
+
b_factors: np.ndarray # [num_res, num_atom_type]
|
| 55 |
+
|
| 56 |
+
# Chain indices for multi-chain predictions
|
| 57 |
+
chain_index: Optional[np.ndarray] = None
|
| 58 |
+
|
| 59 |
+
# Optional remark about the protein. Included as a comment in output PDB
|
| 60 |
+
# files
|
| 61 |
+
remark: Optional[str] = None
|
| 62 |
+
|
| 63 |
+
# Templates used to generate this protein (prediction-only)
|
| 64 |
+
parents: Optional[Sequence[str]] = None
|
| 65 |
+
|
| 66 |
+
# Chain corresponding to each parent
|
| 67 |
+
parents_chain_index: Optional[Sequence[int]] = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
|
| 71 |
+
"""Takes a PDB string and constructs a Protein object.
|
| 72 |
+
|
| 73 |
+
WARNING: All non-standard residue types will be converted into UNK. All
|
| 74 |
+
non-standard atoms will be ignored.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
pdb_str: The contents of the pdb file
|
| 78 |
+
chain_id: If None, then the pdb file must contain a single chain (which
|
| 79 |
+
will be parsed). If chain_id is specified (e.g. A), then only that chain
|
| 80 |
+
is parsed.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
A new `Protein` parsed from the pdb contents.
|
| 84 |
+
"""
|
| 85 |
+
pdb_fh = pdb_str
|
| 86 |
+
parser = PDB.PDBParser(QUIET=True)
|
| 87 |
+
structure = parser.get_structure("none", pdb_fh)
|
| 88 |
+
models = list(structure.get_models())
|
| 89 |
+
if len(models) != 1:
|
| 90 |
+
raise ValueError(
|
| 91 |
+
f"Only single model PDBs are supported. Found {len(models)} models."
|
| 92 |
+
)
|
| 93 |
+
model = models[0]
|
| 94 |
+
|
| 95 |
+
atom_positions = []
|
| 96 |
+
aatype = []
|
| 97 |
+
atom_mask = []
|
| 98 |
+
residue_index = []
|
| 99 |
+
chain_ids = []
|
| 100 |
+
b_factors = []
|
| 101 |
+
|
| 102 |
+
for chain in model:
|
| 103 |
+
if(chain_id is not None and chain.id != chain_id):
|
| 104 |
+
continue
|
| 105 |
+
for res in chain:
|
| 106 |
+
if res.id[2] != " ":
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f"PDB contains an insertion code at chain {chain.id} and residue "
|
| 109 |
+
f"index {res.id[1]}. These are not supported."
|
| 110 |
+
)
|
| 111 |
+
res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
|
| 112 |
+
restype_idx = residue_constants.restype_order.get(
|
| 113 |
+
res_shortname, residue_constants.restype_num
|
| 114 |
+
)
|
| 115 |
+
pos = np.zeros((residue_constants.atom_type_num, 3))
|
| 116 |
+
mask = np.zeros((residue_constants.atom_type_num,))
|
| 117 |
+
res_b_factors = np.zeros((residue_constants.atom_type_num,))
|
| 118 |
+
for atom in res:
|
| 119 |
+
if atom.name not in residue_constants.atom_types:
|
| 120 |
+
continue
|
| 121 |
+
pos[residue_constants.atom_order[atom.name]] = atom.coord
|
| 122 |
+
mask[residue_constants.atom_order[atom.name]] = 1.0
|
| 123 |
+
res_b_factors[
|
| 124 |
+
residue_constants.atom_order[atom.name]
|
| 125 |
+
] = atom.bfactor
|
| 126 |
+
if np.sum(mask) < 0.5:
|
| 127 |
+
# If no known atom positions are reported for the residue then skip it.
|
| 128 |
+
continue
|
| 129 |
+
aatype.append(restype_idx)
|
| 130 |
+
atom_positions.append(pos)
|
| 131 |
+
atom_mask.append(mask)
|
| 132 |
+
residue_index.append(res.id[1])
|
| 133 |
+
chain_ids.append(chain.id)
|
| 134 |
+
b_factors.append(res_b_factors)
|
| 135 |
+
|
| 136 |
+
parents = None
|
| 137 |
+
parents_chain_index = None
|
| 138 |
+
if("PARENT" in pdb_str):
|
| 139 |
+
parents = []
|
| 140 |
+
parents_chain_index = []
|
| 141 |
+
chain_id = 0
|
| 142 |
+
for l in pdb_str.split("\n"):
|
| 143 |
+
if("PARENT" in l):
|
| 144 |
+
if(not "N/A" in l):
|
| 145 |
+
parent_names = l.split()[1:]
|
| 146 |
+
parents.extend(parent_names)
|
| 147 |
+
parents_chain_index.extend([
|
| 148 |
+
chain_id for _ in parent_names
|
| 149 |
+
])
|
| 150 |
+
chain_id += 1
|
| 151 |
+
|
| 152 |
+
unique_chain_ids = np.unique(chain_ids)
|
| 153 |
+
chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase)}
|
| 154 |
+
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
|
| 155 |
+
|
| 156 |
+
return Protein(
|
| 157 |
+
atom_positions=np.array(atom_positions),
|
| 158 |
+
atom_mask=np.array(atom_mask),
|
| 159 |
+
aatype=np.array(aatype),
|
| 160 |
+
residue_index=np.array(residue_index),
|
| 161 |
+
chain_index=chain_index,
|
| 162 |
+
b_factors=np.array(b_factors),
|
| 163 |
+
parents=parents,
|
| 164 |
+
parents_chain_index=parents_chain_index,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def from_proteinnet_string(proteinnet_str: str) -> Protein:
|
| 169 |
+
tag_re = r'(\[[A-Z]+\]\n)'
|
| 170 |
+
tags = [
|
| 171 |
+
tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
|
| 172 |
+
]
|
| 173 |
+
groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
|
| 174 |
+
|
| 175 |
+
atoms = ['N', 'CA', 'C']
|
| 176 |
+
aatype = None
|
| 177 |
+
atom_positions = None
|
| 178 |
+
atom_mask = None
|
| 179 |
+
for g in groups:
|
| 180 |
+
if("[PRIMARY]" == g[0]):
|
| 181 |
+
seq = g[1][0].strip()
|
| 182 |
+
for i in range(len(seq)):
|
| 183 |
+
if(seq[i] not in residue_constants.restypes):
|
| 184 |
+
seq[i] = 'X'
|
| 185 |
+
aatype = np.array([
|
| 186 |
+
residue_constants.restype_order.get(
|
| 187 |
+
res_symbol, residue_constants.restype_num
|
| 188 |
+
) for res_symbol in seq
|
| 189 |
+
])
|
| 190 |
+
elif("[TERTIARY]" == g[0]):
|
| 191 |
+
tertiary = []
|
| 192 |
+
for axis in range(3):
|
| 193 |
+
tertiary.append(list(map(float, g[1][axis].split())))
|
| 194 |
+
tertiary_np = np.array(tertiary)
|
| 195 |
+
atom_positions = np.zeros(
|
| 196 |
+
(len(tertiary[0])//3, residue_constants.atom_type_num, 3)
|
| 197 |
+
).astype(np.float32)
|
| 198 |
+
for i, atom in enumerate(atoms):
|
| 199 |
+
atom_positions[:, residue_constants.atom_order[atom], :] = (
|
| 200 |
+
np.transpose(tertiary_np[:, i::3])
|
| 201 |
+
)
|
| 202 |
+
atom_positions *= PICO_TO_ANGSTROM
|
| 203 |
+
elif("[MASK]" == g[0]):
|
| 204 |
+
mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
|
| 205 |
+
atom_mask = np.zeros(
|
| 206 |
+
(len(mask), residue_constants.atom_type_num,)
|
| 207 |
+
).astype(np.float32)
|
| 208 |
+
for i, atom in enumerate(atoms):
|
| 209 |
+
atom_mask[:, residue_constants.atom_order[atom]] = 1
|
| 210 |
+
atom_mask *= mask[..., None]
|
| 211 |
+
|
| 212 |
+
return Protein(
|
| 213 |
+
atom_positions=atom_positions,
|
| 214 |
+
atom_mask=atom_mask,
|
| 215 |
+
aatype=aatype,
|
| 216 |
+
residue_index=np.arange(len(aatype)),
|
| 217 |
+
b_factors=None,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
|
| 222 |
+
pdb_headers = []
|
| 223 |
+
|
| 224 |
+
remark = prot.remark
|
| 225 |
+
if(remark is not None):
|
| 226 |
+
pdb_headers.append(f"REMARK {remark}")
|
| 227 |
+
|
| 228 |
+
parents = prot.parents
|
| 229 |
+
parents_chain_index = prot.parents_chain_index
|
| 230 |
+
if(parents_chain_index is not None):
|
| 231 |
+
parents = [
|
| 232 |
+
p for i, p in zip(parents_chain_index, parents) if i == chain_id
|
| 233 |
+
]
|
| 234 |
+
|
| 235 |
+
if(parents is None or len(parents) == 0):
|
| 236 |
+
parents = ["N/A"]
|
| 237 |
+
|
| 238 |
+
pdb_headers.append(f"PARENT {' '.join(parents)}")
|
| 239 |
+
|
| 240 |
+
return pdb_headers
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
|
| 244 |
+
""" Add pdb headers to an existing PDB string. Useful during multi-chain
|
| 245 |
+
recycling
|
| 246 |
+
"""
|
| 247 |
+
out_pdb_lines = []
|
| 248 |
+
lines = pdb_str.split('\n')
|
| 249 |
+
|
| 250 |
+
remark = prot.remark
|
| 251 |
+
if(remark is not None):
|
| 252 |
+
out_pdb_lines.append(f"REMARK {remark}")
|
| 253 |
+
|
| 254 |
+
parents_per_chain = None
|
| 255 |
+
if(prot.parents is not None and len(prot.parents) > 0):
|
| 256 |
+
parents_per_chain = []
|
| 257 |
+
if(prot.parents_chain_index is not None):
|
| 258 |
+
cur_chain = prot.parents_chain_index[0]
|
| 259 |
+
parent_dict = {}
|
| 260 |
+
for p, i in zip(prot.parents, prot.parents_chain_index):
|
| 261 |
+
parent_dict.setdefault(str(i), [])
|
| 262 |
+
parent_dict[str(i)].append(p)
|
| 263 |
+
|
| 264 |
+
max_idx = max([int(chain_idx) for chain_idx in parent_dict])
|
| 265 |
+
for i in range(max_idx + 1):
|
| 266 |
+
chain_parents = parent_dict.get(str(i), ["N/A"])
|
| 267 |
+
parents_per_chain.append(chain_parents)
|
| 268 |
+
else:
|
| 269 |
+
parents_per_chain.append(prot.parents)
|
| 270 |
+
else:
|
| 271 |
+
parents_per_chain = [["N/A"]]
|
| 272 |
+
|
| 273 |
+
make_parent_line = lambda p: f"PARENT {' '.join(p)}"
|
| 274 |
+
|
| 275 |
+
out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
|
| 276 |
+
|
| 277 |
+
chain_counter = 0
|
| 278 |
+
for i, l in enumerate(lines):
|
| 279 |
+
if("PARENT" not in l and "REMARK" not in l):
|
| 280 |
+
out_pdb_lines.append(l)
|
| 281 |
+
if("TER" in l and not "END" in lines[i + 1]):
|
| 282 |
+
chain_counter += 1
|
| 283 |
+
if(not chain_counter >= len(parents_per_chain)):
|
| 284 |
+
chain_parents = parents_per_chain[chain_counter]
|
| 285 |
+
else:
|
| 286 |
+
chain_parents = ["N/A"]
|
| 287 |
+
|
| 288 |
+
out_pdb_lines.append(make_parent_line(chain_parents))
|
| 289 |
+
|
| 290 |
+
return '\n'.join(out_pdb_lines)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def to_pdb(prot: Protein) -> str:
|
| 294 |
+
"""Converts a `Protein` instance to a PDB string.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
prot: The protein to convert to PDB.
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
PDB string.
|
| 301 |
+
"""
|
| 302 |
+
restypes = residue_constants.restypes + ["X"]
|
| 303 |
+
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
|
| 304 |
+
atom_types = residue_constants.atom_types
|
| 305 |
+
|
| 306 |
+
pdb_lines = []
|
| 307 |
+
|
| 308 |
+
atom_mask = prot.atom_mask
|
| 309 |
+
aatype = prot.aatype
|
| 310 |
+
atom_positions = prot.atom_positions
|
| 311 |
+
residue_index = prot.residue_index.astype(np.int32)
|
| 312 |
+
b_factors = prot.b_factors
|
| 313 |
+
chain_index = prot.chain_index
|
| 314 |
+
|
| 315 |
+
if np.any(aatype > residue_constants.restype_num):
|
| 316 |
+
raise ValueError("Invalid aatypes.")
|
| 317 |
+
|
| 318 |
+
headers = get_pdb_headers(prot)
|
| 319 |
+
if(len(headers) > 0):
|
| 320 |
+
pdb_lines.extend(headers)
|
| 321 |
+
|
| 322 |
+
n = aatype.shape[0]
|
| 323 |
+
atom_index = 1
|
| 324 |
+
prev_chain_index = 0
|
| 325 |
+
chain_tags = string.ascii_uppercase
|
| 326 |
+
# Add all atom sites.
|
| 327 |
+
for i in range(n):
|
| 328 |
+
res_name_3 = res_1to3(aatype[i])
|
| 329 |
+
for atom_name, pos, mask, b_factor in zip(
|
| 330 |
+
atom_types, atom_positions[i], atom_mask[i], b_factors[i]
|
| 331 |
+
):
|
| 332 |
+
if mask < 0.5:
|
| 333 |
+
chain_tag = "A"
|
| 334 |
+
if(chain_index is not None):
|
| 335 |
+
chain_tag = chain_tags[chain_index[i]]
|
| 336 |
+
continue
|
| 337 |
+
|
| 338 |
+
record_type = "ATOM"
|
| 339 |
+
name = atom_name if len(atom_name) == 4 else f" {atom_name}"
|
| 340 |
+
alt_loc = ""
|
| 341 |
+
insertion_code = ""
|
| 342 |
+
occupancy = 1.00
|
| 343 |
+
element = atom_name[
|
| 344 |
+
0
|
| 345 |
+
] # Protein supports only C, N, O, S, this works.
|
| 346 |
+
charge = ""
|
| 347 |
+
|
| 348 |
+
chain_tag = "A"
|
| 349 |
+
if(chain_index is not None):
|
| 350 |
+
chain_tag = chain_tags[chain_index[i]]
|
| 351 |
+
|
| 352 |
+
# PDB is a columnar format, every space matters here!
|
| 353 |
+
atom_line = (
|
| 354 |
+
f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
|
| 355 |
+
f"{res_name_3:>3} {chain_tag:>1}"
|
| 356 |
+
f"{residue_index[i]:>4}{insertion_code:>1} "
|
| 357 |
+
f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
|
| 358 |
+
f"{occupancy:>6.2f}{b_factor:>6.2f} "
|
| 359 |
+
f"{element:>2}{charge:>2}"
|
| 360 |
+
)
|
| 361 |
+
pdb_lines.append(atom_line)
|
| 362 |
+
atom_index += 1
|
| 363 |
+
|
| 364 |
+
should_terminate = (i == n - 1)
|
| 365 |
+
if(chain_index is not None):
|
| 366 |
+
if(i != n - 1 and chain_index[i + 1] != prev_chain_index):
|
| 367 |
+
should_terminate = True
|
| 368 |
+
prev_chain_index = chain_index[i + 1]
|
| 369 |
+
|
| 370 |
+
if(should_terminate):
|
| 371 |
+
# Close the chain.
|
| 372 |
+
chain_end = "TER"
|
| 373 |
+
chain_termination_line = (
|
| 374 |
+
f"{chain_end:<6}{atom_index:>5} "
|
| 375 |
+
f"{res_1to3(aatype[i]):>3} "
|
| 376 |
+
f"{chain_tag:>1}{residue_index[i]:>4}"
|
| 377 |
+
)
|
| 378 |
+
pdb_lines.append(chain_termination_line)
|
| 379 |
+
atom_index += 1
|
| 380 |
+
|
| 381 |
+
if(i != n - 1):
|
| 382 |
+
# "prev" is a misnomer here. This happens at the beginning of
|
| 383 |
+
# each new chain.
|
| 384 |
+
pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
|
| 385 |
+
|
| 386 |
+
pdb_lines.append("END")
|
| 387 |
+
pdb_lines.append("")
|
| 388 |
+
return "\n".join(pdb_lines)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def ideal_atom_mask(prot: Protein) -> np.ndarray:
|
| 392 |
+
"""Computes an ideal atom mask.
|
| 393 |
+
|
| 394 |
+
`Protein.atom_mask` typically is defined according to the atoms that are
|
| 395 |
+
reported in the PDB. This function computes a mask according to heavy atoms
|
| 396 |
+
that should be present in the given sequence of amino acids.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
prot: `Protein` whose fields are `numpy.ndarray` objects.
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
An ideal atom mask.
|
| 403 |
+
"""
|
| 404 |
+
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def from_prediction(
|
| 408 |
+
features: FeatureDict,
|
| 409 |
+
result: ModelOutput,
|
| 410 |
+
b_factors: Optional[np.ndarray] = None,
|
| 411 |
+
chain_index: Optional[np.ndarray] = None,
|
| 412 |
+
remark: Optional[str] = None,
|
| 413 |
+
parents: Optional[Sequence[str]] = None,
|
| 414 |
+
parents_chain_index: Optional[Sequence[int]] = None
|
| 415 |
+
) -> Protein:
|
| 416 |
+
"""Assembles a protein from a prediction.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
features: Dictionary holding model inputs.
|
| 420 |
+
result: Dictionary holding model outputs.
|
| 421 |
+
b_factors: (Optional) B-factors to use for the protein.
|
| 422 |
+
chain_index: (Optional) Chain indices for multi-chain predictions
|
| 423 |
+
remark: (Optional) Remark about the prediction
|
| 424 |
+
parents: (Optional) List of template names
|
| 425 |
+
Returns:
|
| 426 |
+
A protein instance.
|
| 427 |
+
"""
|
| 428 |
+
if b_factors is None:
|
| 429 |
+
b_factors = np.zeros_like(result["final_atom_mask"])
|
| 430 |
+
|
| 431 |
+
return Protein(
|
| 432 |
+
aatype=features["aatype"],
|
| 433 |
+
atom_positions=result["final_atom_positions"],
|
| 434 |
+
atom_mask=result["final_atom_mask"],
|
| 435 |
+
residue_index=features["residue_index"] + 1,
|
| 436 |
+
b_factors=b_factors,
|
| 437 |
+
chain_index=chain_index,
|
| 438 |
+
remark=remark,
|
| 439 |
+
parents=parents,
|
| 440 |
+
parents_chain_index=parents_chain_index,
|
| 441 |
+
)
|
openfold/np/relax/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import importlib as importlib
|
| 4 |
+
|
| 5 |
+
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
|
| 6 |
+
__all__ = [
|
| 7 |
+
os.path.basename(f)[:-3]
|
| 8 |
+
for f in _files
|
| 9 |
+
if os.path.isfile(f) and not f.endswith("__init__.py")
|
| 10 |
+
]
|
| 11 |
+
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
|
| 12 |
+
for _m in _modules:
|
| 13 |
+
globals()[_m[0]] = _m[1]
|
| 14 |
+
|
| 15 |
+
# Avoid needlessly cluttering the global namespace
|
| 16 |
+
del _files, _m, _modules
|
openfold/np/relax/amber_minimize.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Restrained Amber Minimization of a structure."""
|
| 17 |
+
|
| 18 |
+
import io
|
| 19 |
+
import time
|
| 20 |
+
from typing import Collection, Optional, Sequence
|
| 21 |
+
|
| 22 |
+
from absl import logging
|
| 23 |
+
from openfold.np import (
|
| 24 |
+
protein,
|
| 25 |
+
residue_constants,
|
| 26 |
+
)
|
| 27 |
+
import openfold.utils.loss as loss
|
| 28 |
+
from openfold.np.relax import cleanup, utils
|
| 29 |
+
import ml_collections
|
| 30 |
+
import numpy as np
|
| 31 |
+
from simtk import openmm
|
| 32 |
+
from simtk import unit
|
| 33 |
+
from simtk.openmm import app as openmm_app
|
| 34 |
+
from simtk.openmm.app.internal.pdbstructure import PdbStructure
|
| 35 |
+
|
| 36 |
+
ENERGY = unit.kilocalories_per_mole
|
| 37 |
+
LENGTH = unit.angstroms
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:
|
| 41 |
+
"""Returns True if the atom will be restrained by the given restraint set."""
|
| 42 |
+
|
| 43 |
+
if rset == "non_hydrogen":
|
| 44 |
+
return atom.element.name != "hydrogen"
|
| 45 |
+
elif rset == "c_alpha":
|
| 46 |
+
return atom.name == "CA"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _add_restraints(
|
| 50 |
+
system: openmm.System,
|
| 51 |
+
reference_pdb: openmm_app.PDBFile,
|
| 52 |
+
stiffness: unit.Unit,
|
| 53 |
+
rset: str,
|
| 54 |
+
exclude_residues: Sequence[int],
|
| 55 |
+
):
|
| 56 |
+
"""Adds a harmonic potential that restrains the system to a structure."""
|
| 57 |
+
assert rset in ["non_hydrogen", "c_alpha"]
|
| 58 |
+
|
| 59 |
+
force = openmm.CustomExternalForce(
|
| 60 |
+
"0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)"
|
| 61 |
+
)
|
| 62 |
+
force.addGlobalParameter("k", stiffness)
|
| 63 |
+
for p in ["x0", "y0", "z0"]:
|
| 64 |
+
force.addPerParticleParameter(p)
|
| 65 |
+
|
| 66 |
+
for i, atom in enumerate(reference_pdb.topology.atoms()):
|
| 67 |
+
if atom.residue.index in exclude_residues:
|
| 68 |
+
continue
|
| 69 |
+
if will_restrain(atom, rset):
|
| 70 |
+
force.addParticle(i, reference_pdb.positions[i])
|
| 71 |
+
logging.info(
|
| 72 |
+
"Restraining %d / %d particles.",
|
| 73 |
+
force.getNumParticles(),
|
| 74 |
+
system.getNumParticles(),
|
| 75 |
+
)
|
| 76 |
+
system.addForce(force)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _openmm_minimize(
|
| 80 |
+
pdb_str: str,
|
| 81 |
+
max_iterations: int,
|
| 82 |
+
tolerance: unit.Unit,
|
| 83 |
+
stiffness: unit.Unit,
|
| 84 |
+
restraint_set: str,
|
| 85 |
+
exclude_residues: Sequence[int],
|
| 86 |
+
use_gpu: bool,
|
| 87 |
+
):
|
| 88 |
+
"""Minimize energy via openmm."""
|
| 89 |
+
|
| 90 |
+
pdb_file = io.StringIO(pdb_str)
|
| 91 |
+
pdb = openmm_app.PDBFile(pdb_file)
|
| 92 |
+
|
| 93 |
+
force_field = openmm_app.ForceField("amber99sb.xml")
|
| 94 |
+
constraints = openmm_app.HBonds
|
| 95 |
+
system = force_field.createSystem(pdb.topology, constraints=constraints)
|
| 96 |
+
if stiffness > 0 * ENERGY / (LENGTH ** 2):
|
| 97 |
+
_add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
|
| 98 |
+
|
| 99 |
+
integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
|
| 100 |
+
platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
|
| 101 |
+
simulation = openmm_app.Simulation(
|
| 102 |
+
pdb.topology, system, integrator, platform
|
| 103 |
+
)
|
| 104 |
+
simulation.context.setPositions(pdb.positions)
|
| 105 |
+
|
| 106 |
+
ret = {}
|
| 107 |
+
state = simulation.context.getState(getEnergy=True, getPositions=True)
|
| 108 |
+
ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
|
| 109 |
+
ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
|
| 110 |
+
simulation.minimizeEnergy(maxIterations=max_iterations, tolerance=tolerance)
|
| 111 |
+
state = simulation.context.getState(getEnergy=True, getPositions=True)
|
| 112 |
+
ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
|
| 113 |
+
ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
|
| 114 |
+
ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
|
| 115 |
+
return ret
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):
|
| 119 |
+
"""Returns a pdb string provided OpenMM topology and positions."""
|
| 120 |
+
with io.StringIO() as f:
|
| 121 |
+
openmm_app.PDBFile.writeFile(topology, positions, f)
|
| 122 |
+
return f.getvalue()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str):
|
| 126 |
+
"""Checks that no atom positions have been altered by cleaning."""
|
| 127 |
+
cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
|
| 128 |
+
reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))
|
| 129 |
+
|
| 130 |
+
cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
|
| 131 |
+
ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))
|
| 132 |
+
|
| 133 |
+
for ref_res, cl_res in zip(
|
| 134 |
+
reference.topology.residues(), cleaned.topology.residues()
|
| 135 |
+
):
|
| 136 |
+
assert ref_res.name == cl_res.name
|
| 137 |
+
for rat in ref_res.atoms():
|
| 138 |
+
for cat in cl_res.atoms():
|
| 139 |
+
if cat.name == rat.name:
|
| 140 |
+
if not np.array_equal(
|
| 141 |
+
cl_xyz[cat.index], ref_xyz[rat.index]
|
| 142 |
+
):
|
| 143 |
+
raise ValueError(
|
| 144 |
+
f"Coordinates of cleaned atom {cat} do not match "
|
| 145 |
+
f"coordinates of reference atom {rat}."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _check_residues_are_well_defined(prot: protein.Protein):
|
| 150 |
+
"""Checks that all residues contain non-empty atom sets."""
|
| 151 |
+
if (prot.atom_mask.sum(axis=-1) == 0).any():
|
| 152 |
+
raise ValueError(
|
| 153 |
+
"Amber minimization can only be performed on proteins with"
|
| 154 |
+
" well-defined residues. This protein contains at least"
|
| 155 |
+
" one residue with no atoms."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _check_atom_mask_is_ideal(prot):
|
| 160 |
+
"""Sanity-check the atom mask is ideal, up to a possible OXT."""
|
| 161 |
+
atom_mask = prot.atom_mask
|
| 162 |
+
ideal_atom_mask = protein.ideal_atom_mask(prot)
|
| 163 |
+
utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def clean_protein(prot: protein.Protein, checks: bool = True):
|
| 167 |
+
"""Adds missing atoms to Protein instance.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
prot: A `protein.Protein` instance.
|
| 171 |
+
checks: A `bool` specifying whether to add additional checks to the cleaning
|
| 172 |
+
process.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
pdb_string: A string of the cleaned protein.
|
| 176 |
+
"""
|
| 177 |
+
_check_atom_mask_is_ideal(prot)
|
| 178 |
+
|
| 179 |
+
# Clean pdb.
|
| 180 |
+
prot_pdb_string = protein.to_pdb(prot)
|
| 181 |
+
pdb_file = io.StringIO(prot_pdb_string)
|
| 182 |
+
alterations_info = {}
|
| 183 |
+
fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
|
| 184 |
+
fixed_pdb_file = io.StringIO(fixed_pdb)
|
| 185 |
+
pdb_structure = PdbStructure(fixed_pdb_file)
|
| 186 |
+
cleanup.clean_structure(pdb_structure, alterations_info)
|
| 187 |
+
|
| 188 |
+
logging.info("alterations info: %s", alterations_info)
|
| 189 |
+
|
| 190 |
+
# Write pdb file of cleaned structure.
|
| 191 |
+
as_file = openmm_app.PDBFile(pdb_structure)
|
| 192 |
+
pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
|
| 193 |
+
if checks:
|
| 194 |
+
_check_cleaned_atoms(pdb_string, prot_pdb_string)
|
| 195 |
+
|
| 196 |
+
headers = protein.get_pdb_headers(prot)
|
| 197 |
+
if(len(headers) > 0):
|
| 198 |
+
pdb_string = '\n'.join(['\n'.join(headers), pdb_string])
|
| 199 |
+
|
| 200 |
+
return pdb_string
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def make_atom14_positions(prot):
|
| 204 |
+
"""Constructs denser atom positions (14 dimensions instead of 37)."""
|
| 205 |
+
restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
|
| 206 |
+
restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
|
| 207 |
+
restype_atom14_mask = []
|
| 208 |
+
|
| 209 |
+
for rt in residue_constants.restypes:
|
| 210 |
+
atom_names = residue_constants.restype_name_to_atom14_names[
|
| 211 |
+
residue_constants.restype_1to3[rt]
|
| 212 |
+
]
|
| 213 |
+
|
| 214 |
+
restype_atom14_to_atom37.append(
|
| 215 |
+
[
|
| 216 |
+
(residue_constants.atom_order[name] if name else 0)
|
| 217 |
+
for name in atom_names
|
| 218 |
+
]
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
|
| 222 |
+
restype_atom37_to_atom14.append(
|
| 223 |
+
[
|
| 224 |
+
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
|
| 225 |
+
for name in residue_constants.atom_types
|
| 226 |
+
]
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
restype_atom14_mask.append(
|
| 230 |
+
[(1.0 if name else 0.0) for name in atom_names]
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Add dummy mapping for restype 'UNK'.
|
| 234 |
+
restype_atom14_to_atom37.append([0] * 14)
|
| 235 |
+
restype_atom37_to_atom14.append([0] * 37)
|
| 236 |
+
restype_atom14_mask.append([0.0] * 14)
|
| 237 |
+
|
| 238 |
+
restype_atom14_to_atom37 = np.array(
|
| 239 |
+
restype_atom14_to_atom37, dtype=np.int32
|
| 240 |
+
)
|
| 241 |
+
restype_atom37_to_atom14 = np.array(
|
| 242 |
+
restype_atom37_to_atom14, dtype=np.int32
|
| 243 |
+
)
|
| 244 |
+
restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
|
| 245 |
+
|
| 246 |
+
# Create the mapping for (residx, atom14) --> atom37, i.e. an array
|
| 247 |
+
# with shape (num_res, 14) containing the atom37 indices for this protein.
|
| 248 |
+
residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
|
| 249 |
+
residx_atom14_mask = restype_atom14_mask[prot["aatype"]]
|
| 250 |
+
|
| 251 |
+
# Create a mask for known ground truth positions.
|
| 252 |
+
residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
|
| 253 |
+
prot["all_atom_mask"], residx_atom14_to_atom37, axis=1
|
| 254 |
+
).astype(np.float32)
|
| 255 |
+
|
| 256 |
+
# Gather the ground truth positions.
|
| 257 |
+
residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (
|
| 258 |
+
np.take_along_axis(
|
| 259 |
+
prot["all_atom_positions"],
|
| 260 |
+
residx_atom14_to_atom37[..., None],
|
| 261 |
+
axis=1,
|
| 262 |
+
)
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
prot["atom14_atom_exists"] = residx_atom14_mask
|
| 266 |
+
prot["atom14_gt_exists"] = residx_atom14_gt_mask
|
| 267 |
+
prot["atom14_gt_positions"] = residx_atom14_gt_positions
|
| 268 |
+
|
| 269 |
+
prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37.astype(np.int64)
|
| 270 |
+
|
| 271 |
+
# Create the gather indices for mapping back.
|
| 272 |
+
residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
|
| 273 |
+
prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14.astype(np.int64)
|
| 274 |
+
|
| 275 |
+
# Create the corresponding mask.
|
| 276 |
+
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
|
| 277 |
+
for restype, restype_letter in enumerate(residue_constants.restypes):
|
| 278 |
+
restype_name = residue_constants.restype_1to3[restype_letter]
|
| 279 |
+
atom_names = residue_constants.residue_atoms[restype_name]
|
| 280 |
+
for atom_name in atom_names:
|
| 281 |
+
atom_type = residue_constants.atom_order[atom_name]
|
| 282 |
+
restype_atom37_mask[restype, atom_type] = 1
|
| 283 |
+
|
| 284 |
+
residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
|
| 285 |
+
prot["atom37_atom_exists"] = residx_atom37_mask
|
| 286 |
+
|
| 287 |
+
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
|
| 288 |
+
# alternative ground truth coordinates where the naming is swapped
|
| 289 |
+
restype_3 = [
|
| 290 |
+
residue_constants.restype_1to3[res]
|
| 291 |
+
for res in residue_constants.restypes
|
| 292 |
+
]
|
| 293 |
+
restype_3 += ["UNK"]
|
| 294 |
+
|
| 295 |
+
# Matrices for renaming ambiguous atoms.
|
| 296 |
+
all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
|
| 297 |
+
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
|
| 298 |
+
correspondences = np.arange(14)
|
| 299 |
+
for source_atom_swap, target_atom_swap in swap.items():
|
| 300 |
+
source_index = residue_constants.restype_name_to_atom14_names[
|
| 301 |
+
resname
|
| 302 |
+
].index(source_atom_swap)
|
| 303 |
+
target_index = residue_constants.restype_name_to_atom14_names[
|
| 304 |
+
resname
|
| 305 |
+
].index(target_atom_swap)
|
| 306 |
+
correspondences[source_index] = target_index
|
| 307 |
+
correspondences[target_index] = source_index
|
| 308 |
+
renaming_matrix = np.zeros((14, 14), dtype=np.float32)
|
| 309 |
+
for index, correspondence in enumerate(correspondences):
|
| 310 |
+
renaming_matrix[index, correspondence] = 1.0
|
| 311 |
+
all_matrices[resname] = renaming_matrix.astype(np.float32)
|
| 312 |
+
renaming_matrices = np.stack(
|
| 313 |
+
[all_matrices[restype] for restype in restype_3]
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Pick the transformation matrices for the given residue sequence
|
| 317 |
+
# shape (num_res, 14, 14).
|
| 318 |
+
renaming_transform = renaming_matrices[prot["aatype"]]
|
| 319 |
+
|
| 320 |
+
# Apply it to the ground truth positions. shape (num_res, 14, 3).
|
| 321 |
+
alternative_gt_positions = np.einsum(
|
| 322 |
+
"rac,rab->rbc", residx_atom14_gt_positions, renaming_transform
|
| 323 |
+
)
|
| 324 |
+
prot["atom14_alt_gt_positions"] = alternative_gt_positions
|
| 325 |
+
|
| 326 |
+
# Create the mask for the alternative ground truth (differs from the
|
| 327 |
+
# ground truth mask, if only one of the atoms in an ambiguous pair has a
|
| 328 |
+
# ground truth position).
|
| 329 |
+
alternative_gt_mask = np.einsum(
|
| 330 |
+
"ra,rab->rb", residx_atom14_gt_mask, renaming_transform
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
prot["atom14_alt_gt_exists"] = alternative_gt_mask
|
| 334 |
+
|
| 335 |
+
# Create an ambiguous atoms mask. shape: (21, 14).
|
| 336 |
+
restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
|
| 337 |
+
for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
|
| 338 |
+
for atom_name1, atom_name2 in swap.items():
|
| 339 |
+
restype = residue_constants.restype_order[
|
| 340 |
+
residue_constants.restype_3to1[resname]
|
| 341 |
+
]
|
| 342 |
+
atom_idx1 = residue_constants.restype_name_to_atom14_names[
|
| 343 |
+
resname
|
| 344 |
+
].index(atom_name1)
|
| 345 |
+
atom_idx2 = residue_constants.restype_name_to_atom14_names[
|
| 346 |
+
resname
|
| 347 |
+
].index(atom_name2)
|
| 348 |
+
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
|
| 349 |
+
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
|
| 350 |
+
|
| 351 |
+
# From this create an ambiguous_mask for the given sequence.
|
| 352 |
+
prot["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
|
| 353 |
+
prot["aatype"]
|
| 354 |
+
]
|
| 355 |
+
|
| 356 |
+
return prot
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def find_violations(prot_np: protein.Protein):
|
| 360 |
+
"""Analyzes a protein and returns structural violation information.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
prot_np: A protein.
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
violations: A `dict` of structure components with structural violations.
|
| 367 |
+
violation_metrics: A `dict` of violation metrics.
|
| 368 |
+
"""
|
| 369 |
+
batch = {
|
| 370 |
+
"aatype": prot_np.aatype,
|
| 371 |
+
"all_atom_positions": prot_np.atom_positions.astype(np.float32),
|
| 372 |
+
"all_atom_mask": prot_np.atom_mask.astype(np.float32),
|
| 373 |
+
"residue_index": prot_np.residue_index,
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
|
| 377 |
+
batch = make_atom14_positions(batch)
|
| 378 |
+
|
| 379 |
+
violations = loss.find_structural_violations_np(
|
| 380 |
+
batch=batch,
|
| 381 |
+
atom14_pred_positions=batch["atom14_gt_positions"],
|
| 382 |
+
config=ml_collections.ConfigDict(
|
| 383 |
+
{
|
| 384 |
+
"violation_tolerance_factor": 12, # Taken from model config.
|
| 385 |
+
"clash_overlap_tolerance": 1.5, # Taken from model config.
|
| 386 |
+
}
|
| 387 |
+
),
|
| 388 |
+
)
|
| 389 |
+
violation_metrics = loss.compute_violation_metrics_np(
|
| 390 |
+
batch=batch,
|
| 391 |
+
atom14_pred_positions=batch["atom14_gt_positions"],
|
| 392 |
+
violations=violations,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
return violations, violation_metrics
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def get_violation_metrics(prot: protein.Protein):
|
| 399 |
+
"""Computes violation and alignment metrics."""
|
| 400 |
+
structural_violations, struct_metrics = find_violations(prot)
|
| 401 |
+
violation_idx = np.flatnonzero(
|
| 402 |
+
structural_violations["total_per_residue_violations_mask"]
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
struct_metrics["residue_violations"] = violation_idx
|
| 406 |
+
struct_metrics["num_residue_violations"] = len(violation_idx)
|
| 407 |
+
struct_metrics["structural_violations"] = structural_violations
|
| 408 |
+
return struct_metrics
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def _run_one_iteration(
|
| 412 |
+
*,
|
| 413 |
+
pdb_string: str,
|
| 414 |
+
max_iterations: int,
|
| 415 |
+
tolerance: float,
|
| 416 |
+
stiffness: float,
|
| 417 |
+
restraint_set: str,
|
| 418 |
+
max_attempts: int,
|
| 419 |
+
exclude_residues: Optional[Collection[int]] = None,
|
| 420 |
+
use_gpu: bool,
|
| 421 |
+
):
|
| 422 |
+
"""Runs the minimization pipeline.
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
pdb_string: A pdb string.
|
| 426 |
+
max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
|
| 427 |
+
A value of 0 specifies no limit.
|
| 428 |
+
tolerance: kcal/mol, the energy tolerance of L-BFGS.
|
| 429 |
+
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
|
| 430 |
+
potential.
|
| 431 |
+
restraint_set: The set of atoms to restrain.
|
| 432 |
+
max_attempts: The maximum number of minimization attempts.
|
| 433 |
+
exclude_residues: An optional list of zero-indexed residues to exclude from
|
| 434 |
+
restraints.
|
| 435 |
+
use_gpu: Whether to run relaxation on GPU
|
| 436 |
+
Returns:
|
| 437 |
+
A `dict` of minimization info.
|
| 438 |
+
"""
|
| 439 |
+
exclude_residues = exclude_residues or []
|
| 440 |
+
|
| 441 |
+
# Assign physical dimensions.
|
| 442 |
+
tolerance = tolerance * ENERGY
|
| 443 |
+
stiffness = stiffness * ENERGY / (LENGTH ** 2)
|
| 444 |
+
|
| 445 |
+
start = time.perf_counter()
|
| 446 |
+
minimized = False
|
| 447 |
+
attempts = 0
|
| 448 |
+
while not minimized and attempts < max_attempts:
|
| 449 |
+
attempts += 1
|
| 450 |
+
try:
|
| 451 |
+
logging.info(
|
| 452 |
+
"Minimizing protein, attempt %d of %d.", attempts, max_attempts
|
| 453 |
+
)
|
| 454 |
+
ret = _openmm_minimize(
|
| 455 |
+
pdb_string,
|
| 456 |
+
max_iterations=max_iterations,
|
| 457 |
+
tolerance=tolerance,
|
| 458 |
+
stiffness=stiffness,
|
| 459 |
+
restraint_set=restraint_set,
|
| 460 |
+
exclude_residues=exclude_residues,
|
| 461 |
+
use_gpu=use_gpu,
|
| 462 |
+
)
|
| 463 |
+
minimized = True
|
| 464 |
+
except Exception as e: # pylint: disable=broad-except
|
| 465 |
+
print(e)
|
| 466 |
+
logging.info(e)
|
| 467 |
+
if not minimized:
|
| 468 |
+
raise ValueError(f"Minimization failed after {max_attempts} attempts.")
|
| 469 |
+
ret["opt_time"] = time.perf_counter() - start
|
| 470 |
+
ret["min_attempts"] = attempts
|
| 471 |
+
return ret
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def run_pipeline(
|
| 475 |
+
prot: protein.Protein,
|
| 476 |
+
stiffness: float,
|
| 477 |
+
use_gpu: bool,
|
| 478 |
+
max_outer_iterations: int = 1,
|
| 479 |
+
place_hydrogens_every_iteration: bool = True,
|
| 480 |
+
max_iterations: int = 0,
|
| 481 |
+
tolerance: float = 2.39,
|
| 482 |
+
restraint_set: str = "non_hydrogen",
|
| 483 |
+
max_attempts: int = 100,
|
| 484 |
+
checks: bool = True,
|
| 485 |
+
exclude_residues: Optional[Sequence[int]] = None,
|
| 486 |
+
):
|
| 487 |
+
"""Run iterative amber relax.
|
| 488 |
+
|
| 489 |
+
Successive relax iterations are performed until all violations have been
|
| 490 |
+
resolved. Each iteration involves a restrained Amber minimization, with
|
| 491 |
+
restraint exclusions determined by violation-participating residues.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
prot: A protein to be relaxed.
|
| 495 |
+
stiffness: kcal/mol A**2, the restraint stiffness.
|
| 496 |
+
use_gpu: Whether to run on GPU
|
| 497 |
+
max_outer_iterations: The maximum number of iterative minimization.
|
| 498 |
+
place_hydrogens_every_iteration: Whether hydrogens are re-initialized
|
| 499 |
+
prior to every minimization.
|
| 500 |
+
max_iterations: An `int` specifying the maximum number of L-BFGS steps
|
| 501 |
+
per relax iteration. A value of 0 specifies no limit.
|
| 502 |
+
tolerance: kcal/mol, the energy tolerance of L-BFGS.
|
| 503 |
+
The default value is the OpenMM default.
|
| 504 |
+
restraint_set: The set of atoms to restrain.
|
| 505 |
+
max_attempts: The maximum number of minimization attempts per iteration.
|
| 506 |
+
checks: Whether to perform cleaning checks.
|
| 507 |
+
exclude_residues: An optional list of zero-indexed residues to exclude from
|
| 508 |
+
restraints.
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
out: A dictionary of output values.
|
| 512 |
+
"""
|
| 513 |
+
|
| 514 |
+
# `protein.to_pdb` will strip any poorly-defined residues so we need to
|
| 515 |
+
# perform this check before `clean_protein`.
|
| 516 |
+
_check_residues_are_well_defined(prot)
|
| 517 |
+
pdb_string = clean_protein(prot, checks=checks)
|
| 518 |
+
|
| 519 |
+
# We keep the input around to restore metadata deleted by the relaxer
|
| 520 |
+
input_prot = prot
|
| 521 |
+
|
| 522 |
+
exclude_residues = exclude_residues or []
|
| 523 |
+
exclude_residues = set(exclude_residues)
|
| 524 |
+
violations = np.inf
|
| 525 |
+
iteration = 0
|
| 526 |
+
|
| 527 |
+
while violations > 0 and iteration < max_outer_iterations:
|
| 528 |
+
ret = _run_one_iteration(
|
| 529 |
+
pdb_string=pdb_string,
|
| 530 |
+
exclude_residues=exclude_residues,
|
| 531 |
+
max_iterations=max_iterations,
|
| 532 |
+
tolerance=tolerance,
|
| 533 |
+
stiffness=stiffness,
|
| 534 |
+
restraint_set=restraint_set,
|
| 535 |
+
max_attempts=max_attempts,
|
| 536 |
+
use_gpu=use_gpu,
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
headers = protein.get_pdb_headers(prot)
|
| 540 |
+
if(len(headers) > 0):
|
| 541 |
+
ret["min_pdb"] = '\n'.join(['\n'.join(headers), ret["min_pdb"]])
|
| 542 |
+
|
| 543 |
+
prot = protein.from_pdb_string(ret["min_pdb"])
|
| 544 |
+
if place_hydrogens_every_iteration:
|
| 545 |
+
pdb_string = clean_protein(prot, checks=True)
|
| 546 |
+
else:
|
| 547 |
+
pdb_string = ret["min_pdb"]
|
| 548 |
+
ret.update(get_violation_metrics(prot))
|
| 549 |
+
ret.update(
|
| 550 |
+
{
|
| 551 |
+
"num_exclusions": len(exclude_residues),
|
| 552 |
+
"iteration": iteration,
|
| 553 |
+
}
|
| 554 |
+
)
|
| 555 |
+
violations = ret["violations_per_residue"]
|
| 556 |
+
exclude_residues = exclude_residues.union(ret["residue_violations"])
|
| 557 |
+
|
| 558 |
+
logging.info(
|
| 559 |
+
"Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
|
| 560 |
+
"num residue violations %d num residue exclusions %d ",
|
| 561 |
+
ret["einit"],
|
| 562 |
+
ret["efinal"],
|
| 563 |
+
ret["opt_time"],
|
| 564 |
+
ret["num_residue_violations"],
|
| 565 |
+
ret["num_exclusions"],
|
| 566 |
+
)
|
| 567 |
+
iteration += 1
|
| 568 |
+
return ret
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def get_initial_energies(
|
| 572 |
+
pdb_strs: Sequence[str],
|
| 573 |
+
stiffness: float = 0.0,
|
| 574 |
+
restraint_set: str = "non_hydrogen",
|
| 575 |
+
exclude_residues: Optional[Sequence[int]] = None,
|
| 576 |
+
):
|
| 577 |
+
"""Returns initial potential energies for a sequence of PDBs.
|
| 578 |
+
|
| 579 |
+
Assumes the input PDBs are ready for minimization, and all have the same
|
| 580 |
+
topology.
|
| 581 |
+
Allows time to be saved by not pdbfixing / rebuilding the system.
|
| 582 |
+
|
| 583 |
+
Args:
|
| 584 |
+
pdb_strs: List of PDB strings.
|
| 585 |
+
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
|
| 586 |
+
potential.
|
| 587 |
+
restraint_set: Which atom types to restrain.
|
| 588 |
+
exclude_residues: An optional list of zero-indexed residues to exclude from
|
| 589 |
+
restraints.
|
| 590 |
+
|
| 591 |
+
Returns:
|
| 592 |
+
A list of initial energies in the same order as pdb_strs.
|
| 593 |
+
"""
|
| 594 |
+
exclude_residues = exclude_residues or []
|
| 595 |
+
|
| 596 |
+
openmm_pdbs = [
|
| 597 |
+
openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs
|
| 598 |
+
]
|
| 599 |
+
force_field = openmm_app.ForceField("amber99sb.xml")
|
| 600 |
+
system = force_field.createSystem(
|
| 601 |
+
openmm_pdbs[0].topology, constraints=openmm_app.HBonds
|
| 602 |
+
)
|
| 603 |
+
stiffness = stiffness * ENERGY / (LENGTH ** 2)
|
| 604 |
+
if stiffness > 0 * ENERGY / (LENGTH ** 2):
|
| 605 |
+
_add_restraints(
|
| 606 |
+
system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues
|
| 607 |
+
)
|
| 608 |
+
simulation = openmm_app.Simulation(
|
| 609 |
+
openmm_pdbs[0].topology,
|
| 610 |
+
system,
|
| 611 |
+
openmm.LangevinIntegrator(0, 0.01, 0.0),
|
| 612 |
+
openmm.Platform.getPlatformByName("CPU"),
|
| 613 |
+
)
|
| 614 |
+
energies = []
|
| 615 |
+
for pdb in openmm_pdbs:
|
| 616 |
+
try:
|
| 617 |
+
simulation.context.setPositions(pdb.positions)
|
| 618 |
+
state = simulation.context.getState(getEnergy=True)
|
| 619 |
+
energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
|
| 620 |
+
except Exception as e: # pylint: disable=broad-except
|
| 621 |
+
logging.error(
|
| 622 |
+
"Error getting initial energy, returning large value %s", e
|
| 623 |
+
)
|
| 624 |
+
energies.append(unit.Quantity(1e20, ENERGY))
|
| 625 |
+
return energies
|
openfold/np/relax/cleanup.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
|
| 16 |
+
|
| 17 |
+
fix_pdb uses a third-party tool. We also support fixing some additional edge
|
| 18 |
+
cases like removing chains of length one (see clean_structure).
|
| 19 |
+
"""
|
| 20 |
+
import io
|
| 21 |
+
|
| 22 |
+
import pdbfixer
|
| 23 |
+
from simtk.openmm import app
|
| 24 |
+
from simtk.openmm.app import element
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def fix_pdb(pdbfile, alterations_info):
|
| 28 |
+
"""Apply pdbfixer to the contents of a PDB file; return a PDB string result.
|
| 29 |
+
|
| 30 |
+
1) Replaces nonstandard residues.
|
| 31 |
+
2) Removes heterogens (non protein residues) including water.
|
| 32 |
+
3) Adds missing residues and missing atoms within existing residues.
|
| 33 |
+
4) Adds hydrogens assuming pH=7.0.
|
| 34 |
+
5) KeepIds is currently true, so the fixer must keep the existing chain and
|
| 35 |
+
residue identifiers. This will fail for some files in wider PDB that have
|
| 36 |
+
invalid IDs.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
pdbfile: Input PDB file handle.
|
| 40 |
+
alterations_info: A dict that will store details of changes made.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
A PDB string representing the fixed structure.
|
| 44 |
+
"""
|
| 45 |
+
fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
|
| 46 |
+
fixer.findNonstandardResidues()
|
| 47 |
+
alterations_info["nonstandard_residues"] = fixer.nonstandardResidues
|
| 48 |
+
fixer.replaceNonstandardResidues()
|
| 49 |
+
_remove_heterogens(fixer, alterations_info, keep_water=False)
|
| 50 |
+
fixer.findMissingResidues()
|
| 51 |
+
alterations_info["missing_residues"] = fixer.missingResidues
|
| 52 |
+
fixer.findMissingAtoms()
|
| 53 |
+
alterations_info["missing_heavy_atoms"] = fixer.missingAtoms
|
| 54 |
+
alterations_info["missing_terminals"] = fixer.missingTerminals
|
| 55 |
+
fixer.addMissingAtoms(seed=0)
|
| 56 |
+
fixer.addMissingHydrogens()
|
| 57 |
+
out_handle = io.StringIO()
|
| 58 |
+
app.PDBFile.writeFile(
|
| 59 |
+
fixer.topology, fixer.positions, out_handle, keepIds=True
|
| 60 |
+
)
|
| 61 |
+
return out_handle.getvalue()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def clean_structure(pdb_structure, alterations_info):
|
| 65 |
+
"""Applies additional fixes to an OpenMM structure, to handle edge cases.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
pdb_structure: An OpenMM structure to modify and fix.
|
| 69 |
+
alterations_info: A dict that will store details of changes made.
|
| 70 |
+
"""
|
| 71 |
+
_replace_met_se(pdb_structure, alterations_info)
|
| 72 |
+
_remove_chains_of_length_one(pdb_structure, alterations_info)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _remove_heterogens(fixer, alterations_info, keep_water):
|
| 76 |
+
"""Removes the residues that Pdbfixer considers to be heterogens.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
fixer: A Pdbfixer instance.
|
| 80 |
+
alterations_info: A dict that will store details of changes made.
|
| 81 |
+
keep_water: If True, water (HOH) is not considered to be a heterogen.
|
| 82 |
+
"""
|
| 83 |
+
initial_resnames = set()
|
| 84 |
+
for chain in fixer.topology.chains():
|
| 85 |
+
for residue in chain.residues():
|
| 86 |
+
initial_resnames.add(residue.name)
|
| 87 |
+
fixer.removeHeterogens(keepWater=keep_water)
|
| 88 |
+
final_resnames = set()
|
| 89 |
+
for chain in fixer.topology.chains():
|
| 90 |
+
for residue in chain.residues():
|
| 91 |
+
final_resnames.add(residue.name)
|
| 92 |
+
alterations_info["removed_heterogens"] = initial_resnames.difference(
|
| 93 |
+
final_resnames
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _replace_met_se(pdb_structure, alterations_info):
|
| 98 |
+
"""Replace the Se in any MET residues that were not marked as modified."""
|
| 99 |
+
modified_met_residues = []
|
| 100 |
+
for res in pdb_structure.iter_residues():
|
| 101 |
+
name = res.get_name_with_spaces().strip()
|
| 102 |
+
if name == "MET":
|
| 103 |
+
s_atom = res.get_atom("SD")
|
| 104 |
+
if s_atom.element_symbol == "Se":
|
| 105 |
+
s_atom.element_symbol = "S"
|
| 106 |
+
s_atom.element = element.get_by_symbol("S")
|
| 107 |
+
modified_met_residues.append(s_atom.residue_number)
|
| 108 |
+
alterations_info["Se_in_MET"] = modified_met_residues
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _remove_chains_of_length_one(pdb_structure, alterations_info):
|
| 112 |
+
"""Removes chains that correspond to a single amino acid.
|
| 113 |
+
|
| 114 |
+
A single amino acid in a chain is both N and C terminus. There is no force
|
| 115 |
+
template for this case.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
pdb_structure: An OpenMM pdb_structure to modify and fix.
|
| 119 |
+
alterations_info: A dict that will store details of changes made.
|
| 120 |
+
"""
|
| 121 |
+
removed_chains = {}
|
| 122 |
+
for model in pdb_structure.iter_models():
|
| 123 |
+
valid_chains = [c for c in model.iter_chains() if len(c) > 1]
|
| 124 |
+
invalid_chain_ids = [
|
| 125 |
+
c.chain_id for c in model.iter_chains() if len(c) <= 1
|
| 126 |
+
]
|
| 127 |
+
model.chains = valid_chains
|
| 128 |
+
for chain_id in invalid_chain_ids:
|
| 129 |
+
model.chains_by_id.pop(chain_id)
|
| 130 |
+
removed_chains[model.number] = invalid_chain_ids
|
| 131 |
+
alterations_info["removed_chains"] = removed_chains
|
openfold/np/relax/relax.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Amber relaxation."""
|
| 17 |
+
from typing import Any, Dict, Sequence, Tuple
|
| 18 |
+
from openfold.np import protein
|
| 19 |
+
from openfold.np.relax import amber_minimize, utils
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AmberRelaxation(object):
|
| 24 |
+
"""Amber relaxation."""
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
*,
|
| 28 |
+
max_iterations: int,
|
| 29 |
+
tolerance: float,
|
| 30 |
+
stiffness: float,
|
| 31 |
+
exclude_residues: Sequence[int],
|
| 32 |
+
max_outer_iterations: int,
|
| 33 |
+
use_gpu: bool,
|
| 34 |
+
):
|
| 35 |
+
"""Initialize Amber Relaxer.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
|
| 39 |
+
tolerance: kcal/mol, the energy tolerance of L-BFGS.
|
| 40 |
+
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
|
| 41 |
+
potential.
|
| 42 |
+
exclude_residues: Residues to exclude from per-atom restraining.
|
| 43 |
+
Zero-indexed.
|
| 44 |
+
max_outer_iterations: Maximum number of violation-informed relax
|
| 45 |
+
iterations. A value of 1 will run the non-iterative procedure used in
|
| 46 |
+
CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
|
| 47 |
+
as soon as there are no violations, hence in most cases this causes no
|
| 48 |
+
slowdown. In the worst case we do 20 outer iterations.
|
| 49 |
+
use_gpu: Whether to run on GPU
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
self._max_iterations = max_iterations
|
| 53 |
+
self._tolerance = tolerance
|
| 54 |
+
self._stiffness = stiffness
|
| 55 |
+
self._exclude_residues = exclude_residues
|
| 56 |
+
self._max_outer_iterations = max_outer_iterations
|
| 57 |
+
self._use_gpu = use_gpu
|
| 58 |
+
|
| 59 |
+
def process(
|
| 60 |
+
self, *, prot: protein.Protein
|
| 61 |
+
) -> Tuple[str, Dict[str, Any], np.ndarray]:
|
| 62 |
+
"""Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
|
| 63 |
+
out = amber_minimize.run_pipeline(
|
| 64 |
+
prot=prot,
|
| 65 |
+
max_iterations=self._max_iterations,
|
| 66 |
+
tolerance=self._tolerance,
|
| 67 |
+
stiffness=self._stiffness,
|
| 68 |
+
exclude_residues=self._exclude_residues,
|
| 69 |
+
max_outer_iterations=self._max_outer_iterations,
|
| 70 |
+
use_gpu=self._use_gpu,
|
| 71 |
+
)
|
| 72 |
+
min_pos = out["pos"]
|
| 73 |
+
start_pos = out["posinit"]
|
| 74 |
+
rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0])
|
| 75 |
+
debug_data = {
|
| 76 |
+
"initial_energy": out["einit"],
|
| 77 |
+
"final_energy": out["efinal"],
|
| 78 |
+
"attempts": out["min_attempts"],
|
| 79 |
+
"rmsd": rmsd,
|
| 80 |
+
}
|
| 81 |
+
pdb_str = amber_minimize.clean_protein(prot)
|
| 82 |
+
min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
|
| 83 |
+
min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
|
| 84 |
+
utils.assert_equal_nonterminal_atom_types(
|
| 85 |
+
protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask
|
| 86 |
+
)
|
| 87 |
+
violations = out["structural_violations"][
|
| 88 |
+
"total_per_residue_violations_mask"
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
min_pdb = protein.add_pdb_headers(prot, min_pdb)
|
| 92 |
+
|
| 93 |
+
return min_pdb, debug_data, violations
|
openfold/np/relax/utils.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Utils for minimization."""
|
| 17 |
+
import io
|
| 18 |
+
from openfold.np import residue_constants
|
| 19 |
+
from Bio import PDB
|
| 20 |
+
import numpy as np
|
| 21 |
+
from simtk.openmm import app as openmm_app
|
| 22 |
+
from simtk.openmm.app.internal.pdbstructure import PdbStructure
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
|
| 26 |
+
pdb_file = io.StringIO(pdb_str)
|
| 27 |
+
structure = PdbStructure(pdb_file)
|
| 28 |
+
topology = openmm_app.PDBFile(structure).getTopology()
|
| 29 |
+
with io.StringIO() as f:
|
| 30 |
+
openmm_app.PDBFile.writeFile(topology, pos, f)
|
| 31 |
+
return f.getvalue()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
|
| 35 |
+
"""Overwrites the B-factors in pdb_str with contents of bfactors array.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
pdb_str: An input PDB string.
|
| 39 |
+
bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
|
| 40 |
+
B-factors are per residue; i.e. that the nonzero entries are identical in
|
| 41 |
+
[0, i, :].
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
A new PDB string with the B-factors replaced.
|
| 45 |
+
"""
|
| 46 |
+
if bfactors.shape[-1] != residue_constants.atom_type_num:
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}."
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
parser = PDB.PDBParser(QUIET=True)
|
| 52 |
+
handle = io.StringIO(pdb_str)
|
| 53 |
+
structure = parser.get_structure("", handle)
|
| 54 |
+
|
| 55 |
+
curr_resid = ("", "", "")
|
| 56 |
+
idx = -1
|
| 57 |
+
for atom in structure.get_atoms():
|
| 58 |
+
atom_resid = atom.parent.get_id()
|
| 59 |
+
if atom_resid != curr_resid:
|
| 60 |
+
idx += 1
|
| 61 |
+
if idx >= bfactors.shape[0]:
|
| 62 |
+
raise ValueError(
|
| 63 |
+
"Index into bfactors exceeds number of residues. "
|
| 64 |
+
"B-factors shape: {shape}, idx: {idx}."
|
| 65 |
+
)
|
| 66 |
+
curr_resid = atom_resid
|
| 67 |
+
atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]]
|
| 68 |
+
|
| 69 |
+
new_pdb = io.StringIO()
|
| 70 |
+
pdb_io = PDB.PDBIO()
|
| 71 |
+
pdb_io.set_structure(structure)
|
| 72 |
+
pdb_io.save(new_pdb)
|
| 73 |
+
return new_pdb.getvalue()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def assert_equal_nonterminal_atom_types(
|
| 77 |
+
atom_mask: np.ndarray, ref_atom_mask: np.ndarray
|
| 78 |
+
):
|
| 79 |
+
"""Checks that pre- and post-minimized proteins have same atom set."""
|
| 80 |
+
# Ignore any terminal OXT atoms which may have been added by minimization.
|
| 81 |
+
oxt = residue_constants.atom_order["OXT"]
|
| 82 |
+
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
|
| 83 |
+
no_oxt_mask[..., oxt] = False
|
| 84 |
+
np.testing.assert_almost_equal(
|
| 85 |
+
ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
|
| 86 |
+
)
|
openfold/np/residue_constants.py
ADDED
|
@@ -0,0 +1,1310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Constants used in AlphaFold."""
|
| 17 |
+
|
| 18 |
+
import collections
|
| 19 |
+
import functools
|
| 20 |
+
from typing import Mapping, List, Tuple
|
| 21 |
+
from importlib import resources
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import tree
|
| 25 |
+
|
| 26 |
+
# Internal import (35fd).
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Distance from one CA to next CA [trans configuration: omega = 180].
|
| 30 |
+
ca_ca = 3.80209737096
|
| 31 |
+
|
| 32 |
+
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
|
| 33 |
+
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
|
| 34 |
+
# chi angles so their chi angle lists are empty.
|
| 35 |
+
chi_angles_atoms = {
|
| 36 |
+
"ALA": [],
|
| 37 |
+
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
|
| 38 |
+
"ARG": [
|
| 39 |
+
["N", "CA", "CB", "CG"],
|
| 40 |
+
["CA", "CB", "CG", "CD"],
|
| 41 |
+
["CB", "CG", "CD", "NE"],
|
| 42 |
+
["CG", "CD", "NE", "CZ"],
|
| 43 |
+
],
|
| 44 |
+
"ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
|
| 45 |
+
"ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
|
| 46 |
+
"CYS": [["N", "CA", "CB", "SG"]],
|
| 47 |
+
"GLN": [
|
| 48 |
+
["N", "CA", "CB", "CG"],
|
| 49 |
+
["CA", "CB", "CG", "CD"],
|
| 50 |
+
["CB", "CG", "CD", "OE1"],
|
| 51 |
+
],
|
| 52 |
+
"GLU": [
|
| 53 |
+
["N", "CA", "CB", "CG"],
|
| 54 |
+
["CA", "CB", "CG", "CD"],
|
| 55 |
+
["CB", "CG", "CD", "OE1"],
|
| 56 |
+
],
|
| 57 |
+
"GLY": [],
|
| 58 |
+
"HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
|
| 59 |
+
"ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
|
| 60 |
+
"LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 61 |
+
"LYS": [
|
| 62 |
+
["N", "CA", "CB", "CG"],
|
| 63 |
+
["CA", "CB", "CG", "CD"],
|
| 64 |
+
["CB", "CG", "CD", "CE"],
|
| 65 |
+
["CG", "CD", "CE", "NZ"],
|
| 66 |
+
],
|
| 67 |
+
"MET": [
|
| 68 |
+
["N", "CA", "CB", "CG"],
|
| 69 |
+
["CA", "CB", "CG", "SD"],
|
| 70 |
+
["CB", "CG", "SD", "CE"],
|
| 71 |
+
],
|
| 72 |
+
"PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 73 |
+
"PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
|
| 74 |
+
"SER": [["N", "CA", "CB", "OG"]],
|
| 75 |
+
"THR": [["N", "CA", "CB", "OG1"]],
|
| 76 |
+
"TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 77 |
+
"TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
|
| 78 |
+
"VAL": [["N", "CA", "CB", "CG1"]],
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# If chi angles given in fixed-length array, this matrix determines how to mask
|
| 82 |
+
# them for each AA type. The order is as per restype_order (see below).
|
| 83 |
+
chi_angles_mask = [
|
| 84 |
+
[0.0, 0.0, 0.0, 0.0], # ALA
|
| 85 |
+
[1.0, 1.0, 1.0, 1.0], # ARG
|
| 86 |
+
[1.0, 1.0, 0.0, 0.0], # ASN
|
| 87 |
+
[1.0, 1.0, 0.0, 0.0], # ASP
|
| 88 |
+
[1.0, 0.0, 0.0, 0.0], # CYS
|
| 89 |
+
[1.0, 1.0, 1.0, 0.0], # GLN
|
| 90 |
+
[1.0, 1.0, 1.0, 0.0], # GLU
|
| 91 |
+
[0.0, 0.0, 0.0, 0.0], # GLY
|
| 92 |
+
[1.0, 1.0, 0.0, 0.0], # HIS
|
| 93 |
+
[1.0, 1.0, 0.0, 0.0], # ILE
|
| 94 |
+
[1.0, 1.0, 0.0, 0.0], # LEU
|
| 95 |
+
[1.0, 1.0, 1.0, 1.0], # LYS
|
| 96 |
+
[1.0, 1.0, 1.0, 0.0], # MET
|
| 97 |
+
[1.0, 1.0, 0.0, 0.0], # PHE
|
| 98 |
+
[1.0, 1.0, 0.0, 0.0], # PRO
|
| 99 |
+
[1.0, 0.0, 0.0, 0.0], # SER
|
| 100 |
+
[1.0, 0.0, 0.0, 0.0], # THR
|
| 101 |
+
[1.0, 1.0, 0.0, 0.0], # TRP
|
| 102 |
+
[1.0, 1.0, 0.0, 0.0], # TYR
|
| 103 |
+
[1.0, 0.0, 0.0, 0.0], # VAL
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
# The following chi angles are pi periodic: they can be rotated by a multiple
|
| 107 |
+
# of pi without affecting the structure.
|
| 108 |
+
chi_pi_periodic = [
|
| 109 |
+
[0.0, 0.0, 0.0, 0.0], # ALA
|
| 110 |
+
[0.0, 0.0, 0.0, 0.0], # ARG
|
| 111 |
+
[0.0, 0.0, 0.0, 0.0], # ASN
|
| 112 |
+
[0.0, 1.0, 0.0, 0.0], # ASP
|
| 113 |
+
[0.0, 0.0, 0.0, 0.0], # CYS
|
| 114 |
+
[0.0, 0.0, 0.0, 0.0], # GLN
|
| 115 |
+
[0.0, 0.0, 1.0, 0.0], # GLU
|
| 116 |
+
[0.0, 0.0, 0.0, 0.0], # GLY
|
| 117 |
+
[0.0, 0.0, 0.0, 0.0], # HIS
|
| 118 |
+
[0.0, 0.0, 0.0, 0.0], # ILE
|
| 119 |
+
[0.0, 0.0, 0.0, 0.0], # LEU
|
| 120 |
+
[0.0, 0.0, 0.0, 0.0], # LYS
|
| 121 |
+
[0.0, 0.0, 0.0, 0.0], # MET
|
| 122 |
+
[0.0, 1.0, 0.0, 0.0], # PHE
|
| 123 |
+
[0.0, 0.0, 0.0, 0.0], # PRO
|
| 124 |
+
[0.0, 0.0, 0.0, 0.0], # SER
|
| 125 |
+
[0.0, 0.0, 0.0, 0.0], # THR
|
| 126 |
+
[0.0, 0.0, 0.0, 0.0], # TRP
|
| 127 |
+
[0.0, 1.0, 0.0, 0.0], # TYR
|
| 128 |
+
[0.0, 0.0, 0.0, 0.0], # VAL
|
| 129 |
+
[0.0, 0.0, 0.0, 0.0], # UNK
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
|
| 133 |
+
# psi and chi angles:
|
| 134 |
+
# 0: 'backbone group',
|
| 135 |
+
# 1: 'pre-omega-group', (empty)
|
| 136 |
+
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
|
| 137 |
+
# 3: 'psi-group',
|
| 138 |
+
# 4,5,6,7: 'chi1,2,3,4-group'
|
| 139 |
+
# The atom positions are relative to the axis-end-atom of the corresponding
|
| 140 |
+
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
|
| 141 |
+
# is defined such that the dihedral-angle-definiting atom (the last entry in
|
| 142 |
+
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
|
| 143 |
+
# format: [atomname, group_idx, rel_position]
|
| 144 |
+
rigid_group_atom_positions = {
|
| 145 |
+
"ALA": [
|
| 146 |
+
["N", 0, (-0.525, 1.363, 0.000)],
|
| 147 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 148 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
| 149 |
+
["CB", 0, (-0.529, -0.774, -1.205)],
|
| 150 |
+
["O", 3, (0.627, 1.062, 0.000)],
|
| 151 |
+
],
|
| 152 |
+
"ARG": [
|
| 153 |
+
["N", 0, (-0.524, 1.362, -0.000)],
|
| 154 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 155 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
| 156 |
+
["CB", 0, (-0.524, -0.778, -1.209)],
|
| 157 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
| 158 |
+
["CG", 4, (0.616, 1.390, -0.000)],
|
| 159 |
+
["CD", 5, (0.564, 1.414, 0.000)],
|
| 160 |
+
["NE", 6, (0.539, 1.357, -0.000)],
|
| 161 |
+
["NH1", 7, (0.206, 2.301, 0.000)],
|
| 162 |
+
["NH2", 7, (2.078, 0.978, -0.000)],
|
| 163 |
+
["CZ", 7, (0.758, 1.093, -0.000)],
|
| 164 |
+
],
|
| 165 |
+
"ASN": [
|
| 166 |
+
["N", 0, (-0.536, 1.357, 0.000)],
|
| 167 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 168 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
| 169 |
+
["CB", 0, (-0.531, -0.787, -1.200)],
|
| 170 |
+
["O", 3, (0.625, 1.062, 0.000)],
|
| 171 |
+
["CG", 4, (0.584, 1.399, 0.000)],
|
| 172 |
+
["ND2", 5, (0.593, -1.188, 0.001)],
|
| 173 |
+
["OD1", 5, (0.633, 1.059, 0.000)],
|
| 174 |
+
],
|
| 175 |
+
"ASP": [
|
| 176 |
+
["N", 0, (-0.525, 1.362, -0.000)],
|
| 177 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 178 |
+
["C", 0, (1.527, 0.000, -0.000)],
|
| 179 |
+
["CB", 0, (-0.526, -0.778, -1.208)],
|
| 180 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 181 |
+
["CG", 4, (0.593, 1.398, -0.000)],
|
| 182 |
+
["OD1", 5, (0.610, 1.091, 0.000)],
|
| 183 |
+
["OD2", 5, (0.592, -1.101, -0.003)],
|
| 184 |
+
],
|
| 185 |
+
"CYS": [
|
| 186 |
+
["N", 0, (-0.522, 1.362, -0.000)],
|
| 187 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 188 |
+
["C", 0, (1.524, 0.000, 0.000)],
|
| 189 |
+
["CB", 0, (-0.519, -0.773, -1.212)],
|
| 190 |
+
["O", 3, (0.625, 1.062, -0.000)],
|
| 191 |
+
["SG", 4, (0.728, 1.653, 0.000)],
|
| 192 |
+
],
|
| 193 |
+
"GLN": [
|
| 194 |
+
["N", 0, (-0.526, 1.361, -0.000)],
|
| 195 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 196 |
+
["C", 0, (1.526, 0.000, 0.000)],
|
| 197 |
+
["CB", 0, (-0.525, -0.779, -1.207)],
|
| 198 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 199 |
+
["CG", 4, (0.615, 1.393, 0.000)],
|
| 200 |
+
["CD", 5, (0.587, 1.399, -0.000)],
|
| 201 |
+
["NE2", 6, (0.593, -1.189, -0.001)],
|
| 202 |
+
["OE1", 6, (0.634, 1.060, 0.000)],
|
| 203 |
+
],
|
| 204 |
+
"GLU": [
|
| 205 |
+
["N", 0, (-0.528, 1.361, 0.000)],
|
| 206 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 207 |
+
["C", 0, (1.526, -0.000, -0.000)],
|
| 208 |
+
["CB", 0, (-0.526, -0.781, -1.207)],
|
| 209 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
| 210 |
+
["CG", 4, (0.615, 1.392, 0.000)],
|
| 211 |
+
["CD", 5, (0.600, 1.397, 0.000)],
|
| 212 |
+
["OE1", 6, (0.607, 1.095, -0.000)],
|
| 213 |
+
["OE2", 6, (0.589, -1.104, -0.001)],
|
| 214 |
+
],
|
| 215 |
+
"GLY": [
|
| 216 |
+
["N", 0, (-0.572, 1.337, 0.000)],
|
| 217 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 218 |
+
["C", 0, (1.517, -0.000, -0.000)],
|
| 219 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 220 |
+
],
|
| 221 |
+
"HIS": [
|
| 222 |
+
["N", 0, (-0.527, 1.360, 0.000)],
|
| 223 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 224 |
+
["C", 0, (1.525, 0.000, 0.000)],
|
| 225 |
+
["CB", 0, (-0.525, -0.778, -1.208)],
|
| 226 |
+
["O", 3, (0.625, 1.063, 0.000)],
|
| 227 |
+
["CG", 4, (0.600, 1.370, -0.000)],
|
| 228 |
+
["CD2", 5, (0.889, -1.021, 0.003)],
|
| 229 |
+
["ND1", 5, (0.744, 1.160, -0.000)],
|
| 230 |
+
["CE1", 5, (2.030, 0.851, 0.002)],
|
| 231 |
+
["NE2", 5, (2.145, -0.466, 0.004)],
|
| 232 |
+
],
|
| 233 |
+
"ILE": [
|
| 234 |
+
["N", 0, (-0.493, 1.373, -0.000)],
|
| 235 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 236 |
+
["C", 0, (1.527, -0.000, -0.000)],
|
| 237 |
+
["CB", 0, (-0.536, -0.793, -1.213)],
|
| 238 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
| 239 |
+
["CG1", 4, (0.534, 1.437, -0.000)],
|
| 240 |
+
["CG2", 4, (0.540, -0.785, -1.199)],
|
| 241 |
+
["CD1", 5, (0.619, 1.391, 0.000)],
|
| 242 |
+
],
|
| 243 |
+
"LEU": [
|
| 244 |
+
["N", 0, (-0.520, 1.363, 0.000)],
|
| 245 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 246 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
| 247 |
+
["CB", 0, (-0.522, -0.773, -1.214)],
|
| 248 |
+
["O", 3, (0.625, 1.063, -0.000)],
|
| 249 |
+
["CG", 4, (0.678, 1.371, 0.000)],
|
| 250 |
+
["CD1", 5, (0.530, 1.430, -0.000)],
|
| 251 |
+
["CD2", 5, (0.535, -0.774, 1.200)],
|
| 252 |
+
],
|
| 253 |
+
"LYS": [
|
| 254 |
+
["N", 0, (-0.526, 1.362, -0.000)],
|
| 255 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 256 |
+
["C", 0, (1.526, 0.000, 0.000)],
|
| 257 |
+
["CB", 0, (-0.524, -0.778, -1.208)],
|
| 258 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 259 |
+
["CG", 4, (0.619, 1.390, 0.000)],
|
| 260 |
+
["CD", 5, (0.559, 1.417, 0.000)],
|
| 261 |
+
["CE", 6, (0.560, 1.416, 0.000)],
|
| 262 |
+
["NZ", 7, (0.554, 1.387, 0.000)],
|
| 263 |
+
],
|
| 264 |
+
"MET": [
|
| 265 |
+
["N", 0, (-0.521, 1.364, -0.000)],
|
| 266 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 267 |
+
["C", 0, (1.525, 0.000, 0.000)],
|
| 268 |
+
["CB", 0, (-0.523, -0.776, -1.210)],
|
| 269 |
+
["O", 3, (0.625, 1.062, -0.000)],
|
| 270 |
+
["CG", 4, (0.613, 1.391, -0.000)],
|
| 271 |
+
["SD", 5, (0.703, 1.695, 0.000)],
|
| 272 |
+
["CE", 6, (0.320, 1.786, -0.000)],
|
| 273 |
+
],
|
| 274 |
+
"PHE": [
|
| 275 |
+
["N", 0, (-0.518, 1.363, 0.000)],
|
| 276 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 277 |
+
["C", 0, (1.524, 0.000, -0.000)],
|
| 278 |
+
["CB", 0, (-0.525, -0.776, -1.212)],
|
| 279 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 280 |
+
["CG", 4, (0.607, 1.377, 0.000)],
|
| 281 |
+
["CD1", 5, (0.709, 1.195, -0.000)],
|
| 282 |
+
["CD2", 5, (0.706, -1.196, 0.000)],
|
| 283 |
+
["CE1", 5, (2.102, 1.198, -0.000)],
|
| 284 |
+
["CE2", 5, (2.098, -1.201, -0.000)],
|
| 285 |
+
["CZ", 5, (2.794, -0.003, -0.001)],
|
| 286 |
+
],
|
| 287 |
+
"PRO": [
|
| 288 |
+
["N", 0, (-0.566, 1.351, -0.000)],
|
| 289 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 290 |
+
["C", 0, (1.527, -0.000, 0.000)],
|
| 291 |
+
["CB", 0, (-0.546, -0.611, -1.293)],
|
| 292 |
+
["O", 3, (0.621, 1.066, 0.000)],
|
| 293 |
+
["CG", 4, (0.382, 1.445, 0.0)],
|
| 294 |
+
# ['CD', 5, (0.427, 1.440, 0.0)],
|
| 295 |
+
["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
|
| 296 |
+
],
|
| 297 |
+
"SER": [
|
| 298 |
+
["N", 0, (-0.529, 1.360, -0.000)],
|
| 299 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 300 |
+
["C", 0, (1.525, -0.000, -0.000)],
|
| 301 |
+
["CB", 0, (-0.518, -0.777, -1.211)],
|
| 302 |
+
["O", 3, (0.626, 1.062, -0.000)],
|
| 303 |
+
["OG", 4, (0.503, 1.325, 0.000)],
|
| 304 |
+
],
|
| 305 |
+
"THR": [
|
| 306 |
+
["N", 0, (-0.517, 1.364, 0.000)],
|
| 307 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 308 |
+
["C", 0, (1.526, 0.000, -0.000)],
|
| 309 |
+
["CB", 0, (-0.516, -0.793, -1.215)],
|
| 310 |
+
["O", 3, (0.626, 1.062, 0.000)],
|
| 311 |
+
["CG2", 4, (0.550, -0.718, -1.228)],
|
| 312 |
+
["OG1", 4, (0.472, 1.353, 0.000)],
|
| 313 |
+
],
|
| 314 |
+
"TRP": [
|
| 315 |
+
["N", 0, (-0.521, 1.363, 0.000)],
|
| 316 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 317 |
+
["C", 0, (1.525, -0.000, 0.000)],
|
| 318 |
+
["CB", 0, (-0.523, -0.776, -1.212)],
|
| 319 |
+
["O", 3, (0.627, 1.062, 0.000)],
|
| 320 |
+
["CG", 4, (0.609, 1.370, -0.000)],
|
| 321 |
+
["CD1", 5, (0.824, 1.091, 0.000)],
|
| 322 |
+
["CD2", 5, (0.854, -1.148, -0.005)],
|
| 323 |
+
["CE2", 5, (2.186, -0.678, -0.007)],
|
| 324 |
+
["CE3", 5, (0.622, -2.530, -0.007)],
|
| 325 |
+
["NE1", 5, (2.140, 0.690, -0.004)],
|
| 326 |
+
["CH2", 5, (3.028, -2.890, -0.013)],
|
| 327 |
+
["CZ2", 5, (3.283, -1.543, -0.011)],
|
| 328 |
+
["CZ3", 5, (1.715, -3.389, -0.011)],
|
| 329 |
+
],
|
| 330 |
+
"TYR": [
|
| 331 |
+
["N", 0, (-0.522, 1.362, 0.000)],
|
| 332 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 333 |
+
["C", 0, (1.524, -0.000, -0.000)],
|
| 334 |
+
["CB", 0, (-0.522, -0.776, -1.213)],
|
| 335 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
| 336 |
+
["CG", 4, (0.607, 1.382, -0.000)],
|
| 337 |
+
["CD1", 5, (0.716, 1.195, -0.000)],
|
| 338 |
+
["CD2", 5, (0.713, -1.194, -0.001)],
|
| 339 |
+
["CE1", 5, (2.107, 1.200, -0.002)],
|
| 340 |
+
["CE2", 5, (2.104, -1.201, -0.003)],
|
| 341 |
+
["OH", 5, (4.168, -0.002, -0.005)],
|
| 342 |
+
["CZ", 5, (2.791, -0.001, -0.003)],
|
| 343 |
+
],
|
| 344 |
+
"VAL": [
|
| 345 |
+
["N", 0, (-0.494, 1.373, -0.000)],
|
| 346 |
+
["CA", 0, (0.000, 0.000, 0.000)],
|
| 347 |
+
["C", 0, (1.527, -0.000, -0.000)],
|
| 348 |
+
["CB", 0, (-0.533, -0.795, -1.213)],
|
| 349 |
+
["O", 3, (0.627, 1.062, -0.000)],
|
| 350 |
+
["CG1", 4, (0.540, 1.429, -0.000)],
|
| 351 |
+
["CG2", 4, (0.533, -0.776, 1.203)],
|
| 352 |
+
],
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
|
| 356 |
+
residue_atoms = {
|
| 357 |
+
"ALA": ["C", "CA", "CB", "N", "O"],
|
| 358 |
+
"ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
|
| 359 |
+
"ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
|
| 360 |
+
"ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
|
| 361 |
+
"CYS": ["C", "CA", "CB", "N", "O", "SG"],
|
| 362 |
+
"GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
|
| 363 |
+
"GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
|
| 364 |
+
"GLY": ["C", "CA", "N", "O"],
|
| 365 |
+
"HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
|
| 366 |
+
"ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
|
| 367 |
+
"LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
|
| 368 |
+
"LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
|
| 369 |
+
"MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
|
| 370 |
+
"PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
|
| 371 |
+
"PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
|
| 372 |
+
"SER": ["C", "CA", "CB", "N", "O", "OG"],
|
| 373 |
+
"THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
|
| 374 |
+
"TRP": [
|
| 375 |
+
"C",
|
| 376 |
+
"CA",
|
| 377 |
+
"CB",
|
| 378 |
+
"CG",
|
| 379 |
+
"CD1",
|
| 380 |
+
"CD2",
|
| 381 |
+
"CE2",
|
| 382 |
+
"CE3",
|
| 383 |
+
"CZ2",
|
| 384 |
+
"CZ3",
|
| 385 |
+
"CH2",
|
| 386 |
+
"N",
|
| 387 |
+
"NE1",
|
| 388 |
+
"O",
|
| 389 |
+
],
|
| 390 |
+
"TYR": [
|
| 391 |
+
"C",
|
| 392 |
+
"CA",
|
| 393 |
+
"CB",
|
| 394 |
+
"CG",
|
| 395 |
+
"CD1",
|
| 396 |
+
"CD2",
|
| 397 |
+
"CE1",
|
| 398 |
+
"CE2",
|
| 399 |
+
"CZ",
|
| 400 |
+
"N",
|
| 401 |
+
"O",
|
| 402 |
+
"OH",
|
| 403 |
+
],
|
| 404 |
+
"VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
# Naming swaps for ambiguous atom names.
|
| 408 |
+
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
|
| 409 |
+
# 4 of the 20 amino acids.
|
| 410 |
+
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
|
| 411 |
+
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
|
| 412 |
+
# the 'ambiguous' atoms and their neighbours)
|
| 413 |
+
# TODO: ^ interpret this
|
| 414 |
+
residue_atom_renaming_swaps = {
|
| 415 |
+
"ASP": {"OD1": "OD2"},
|
| 416 |
+
"GLU": {"OE1": "OE2"},
|
| 417 |
+
"PHE": {"CD1": "CD2", "CE1": "CE2"},
|
| 418 |
+
"TYR": {"CD1": "CD2", "CE1": "CE2"},
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
|
| 422 |
+
van_der_waals_radius = {
|
| 423 |
+
"C": 1.7,
|
| 424 |
+
"N": 1.55,
|
| 425 |
+
"O": 1.52,
|
| 426 |
+
"S": 1.8,
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
Bond = collections.namedtuple(
|
| 430 |
+
"Bond", ["atom1_name", "atom2_name", "length", "stddev"]
|
| 431 |
+
)
|
| 432 |
+
BondAngle = collections.namedtuple(
|
| 433 |
+
"BondAngle",
|
| 434 |
+
["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
@functools.lru_cache(maxsize=None)
|
| 439 |
+
def load_stereo_chemical_props() -> Tuple[
|
| 440 |
+
Mapping[str, List[Bond]],
|
| 441 |
+
Mapping[str, List[Bond]],
|
| 442 |
+
Mapping[str, List[BondAngle]],
|
| 443 |
+
]:
|
| 444 |
+
"""Load stereo_chemical_props.txt into a nice structure.
|
| 445 |
+
|
| 446 |
+
Load literature values for bond lengths and bond angles and translate
|
| 447 |
+
bond angles into the length of the opposite edge of the triangle
|
| 448 |
+
("residue_virtual_bonds").
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
residue_bonds: dict that maps resname --> list of Bond tuples
|
| 452 |
+
residue_virtual_bonds: dict that maps resname --> list of Bond tuples
|
| 453 |
+
residue_bond_angles: dict that maps resname --> list of BondAngle tuples
|
| 454 |
+
"""
|
| 455 |
+
# TODO: this file should be downloaded in a setup script
|
| 456 |
+
stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt")
|
| 457 |
+
|
| 458 |
+
lines_iter = iter(stereo_chemical_props.splitlines())
|
| 459 |
+
# Load bond lengths.
|
| 460 |
+
residue_bonds = {}
|
| 461 |
+
next(lines_iter) # Skip header line.
|
| 462 |
+
for line in lines_iter:
|
| 463 |
+
if line.strip() == "-":
|
| 464 |
+
break
|
| 465 |
+
bond, resname, length, stddev = line.split()
|
| 466 |
+
atom1, atom2 = bond.split("-")
|
| 467 |
+
if resname not in residue_bonds:
|
| 468 |
+
residue_bonds[resname] = []
|
| 469 |
+
residue_bonds[resname].append(
|
| 470 |
+
Bond(atom1, atom2, float(length), float(stddev))
|
| 471 |
+
)
|
| 472 |
+
residue_bonds["UNK"] = []
|
| 473 |
+
|
| 474 |
+
# Load bond angles.
|
| 475 |
+
residue_bond_angles = {}
|
| 476 |
+
next(lines_iter) # Skip empty line.
|
| 477 |
+
next(lines_iter) # Skip header line.
|
| 478 |
+
for line in lines_iter:
|
| 479 |
+
if line.strip() == "-":
|
| 480 |
+
break
|
| 481 |
+
bond, resname, angle_degree, stddev_degree = line.split()
|
| 482 |
+
atom1, atom2, atom3 = bond.split("-")
|
| 483 |
+
if resname not in residue_bond_angles:
|
| 484 |
+
residue_bond_angles[resname] = []
|
| 485 |
+
residue_bond_angles[resname].append(
|
| 486 |
+
BondAngle(
|
| 487 |
+
atom1,
|
| 488 |
+
atom2,
|
| 489 |
+
atom3,
|
| 490 |
+
float(angle_degree) / 180.0 * np.pi,
|
| 491 |
+
float(stddev_degree) / 180.0 * np.pi,
|
| 492 |
+
)
|
| 493 |
+
)
|
| 494 |
+
residue_bond_angles["UNK"] = []
|
| 495 |
+
|
| 496 |
+
def make_bond_key(atom1_name, atom2_name):
|
| 497 |
+
"""Unique key to lookup bonds."""
|
| 498 |
+
return "-".join(sorted([atom1_name, atom2_name]))
|
| 499 |
+
|
| 500 |
+
# Translate bond angles into distances ("virtual bonds").
|
| 501 |
+
residue_virtual_bonds = {}
|
| 502 |
+
for resname, bond_angles in residue_bond_angles.items():
|
| 503 |
+
# Create a fast lookup dict for bond lengths.
|
| 504 |
+
bond_cache = {}
|
| 505 |
+
for b in residue_bonds[resname]:
|
| 506 |
+
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
|
| 507 |
+
residue_virtual_bonds[resname] = []
|
| 508 |
+
for ba in bond_angles:
|
| 509 |
+
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
|
| 510 |
+
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
|
| 511 |
+
|
| 512 |
+
# Compute distance between atom1 and atom3 using the law of cosines
|
| 513 |
+
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
|
| 514 |
+
gamma = ba.angle_rad
|
| 515 |
+
length = np.sqrt(
|
| 516 |
+
bond1.length ** 2
|
| 517 |
+
+ bond2.length ** 2
|
| 518 |
+
- 2 * bond1.length * bond2.length * np.cos(gamma)
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# Propagation of uncertainty assuming uncorrelated errors.
|
| 522 |
+
dl_outer = 0.5 / length
|
| 523 |
+
dl_dgamma = (
|
| 524 |
+
2 * bond1.length * bond2.length * np.sin(gamma)
|
| 525 |
+
) * dl_outer
|
| 526 |
+
dl_db1 = (
|
| 527 |
+
2 * bond1.length - 2 * bond2.length * np.cos(gamma)
|
| 528 |
+
) * dl_outer
|
| 529 |
+
dl_db2 = (
|
| 530 |
+
2 * bond2.length - 2 * bond1.length * np.cos(gamma)
|
| 531 |
+
) * dl_outer
|
| 532 |
+
stddev = np.sqrt(
|
| 533 |
+
(dl_dgamma * ba.stddev) ** 2
|
| 534 |
+
+ (dl_db1 * bond1.stddev) ** 2
|
| 535 |
+
+ (dl_db2 * bond2.stddev) ** 2
|
| 536 |
+
)
|
| 537 |
+
residue_virtual_bonds[resname].append(
|
| 538 |
+
Bond(ba.atom1_name, ba.atom3name, length, stddev)
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
# Between-residue bond lengths for general bonds (first element) and for Proline
|
| 545 |
+
# (second element).
|
| 546 |
+
between_res_bond_length_c_n = [1.329, 1.341]
|
| 547 |
+
between_res_bond_length_stddev_c_n = [0.014, 0.016]
|
| 548 |
+
|
| 549 |
+
# Between-residue cos_angles.
|
| 550 |
+
between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
|
| 551 |
+
between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
|
| 552 |
+
|
| 553 |
+
# This mapping is used when we need to store atom data in a format that requires
|
| 554 |
+
# fixed atom data size for every residue (e.g. a numpy array).
|
| 555 |
+
atom_types = [
|
| 556 |
+
"N",
|
| 557 |
+
"CA",
|
| 558 |
+
"C",
|
| 559 |
+
"CB",
|
| 560 |
+
"O",
|
| 561 |
+
"CG",
|
| 562 |
+
"CG1",
|
| 563 |
+
"CG2",
|
| 564 |
+
"OG",
|
| 565 |
+
"OG1",
|
| 566 |
+
"SG",
|
| 567 |
+
"CD",
|
| 568 |
+
"CD1",
|
| 569 |
+
"CD2",
|
| 570 |
+
"ND1",
|
| 571 |
+
"ND2",
|
| 572 |
+
"OD1",
|
| 573 |
+
"OD2",
|
| 574 |
+
"SD",
|
| 575 |
+
"CE",
|
| 576 |
+
"CE1",
|
| 577 |
+
"CE2",
|
| 578 |
+
"CE3",
|
| 579 |
+
"NE",
|
| 580 |
+
"NE1",
|
| 581 |
+
"NE2",
|
| 582 |
+
"OE1",
|
| 583 |
+
"OE2",
|
| 584 |
+
"CH2",
|
| 585 |
+
"NH1",
|
| 586 |
+
"NH2",
|
| 587 |
+
"OH",
|
| 588 |
+
"CZ",
|
| 589 |
+
"CZ2",
|
| 590 |
+
"CZ3",
|
| 591 |
+
"NZ",
|
| 592 |
+
"OXT",
|
| 593 |
+
]
|
| 594 |
+
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
|
| 595 |
+
atom_type_num = len(atom_types) # := 37.
|
| 596 |
+
|
| 597 |
+
# A compact atom encoding with 14 columns
|
| 598 |
+
# pylint: disable=line-too-long
|
| 599 |
+
# pylint: disable=bad-whitespace
|
| 600 |
+
restype_name_to_atom14_names = {
|
| 601 |
+
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
|
| 602 |
+
"ARG": [
|
| 603 |
+
"N",
|
| 604 |
+
"CA",
|
| 605 |
+
"C",
|
| 606 |
+
"O",
|
| 607 |
+
"CB",
|
| 608 |
+
"CG",
|
| 609 |
+
"CD",
|
| 610 |
+
"NE",
|
| 611 |
+
"CZ",
|
| 612 |
+
"NH1",
|
| 613 |
+
"NH2",
|
| 614 |
+
"",
|
| 615 |
+
"",
|
| 616 |
+
"",
|
| 617 |
+
],
|
| 618 |
+
"ASN": [
|
| 619 |
+
"N",
|
| 620 |
+
"CA",
|
| 621 |
+
"C",
|
| 622 |
+
"O",
|
| 623 |
+
"CB",
|
| 624 |
+
"CG",
|
| 625 |
+
"OD1",
|
| 626 |
+
"ND2",
|
| 627 |
+
"",
|
| 628 |
+
"",
|
| 629 |
+
"",
|
| 630 |
+
"",
|
| 631 |
+
"",
|
| 632 |
+
"",
|
| 633 |
+
],
|
| 634 |
+
"ASP": [
|
| 635 |
+
"N",
|
| 636 |
+
"CA",
|
| 637 |
+
"C",
|
| 638 |
+
"O",
|
| 639 |
+
"CB",
|
| 640 |
+
"CG",
|
| 641 |
+
"OD1",
|
| 642 |
+
"OD2",
|
| 643 |
+
"",
|
| 644 |
+
"",
|
| 645 |
+
"",
|
| 646 |
+
"",
|
| 647 |
+
"",
|
| 648 |
+
"",
|
| 649 |
+
],
|
| 650 |
+
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
|
| 651 |
+
"GLN": [
|
| 652 |
+
"N",
|
| 653 |
+
"CA",
|
| 654 |
+
"C",
|
| 655 |
+
"O",
|
| 656 |
+
"CB",
|
| 657 |
+
"CG",
|
| 658 |
+
"CD",
|
| 659 |
+
"OE1",
|
| 660 |
+
"NE2",
|
| 661 |
+
"",
|
| 662 |
+
"",
|
| 663 |
+
"",
|
| 664 |
+
"",
|
| 665 |
+
"",
|
| 666 |
+
],
|
| 667 |
+
"GLU": [
|
| 668 |
+
"N",
|
| 669 |
+
"CA",
|
| 670 |
+
"C",
|
| 671 |
+
"O",
|
| 672 |
+
"CB",
|
| 673 |
+
"CG",
|
| 674 |
+
"CD",
|
| 675 |
+
"OE1",
|
| 676 |
+
"OE2",
|
| 677 |
+
"",
|
| 678 |
+
"",
|
| 679 |
+
"",
|
| 680 |
+
"",
|
| 681 |
+
"",
|
| 682 |
+
],
|
| 683 |
+
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
|
| 684 |
+
"HIS": [
|
| 685 |
+
"N",
|
| 686 |
+
"CA",
|
| 687 |
+
"C",
|
| 688 |
+
"O",
|
| 689 |
+
"CB",
|
| 690 |
+
"CG",
|
| 691 |
+
"ND1",
|
| 692 |
+
"CD2",
|
| 693 |
+
"CE1",
|
| 694 |
+
"NE2",
|
| 695 |
+
"",
|
| 696 |
+
"",
|
| 697 |
+
"",
|
| 698 |
+
"",
|
| 699 |
+
],
|
| 700 |
+
"ILE": [
|
| 701 |
+
"N",
|
| 702 |
+
"CA",
|
| 703 |
+
"C",
|
| 704 |
+
"O",
|
| 705 |
+
"CB",
|
| 706 |
+
"CG1",
|
| 707 |
+
"CG2",
|
| 708 |
+
"CD1",
|
| 709 |
+
"",
|
| 710 |
+
"",
|
| 711 |
+
"",
|
| 712 |
+
"",
|
| 713 |
+
"",
|
| 714 |
+
"",
|
| 715 |
+
],
|
| 716 |
+
"LEU": [
|
| 717 |
+
"N",
|
| 718 |
+
"CA",
|
| 719 |
+
"C",
|
| 720 |
+
"O",
|
| 721 |
+
"CB",
|
| 722 |
+
"CG",
|
| 723 |
+
"CD1",
|
| 724 |
+
"CD2",
|
| 725 |
+
"",
|
| 726 |
+
"",
|
| 727 |
+
"",
|
| 728 |
+
"",
|
| 729 |
+
"",
|
| 730 |
+
"",
|
| 731 |
+
],
|
| 732 |
+
"LYS": [
|
| 733 |
+
"N",
|
| 734 |
+
"CA",
|
| 735 |
+
"C",
|
| 736 |
+
"O",
|
| 737 |
+
"CB",
|
| 738 |
+
"CG",
|
| 739 |
+
"CD",
|
| 740 |
+
"CE",
|
| 741 |
+
"NZ",
|
| 742 |
+
"",
|
| 743 |
+
"",
|
| 744 |
+
"",
|
| 745 |
+
"",
|
| 746 |
+
"",
|
| 747 |
+
],
|
| 748 |
+
"MET": [
|
| 749 |
+
"N",
|
| 750 |
+
"CA",
|
| 751 |
+
"C",
|
| 752 |
+
"O",
|
| 753 |
+
"CB",
|
| 754 |
+
"CG",
|
| 755 |
+
"SD",
|
| 756 |
+
"CE",
|
| 757 |
+
"",
|
| 758 |
+
"",
|
| 759 |
+
"",
|
| 760 |
+
"",
|
| 761 |
+
"",
|
| 762 |
+
"",
|
| 763 |
+
],
|
| 764 |
+
"PHE": [
|
| 765 |
+
"N",
|
| 766 |
+
"CA",
|
| 767 |
+
"C",
|
| 768 |
+
"O",
|
| 769 |
+
"CB",
|
| 770 |
+
"CG",
|
| 771 |
+
"CD1",
|
| 772 |
+
"CD2",
|
| 773 |
+
"CE1",
|
| 774 |
+
"CE2",
|
| 775 |
+
"CZ",
|
| 776 |
+
"",
|
| 777 |
+
"",
|
| 778 |
+
"",
|
| 779 |
+
],
|
| 780 |
+
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
|
| 781 |
+
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
|
| 782 |
+
"THR": [
|
| 783 |
+
"N",
|
| 784 |
+
"CA",
|
| 785 |
+
"C",
|
| 786 |
+
"O",
|
| 787 |
+
"CB",
|
| 788 |
+
"OG1",
|
| 789 |
+
"CG2",
|
| 790 |
+
"",
|
| 791 |
+
"",
|
| 792 |
+
"",
|
| 793 |
+
"",
|
| 794 |
+
"",
|
| 795 |
+
"",
|
| 796 |
+
"",
|
| 797 |
+
],
|
| 798 |
+
"TRP": [
|
| 799 |
+
"N",
|
| 800 |
+
"CA",
|
| 801 |
+
"C",
|
| 802 |
+
"O",
|
| 803 |
+
"CB",
|
| 804 |
+
"CG",
|
| 805 |
+
"CD1",
|
| 806 |
+
"CD2",
|
| 807 |
+
"CE2",
|
| 808 |
+
"CE3",
|
| 809 |
+
"NE1",
|
| 810 |
+
"CZ2",
|
| 811 |
+
"CZ3",
|
| 812 |
+
"CH2",
|
| 813 |
+
],
|
| 814 |
+
"TYR": [
|
| 815 |
+
"N",
|
| 816 |
+
"CA",
|
| 817 |
+
"C",
|
| 818 |
+
"O",
|
| 819 |
+
"CB",
|
| 820 |
+
"CG",
|
| 821 |
+
"CD1",
|
| 822 |
+
"CD2",
|
| 823 |
+
"CE1",
|
| 824 |
+
"CE2",
|
| 825 |
+
"CZ",
|
| 826 |
+
"OH",
|
| 827 |
+
"",
|
| 828 |
+
"",
|
| 829 |
+
],
|
| 830 |
+
"VAL": [
|
| 831 |
+
"N",
|
| 832 |
+
"CA",
|
| 833 |
+
"C",
|
| 834 |
+
"O",
|
| 835 |
+
"CB",
|
| 836 |
+
"CG1",
|
| 837 |
+
"CG2",
|
| 838 |
+
"",
|
| 839 |
+
"",
|
| 840 |
+
"",
|
| 841 |
+
"",
|
| 842 |
+
"",
|
| 843 |
+
"",
|
| 844 |
+
"",
|
| 845 |
+
],
|
| 846 |
+
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
|
| 847 |
+
}
|
| 848 |
+
# pylint: enable=line-too-long
|
| 849 |
+
# pylint: enable=bad-whitespace
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
# This is the standard residue order when coding AA type as a number.
|
| 853 |
+
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
|
| 854 |
+
restypes = [
|
| 855 |
+
"A",
|
| 856 |
+
"R",
|
| 857 |
+
"N",
|
| 858 |
+
"D",
|
| 859 |
+
"C",
|
| 860 |
+
"Q",
|
| 861 |
+
"E",
|
| 862 |
+
"G",
|
| 863 |
+
"H",
|
| 864 |
+
"I",
|
| 865 |
+
"L",
|
| 866 |
+
"K",
|
| 867 |
+
"M",
|
| 868 |
+
"F",
|
| 869 |
+
"P",
|
| 870 |
+
"S",
|
| 871 |
+
"T",
|
| 872 |
+
"W",
|
| 873 |
+
"Y",
|
| 874 |
+
"V",
|
| 875 |
+
]
|
| 876 |
+
restype_order = {restype: i for i, restype in enumerate(restypes)}
|
| 877 |
+
restype_num = len(restypes) # := 20.
|
| 878 |
+
unk_restype_index = restype_num # Catch-all index for unknown restypes.
|
| 879 |
+
|
| 880 |
+
restypes_with_x = restypes + ["X"]
|
| 881 |
+
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
def sequence_to_onehot(
|
| 885 |
+
sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
|
| 886 |
+
) -> np.ndarray:
|
| 887 |
+
"""Maps the given sequence into a one-hot encoded matrix.
|
| 888 |
+
|
| 889 |
+
Args:
|
| 890 |
+
sequence: An amino acid sequence.
|
| 891 |
+
mapping: A dictionary mapping amino acids to integers.
|
| 892 |
+
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
|
| 893 |
+
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
|
| 894 |
+
amino acid 'X', an error will be thrown. If False, any amino acid not in
|
| 895 |
+
the mapping will throw an error.
|
| 896 |
+
|
| 897 |
+
Returns:
|
| 898 |
+
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
|
| 899 |
+
the sequence.
|
| 900 |
+
|
| 901 |
+
Raises:
|
| 902 |
+
ValueError: If the mapping doesn't contain values from 0 to
|
| 903 |
+
num_unique_aas - 1 without any gaps.
|
| 904 |
+
"""
|
| 905 |
+
num_entries = max(mapping.values()) + 1
|
| 906 |
+
|
| 907 |
+
if sorted(set(mapping.values())) != list(range(num_entries)):
|
| 908 |
+
raise ValueError(
|
| 909 |
+
"The mapping must have values from 0 to num_unique_aas-1 "
|
| 910 |
+
"without any gaps. Got: %s" % sorted(mapping.values())
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
|
| 914 |
+
|
| 915 |
+
for aa_index, aa_type in enumerate(sequence):
|
| 916 |
+
if map_unknown_to_x:
|
| 917 |
+
if aa_type.isalpha() and aa_type.isupper():
|
| 918 |
+
aa_id = mapping.get(aa_type, mapping["X"])
|
| 919 |
+
else:
|
| 920 |
+
raise ValueError(
|
| 921 |
+
f"Invalid character in the sequence: {aa_type}"
|
| 922 |
+
)
|
| 923 |
+
else:
|
| 924 |
+
aa_id = mapping[aa_type]
|
| 925 |
+
one_hot_arr[aa_index, aa_id] = 1
|
| 926 |
+
|
| 927 |
+
return one_hot_arr
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
restype_1to3 = {
|
| 931 |
+
"A": "ALA",
|
| 932 |
+
"R": "ARG",
|
| 933 |
+
"N": "ASN",
|
| 934 |
+
"D": "ASP",
|
| 935 |
+
"C": "CYS",
|
| 936 |
+
"Q": "GLN",
|
| 937 |
+
"E": "GLU",
|
| 938 |
+
"G": "GLY",
|
| 939 |
+
"H": "HIS",
|
| 940 |
+
"I": "ILE",
|
| 941 |
+
"L": "LEU",
|
| 942 |
+
"K": "LYS",
|
| 943 |
+
"M": "MET",
|
| 944 |
+
"F": "PHE",
|
| 945 |
+
"P": "PRO",
|
| 946 |
+
"S": "SER",
|
| 947 |
+
"T": "THR",
|
| 948 |
+
"W": "TRP",
|
| 949 |
+
"Y": "TYR",
|
| 950 |
+
"V": "VAL",
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
|
| 955 |
+
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
|
| 956 |
+
# many more, and less common, three letter names as keys and maps many of these
|
| 957 |
+
# to the same one letter name (including 'X' and 'U' which we don't use here).
|
| 958 |
+
restype_3to1 = {v: k for k, v in restype_1to3.items()}
|
| 959 |
+
|
| 960 |
+
# Define a restype name for all unknown residues.
|
| 961 |
+
unk_restype = "UNK"
|
| 962 |
+
|
| 963 |
+
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
|
| 964 |
+
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
# The mapping here uses hhblits convention, so that B is mapped to D, J and O
|
| 968 |
+
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
|
| 969 |
+
# remaining 20 amino acids are kept in alphabetical order.
|
| 970 |
+
# There are 2 non-amino acid codes, X (representing any amino acid) and
|
| 971 |
+
# "-" representing a missing amino acid in an alignment. The id for these
|
| 972 |
+
# codes is put at the end (20 and 21) so that they can easily be ignored if
|
| 973 |
+
# desired.
|
| 974 |
+
HHBLITS_AA_TO_ID = {
|
| 975 |
+
"A": 0,
|
| 976 |
+
"B": 2,
|
| 977 |
+
"C": 1,
|
| 978 |
+
"D": 2,
|
| 979 |
+
"E": 3,
|
| 980 |
+
"F": 4,
|
| 981 |
+
"G": 5,
|
| 982 |
+
"H": 6,
|
| 983 |
+
"I": 7,
|
| 984 |
+
"J": 20,
|
| 985 |
+
"K": 8,
|
| 986 |
+
"L": 9,
|
| 987 |
+
"M": 10,
|
| 988 |
+
"N": 11,
|
| 989 |
+
"O": 20,
|
| 990 |
+
"P": 12,
|
| 991 |
+
"Q": 13,
|
| 992 |
+
"R": 14,
|
| 993 |
+
"S": 15,
|
| 994 |
+
"T": 16,
|
| 995 |
+
"U": 1,
|
| 996 |
+
"V": 17,
|
| 997 |
+
"W": 18,
|
| 998 |
+
"X": 20,
|
| 999 |
+
"Y": 19,
|
| 1000 |
+
"Z": 3,
|
| 1001 |
+
"-": 21,
|
| 1002 |
+
}
|
| 1003 |
+
|
| 1004 |
+
# Partial inversion of HHBLITS_AA_TO_ID.
|
| 1005 |
+
ID_TO_HHBLITS_AA = {
|
| 1006 |
+
0: "A",
|
| 1007 |
+
1: "C", # Also U.
|
| 1008 |
+
2: "D", # Also B.
|
| 1009 |
+
3: "E", # Also Z.
|
| 1010 |
+
4: "F",
|
| 1011 |
+
5: "G",
|
| 1012 |
+
6: "H",
|
| 1013 |
+
7: "I",
|
| 1014 |
+
8: "K",
|
| 1015 |
+
9: "L",
|
| 1016 |
+
10: "M",
|
| 1017 |
+
11: "N",
|
| 1018 |
+
12: "P",
|
| 1019 |
+
13: "Q",
|
| 1020 |
+
14: "R",
|
| 1021 |
+
15: "S",
|
| 1022 |
+
16: "T",
|
| 1023 |
+
17: "V",
|
| 1024 |
+
18: "W",
|
| 1025 |
+
19: "Y",
|
| 1026 |
+
20: "X", # Includes J and O.
|
| 1027 |
+
21: "-",
|
| 1028 |
+
}
|
| 1029 |
+
|
| 1030 |
+
restypes_with_x_and_gap = restypes + ["X", "-"]
|
| 1031 |
+
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
|
| 1032 |
+
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
|
| 1033 |
+
for i in range(len(restypes_with_x_and_gap))
|
| 1034 |
+
)
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
+
def _make_standard_atom_mask() -> np.ndarray:
|
| 1038 |
+
"""Returns [num_res_types, num_atom_types] mask array."""
|
| 1039 |
+
# +1 to account for unknown (all 0s).
|
| 1040 |
+
mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
|
| 1041 |
+
for restype, restype_letter in enumerate(restypes):
|
| 1042 |
+
restype_name = restype_1to3[restype_letter]
|
| 1043 |
+
atom_names = residue_atoms[restype_name]
|
| 1044 |
+
for atom_name in atom_names:
|
| 1045 |
+
atom_type = atom_order[atom_name]
|
| 1046 |
+
mask[restype, atom_type] = 1
|
| 1047 |
+
return mask
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
+
STANDARD_ATOM_MASK = _make_standard_atom_mask()
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
# A one hot representation for the first and second atoms defining the axis
|
| 1054 |
+
# of rotation for each chi-angle in each residue.
|
| 1055 |
+
def chi_angle_atom(atom_index: int) -> np.ndarray:
|
| 1056 |
+
"""Define chi-angle rigid groups via one-hot representations."""
|
| 1057 |
+
chi_angles_index = {}
|
| 1058 |
+
one_hots = []
|
| 1059 |
+
|
| 1060 |
+
for k, v in chi_angles_atoms.items():
|
| 1061 |
+
indices = [atom_types.index(s[atom_index]) for s in v]
|
| 1062 |
+
indices.extend([-1] * (4 - len(indices)))
|
| 1063 |
+
chi_angles_index[k] = indices
|
| 1064 |
+
|
| 1065 |
+
for r in restypes:
|
| 1066 |
+
res3 = restype_1to3[r]
|
| 1067 |
+
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
|
| 1068 |
+
one_hots.append(one_hot)
|
| 1069 |
+
|
| 1070 |
+
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
|
| 1071 |
+
one_hot = np.stack(one_hots, axis=0)
|
| 1072 |
+
one_hot = np.transpose(one_hot, [0, 2, 1])
|
| 1073 |
+
|
| 1074 |
+
return one_hot
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
chi_atom_1_one_hot = chi_angle_atom(1)
|
| 1078 |
+
chi_atom_2_one_hot = chi_angle_atom(2)
|
| 1079 |
+
|
| 1080 |
+
# An array like chi_angles_atoms but using indices rather than names.
|
| 1081 |
+
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
|
| 1082 |
+
chi_angles_atom_indices = tree.map_structure(
|
| 1083 |
+
lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
|
| 1084 |
+
)
|
| 1085 |
+
chi_angles_atom_indices = np.array(
|
| 1086 |
+
[
|
| 1087 |
+
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
|
| 1088 |
+
for chi_atoms in chi_angles_atom_indices
|
| 1089 |
+
]
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
|
| 1093 |
+
# and atom index within that group.
|
| 1094 |
+
chi_groups_for_atom = collections.defaultdict(list)
|
| 1095 |
+
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
|
| 1096 |
+
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
|
| 1097 |
+
for atom_i, atom in enumerate(chi_group):
|
| 1098 |
+
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
|
| 1099 |
+
chi_groups_for_atom = dict(chi_groups_for_atom)
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
+
def _make_rigid_transformation_4x4(ex, ey, translation):
|
| 1103 |
+
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
|
| 1104 |
+
# Normalize ex.
|
| 1105 |
+
ex_normalized = ex / np.linalg.norm(ex)
|
| 1106 |
+
|
| 1107 |
+
# make ey perpendicular to ex
|
| 1108 |
+
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
|
| 1109 |
+
ey_normalized /= np.linalg.norm(ey_normalized)
|
| 1110 |
+
|
| 1111 |
+
# compute ez as cross product
|
| 1112 |
+
eznorm = np.cross(ex_normalized, ey_normalized)
|
| 1113 |
+
m = np.stack(
|
| 1114 |
+
[ex_normalized, ey_normalized, eznorm, translation]
|
| 1115 |
+
).transpose()
|
| 1116 |
+
m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
|
| 1117 |
+
return m
|
| 1118 |
+
|
| 1119 |
+
|
| 1120 |
+
# create an array with (restype, atomtype) --> rigid_group_idx
|
| 1121 |
+
# and an array with (restype, atomtype, coord) for the atom positions
|
| 1122 |
+
# and compute affine transformation matrices (4,4) from one rigid group to the
|
| 1123 |
+
# previous group
|
| 1124 |
+
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
|
| 1125 |
+
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
|
| 1126 |
+
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
|
| 1127 |
+
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
|
| 1128 |
+
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
|
| 1129 |
+
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
|
| 1130 |
+
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
|
| 1131 |
+
|
| 1132 |
+
|
| 1133 |
+
def _make_rigid_group_constants():
|
| 1134 |
+
"""Fill the arrays above."""
|
| 1135 |
+
for restype, restype_letter in enumerate(restypes):
|
| 1136 |
+
resname = restype_1to3[restype_letter]
|
| 1137 |
+
for atomname, group_idx, atom_position in rigid_group_atom_positions[
|
| 1138 |
+
resname
|
| 1139 |
+
]:
|
| 1140 |
+
atomtype = atom_order[atomname]
|
| 1141 |
+
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
|
| 1142 |
+
restype_atom37_mask[restype, atomtype] = 1
|
| 1143 |
+
restype_atom37_rigid_group_positions[
|
| 1144 |
+
restype, atomtype, :
|
| 1145 |
+
] = atom_position
|
| 1146 |
+
|
| 1147 |
+
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
|
| 1148 |
+
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
|
| 1149 |
+
restype_atom14_mask[restype, atom14idx] = 1
|
| 1150 |
+
restype_atom14_rigid_group_positions[
|
| 1151 |
+
restype, atom14idx, :
|
| 1152 |
+
] = atom_position
|
| 1153 |
+
|
| 1154 |
+
for restype, restype_letter in enumerate(restypes):
|
| 1155 |
+
resname = restype_1to3[restype_letter]
|
| 1156 |
+
atom_positions = {
|
| 1157 |
+
name: np.array(pos)
|
| 1158 |
+
for name, _, pos in rigid_group_atom_positions[resname]
|
| 1159 |
+
}
|
| 1160 |
+
|
| 1161 |
+
# backbone to backbone is the identity transform
|
| 1162 |
+
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
|
| 1163 |
+
|
| 1164 |
+
# pre-omega-frame to backbone (currently dummy identity matrix)
|
| 1165 |
+
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
|
| 1166 |
+
|
| 1167 |
+
# phi-frame to backbone
|
| 1168 |
+
mat = _make_rigid_transformation_4x4(
|
| 1169 |
+
ex=atom_positions["N"] - atom_positions["CA"],
|
| 1170 |
+
ey=np.array([1.0, 0.0, 0.0]),
|
| 1171 |
+
translation=atom_positions["N"],
|
| 1172 |
+
)
|
| 1173 |
+
restype_rigid_group_default_frame[restype, 2, :, :] = mat
|
| 1174 |
+
|
| 1175 |
+
# psi-frame to backbone
|
| 1176 |
+
mat = _make_rigid_transformation_4x4(
|
| 1177 |
+
ex=atom_positions["C"] - atom_positions["CA"],
|
| 1178 |
+
ey=atom_positions["CA"] - atom_positions["N"],
|
| 1179 |
+
translation=atom_positions["C"],
|
| 1180 |
+
)
|
| 1181 |
+
restype_rigid_group_default_frame[restype, 3, :, :] = mat
|
| 1182 |
+
|
| 1183 |
+
# chi1-frame to backbone
|
| 1184 |
+
if chi_angles_mask[restype][0]:
|
| 1185 |
+
base_atom_names = chi_angles_atoms[resname][0]
|
| 1186 |
+
base_atom_positions = [
|
| 1187 |
+
atom_positions[name] for name in base_atom_names
|
| 1188 |
+
]
|
| 1189 |
+
mat = _make_rigid_transformation_4x4(
|
| 1190 |
+
ex=base_atom_positions[2] - base_atom_positions[1],
|
| 1191 |
+
ey=base_atom_positions[0] - base_atom_positions[1],
|
| 1192 |
+
translation=base_atom_positions[2],
|
| 1193 |
+
)
|
| 1194 |
+
restype_rigid_group_default_frame[restype, 4, :, :] = mat
|
| 1195 |
+
|
| 1196 |
+
# chi2-frame to chi1-frame
|
| 1197 |
+
# chi3-frame to chi2-frame
|
| 1198 |
+
# chi4-frame to chi3-frame
|
| 1199 |
+
# luckily all rotation axes for the next frame start at (0,0,0) of the
|
| 1200 |
+
# previous frame
|
| 1201 |
+
for chi_idx in range(1, 4):
|
| 1202 |
+
if chi_angles_mask[restype][chi_idx]:
|
| 1203 |
+
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
|
| 1204 |
+
axis_end_atom_position = atom_positions[axis_end_atom_name]
|
| 1205 |
+
mat = _make_rigid_transformation_4x4(
|
| 1206 |
+
ex=axis_end_atom_position,
|
| 1207 |
+
ey=np.array([-1.0, 0.0, 0.0]),
|
| 1208 |
+
translation=axis_end_atom_position,
|
| 1209 |
+
)
|
| 1210 |
+
restype_rigid_group_default_frame[
|
| 1211 |
+
restype, 4 + chi_idx, :, :
|
| 1212 |
+
] = mat
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
_make_rigid_group_constants()
|
| 1216 |
+
|
| 1217 |
+
|
| 1218 |
+
def make_atom14_dists_bounds(
|
| 1219 |
+
overlap_tolerance=1.5, bond_length_tolerance_factor=15
|
| 1220 |
+
):
|
| 1221 |
+
"""compute upper and lower bounds for bonds to assess violations."""
|
| 1222 |
+
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
|
| 1223 |
+
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
|
| 1224 |
+
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
|
| 1225 |
+
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
|
| 1226 |
+
for restype, restype_letter in enumerate(restypes):
|
| 1227 |
+
resname = restype_1to3[restype_letter]
|
| 1228 |
+
atom_list = restype_name_to_atom14_names[resname]
|
| 1229 |
+
|
| 1230 |
+
# create lower and upper bounds for clashes
|
| 1231 |
+
for atom1_idx, atom1_name in enumerate(atom_list):
|
| 1232 |
+
if not atom1_name:
|
| 1233 |
+
continue
|
| 1234 |
+
atom1_radius = van_der_waals_radius[atom1_name[0]]
|
| 1235 |
+
for atom2_idx, atom2_name in enumerate(atom_list):
|
| 1236 |
+
if (not atom2_name) or atom1_idx == atom2_idx:
|
| 1237 |
+
continue
|
| 1238 |
+
atom2_radius = van_der_waals_radius[atom2_name[0]]
|
| 1239 |
+
lower = atom1_radius + atom2_radius - overlap_tolerance
|
| 1240 |
+
upper = 1e10
|
| 1241 |
+
restype_atom14_bond_lower_bound[
|
| 1242 |
+
restype, atom1_idx, atom2_idx
|
| 1243 |
+
] = lower
|
| 1244 |
+
restype_atom14_bond_lower_bound[
|
| 1245 |
+
restype, atom2_idx, atom1_idx
|
| 1246 |
+
] = lower
|
| 1247 |
+
restype_atom14_bond_upper_bound[
|
| 1248 |
+
restype, atom1_idx, atom2_idx
|
| 1249 |
+
] = upper
|
| 1250 |
+
restype_atom14_bond_upper_bound[
|
| 1251 |
+
restype, atom2_idx, atom1_idx
|
| 1252 |
+
] = upper
|
| 1253 |
+
|
| 1254 |
+
# overwrite lower and upper bounds for bonds and angles
|
| 1255 |
+
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
|
| 1256 |
+
atom1_idx = atom_list.index(b.atom1_name)
|
| 1257 |
+
atom2_idx = atom_list.index(b.atom2_name)
|
| 1258 |
+
lower = b.length - bond_length_tolerance_factor * b.stddev
|
| 1259 |
+
upper = b.length + bond_length_tolerance_factor * b.stddev
|
| 1260 |
+
restype_atom14_bond_lower_bound[
|
| 1261 |
+
restype, atom1_idx, atom2_idx
|
| 1262 |
+
] = lower
|
| 1263 |
+
restype_atom14_bond_lower_bound[
|
| 1264 |
+
restype, atom2_idx, atom1_idx
|
| 1265 |
+
] = lower
|
| 1266 |
+
restype_atom14_bond_upper_bound[
|
| 1267 |
+
restype, atom1_idx, atom2_idx
|
| 1268 |
+
] = upper
|
| 1269 |
+
restype_atom14_bond_upper_bound[
|
| 1270 |
+
restype, atom2_idx, atom1_idx
|
| 1271 |
+
] = upper
|
| 1272 |
+
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
|
| 1273 |
+
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
|
| 1274 |
+
return {
|
| 1275 |
+
"lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
|
| 1276 |
+
"upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
|
| 1277 |
+
"stddev": restype_atom14_bond_stddev, # shape (21,14,14)
|
| 1278 |
+
}
|
| 1279 |
+
|
| 1280 |
+
|
| 1281 |
+
restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
|
| 1282 |
+
restype_atom14_ambiguous_atoms_swap_idx = np.tile(
|
| 1283 |
+
np.arange(14, dtype=np.int), (21, 1)
|
| 1284 |
+
)
|
| 1285 |
+
|
| 1286 |
+
|
| 1287 |
+
def _make_atom14_ambiguity_feats():
|
| 1288 |
+
for res, pairs in residue_atom_renaming_swaps.items():
|
| 1289 |
+
res_idx = restype_order[restype_3to1[res]]
|
| 1290 |
+
for atom1, atom2 in pairs.items():
|
| 1291 |
+
atom1_idx = restype_name_to_atom14_names[res].index(atom1)
|
| 1292 |
+
atom2_idx = restype_name_to_atom14_names[res].index(atom2)
|
| 1293 |
+
restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
|
| 1294 |
+
restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
|
| 1295 |
+
restype_atom14_ambiguous_atoms_swap_idx[
|
| 1296 |
+
res_idx, atom1_idx
|
| 1297 |
+
] = atom2_idx
|
| 1298 |
+
restype_atom14_ambiguous_atoms_swap_idx[
|
| 1299 |
+
res_idx, atom2_idx
|
| 1300 |
+
] = atom1_idx
|
| 1301 |
+
|
| 1302 |
+
|
| 1303 |
+
_make_atom14_ambiguity_feats()
|
| 1304 |
+
|
| 1305 |
+
|
| 1306 |
+
def aatype_to_str_sequence(aatype):
|
| 1307 |
+
return ''.join([
|
| 1308 |
+
restypes_with_x[aatype[i]]
|
| 1309 |
+
for i in range(len(aatype))
|
| 1310 |
+
])
|
openfold/resources/__init__.py
ADDED
|
File without changes
|
openfold/utils/feats.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
from typing import Dict
|
| 22 |
+
|
| 23 |
+
from openfold.np import protein
|
| 24 |
+
import openfold.np.residue_constants as rc
|
| 25 |
+
from openfold.utils.rigid_utils import Rotation, Rigid
|
| 26 |
+
from openfold.utils.tensor_utils import (
|
| 27 |
+
batched_gather,
|
| 28 |
+
one_hot,
|
| 29 |
+
tree_map,
|
| 30 |
+
tensor_tree_map,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
|
| 35 |
+
is_gly = aatype == rc.restype_order["G"]
|
| 36 |
+
ca_idx = rc.atom_order["CA"]
|
| 37 |
+
cb_idx = rc.atom_order["CB"]
|
| 38 |
+
pseudo_beta = torch.where(
|
| 39 |
+
is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
|
| 40 |
+
all_atom_positions[..., ca_idx, :],
|
| 41 |
+
all_atom_positions[..., cb_idx, :],
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
if all_atom_masks is not None:
|
| 45 |
+
pseudo_beta_mask = torch.where(
|
| 46 |
+
is_gly,
|
| 47 |
+
all_atom_masks[..., ca_idx],
|
| 48 |
+
all_atom_masks[..., cb_idx],
|
| 49 |
+
)
|
| 50 |
+
return pseudo_beta, pseudo_beta_mask
|
| 51 |
+
else:
|
| 52 |
+
return pseudo_beta
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def atom14_to_atom37(atom14, batch):
|
| 56 |
+
atom37_data = batched_gather(
|
| 57 |
+
atom14,
|
| 58 |
+
batch["residx_atom37_to_atom14"],
|
| 59 |
+
dim=-2,
|
| 60 |
+
no_batch_dims=len(atom14.shape[:-2]),
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
|
| 64 |
+
|
| 65 |
+
return atom37_data
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def build_template_angle_feat(template_feats):
|
| 69 |
+
template_aatype = template_feats["template_aatype"]
|
| 70 |
+
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
|
| 71 |
+
alt_torsion_angles_sin_cos = template_feats[
|
| 72 |
+
"template_alt_torsion_angles_sin_cos"
|
| 73 |
+
]
|
| 74 |
+
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
|
| 75 |
+
template_angle_feat = torch.cat(
|
| 76 |
+
[
|
| 77 |
+
nn.functional.one_hot(template_aatype, 22),
|
| 78 |
+
torsion_angles_sin_cos.reshape(
|
| 79 |
+
*torsion_angles_sin_cos.shape[:-2], 14
|
| 80 |
+
),
|
| 81 |
+
alt_torsion_angles_sin_cos.reshape(
|
| 82 |
+
*alt_torsion_angles_sin_cos.shape[:-2], 14
|
| 83 |
+
),
|
| 84 |
+
torsion_angles_mask,
|
| 85 |
+
],
|
| 86 |
+
dim=-1,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return template_angle_feat
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def build_template_pair_feat(
|
| 93 |
+
batch,
|
| 94 |
+
min_bin, max_bin, no_bins,
|
| 95 |
+
use_unit_vector=False,
|
| 96 |
+
eps=1e-20, inf=1e8
|
| 97 |
+
):
|
| 98 |
+
template_mask = batch["template_pseudo_beta_mask"]
|
| 99 |
+
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
|
| 100 |
+
|
| 101 |
+
# Compute distogram (this seems to differ slightly from Alg. 5)
|
| 102 |
+
tpb = batch["template_pseudo_beta"]
|
| 103 |
+
dgram = torch.sum(
|
| 104 |
+
(tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
|
| 105 |
+
)
|
| 106 |
+
lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
|
| 107 |
+
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
|
| 108 |
+
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
|
| 109 |
+
|
| 110 |
+
to_concat = [dgram, template_mask_2d[..., None]]
|
| 111 |
+
|
| 112 |
+
aatype_one_hot = nn.functional.one_hot(
|
| 113 |
+
batch["template_aatype"],
|
| 114 |
+
rc.restype_num + 2,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
n_res = batch["template_aatype"].shape[-1]
|
| 118 |
+
to_concat.append(
|
| 119 |
+
aatype_one_hot[..., None, :, :].expand(
|
| 120 |
+
*aatype_one_hot.shape[:-2], n_res, -1, -1
|
| 121 |
+
)
|
| 122 |
+
)
|
| 123 |
+
to_concat.append(
|
| 124 |
+
aatype_one_hot[..., None, :].expand(
|
| 125 |
+
*aatype_one_hot.shape[:-2], -1, n_res, -1
|
| 126 |
+
)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
|
| 130 |
+
rigids = Rigid.make_transform_from_reference(
|
| 131 |
+
n_xyz=batch["template_all_atom_positions"][..., n, :],
|
| 132 |
+
ca_xyz=batch["template_all_atom_positions"][..., ca, :],
|
| 133 |
+
c_xyz=batch["template_all_atom_positions"][..., c, :],
|
| 134 |
+
eps=eps,
|
| 135 |
+
)
|
| 136 |
+
points = rigids.get_trans()[..., None, :, :]
|
| 137 |
+
rigid_vec = rigids[..., None].invert_apply(points)
|
| 138 |
+
|
| 139 |
+
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1))
|
| 140 |
+
|
| 141 |
+
t_aa_masks = batch["template_all_atom_mask"]
|
| 142 |
+
template_mask = (
|
| 143 |
+
t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
|
| 144 |
+
)
|
| 145 |
+
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
|
| 146 |
+
|
| 147 |
+
inv_distance_scalar = inv_distance_scalar * template_mask_2d
|
| 148 |
+
unit_vector = rigid_vec * inv_distance_scalar[..., None]
|
| 149 |
+
|
| 150 |
+
if(not use_unit_vector):
|
| 151 |
+
unit_vector = unit_vector * 0.
|
| 152 |
+
|
| 153 |
+
to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
|
| 154 |
+
to_concat.append(template_mask_2d[..., None])
|
| 155 |
+
|
| 156 |
+
act = torch.cat(to_concat, dim=-1)
|
| 157 |
+
act = act * template_mask_2d[..., None]
|
| 158 |
+
|
| 159 |
+
return act
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def build_extra_msa_feat(batch):
|
| 163 |
+
msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
|
| 164 |
+
msa_feat = [
|
| 165 |
+
msa_1hot,
|
| 166 |
+
batch["extra_has_deletion"].unsqueeze(-1),
|
| 167 |
+
batch["extra_deletion_value"].unsqueeze(-1),
|
| 168 |
+
]
|
| 169 |
+
return torch.cat(msa_feat, dim=-1)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def torsion_angles_to_frames(
|
| 173 |
+
r: Rigid,
|
| 174 |
+
alpha: torch.Tensor,
|
| 175 |
+
aatype: torch.Tensor,
|
| 176 |
+
rrgdf: torch.Tensor,
|
| 177 |
+
):
|
| 178 |
+
# [*, N, 8, 4, 4]
|
| 179 |
+
default_4x4 = rrgdf[aatype, ...]
|
| 180 |
+
|
| 181 |
+
# [*, N, 8] transformations, i.e.
|
| 182 |
+
# One [*, N, 8, 3, 3] rotation matrix and
|
| 183 |
+
# One [*, N, 8, 3] translation matrix
|
| 184 |
+
default_r = r.from_tensor_4x4(default_4x4)
|
| 185 |
+
|
| 186 |
+
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
|
| 187 |
+
bb_rot[..., 1] = 1
|
| 188 |
+
|
| 189 |
+
# [*, N, 8, 2]
|
| 190 |
+
alpha = torch.cat(
|
| 191 |
+
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# [*, N, 8, 3, 3]
|
| 195 |
+
# Produces rotation matrices of the form:
|
| 196 |
+
# [
|
| 197 |
+
# [1, 0 , 0 ],
|
| 198 |
+
# [0, a_2,-a_1],
|
| 199 |
+
# [0, a_1, a_2]
|
| 200 |
+
# ]
|
| 201 |
+
# This follows the original code rather than the supplement, which uses
|
| 202 |
+
# different indices.
|
| 203 |
+
|
| 204 |
+
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
|
| 205 |
+
all_rots[..., 0, 0] = 1
|
| 206 |
+
all_rots[..., 1, 1] = alpha[..., 1]
|
| 207 |
+
all_rots[..., 1, 2] = -alpha[..., 0]
|
| 208 |
+
all_rots[..., 2, 1:] = alpha
|
| 209 |
+
|
| 210 |
+
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
|
| 211 |
+
|
| 212 |
+
all_frames = default_r.compose(all_rots)
|
| 213 |
+
|
| 214 |
+
chi2_frame_to_frame = all_frames[..., 5]
|
| 215 |
+
chi3_frame_to_frame = all_frames[..., 6]
|
| 216 |
+
chi4_frame_to_frame = all_frames[..., 7]
|
| 217 |
+
|
| 218 |
+
chi1_frame_to_bb = all_frames[..., 4]
|
| 219 |
+
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
|
| 220 |
+
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
|
| 221 |
+
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
|
| 222 |
+
|
| 223 |
+
all_frames_to_bb = Rigid.cat(
|
| 224 |
+
[
|
| 225 |
+
all_frames[..., :5],
|
| 226 |
+
chi2_frame_to_bb.unsqueeze(-1),
|
| 227 |
+
chi3_frame_to_bb.unsqueeze(-1),
|
| 228 |
+
chi4_frame_to_bb.unsqueeze(-1),
|
| 229 |
+
],
|
| 230 |
+
dim=-1,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
all_frames_to_global = r[..., None].compose(all_frames_to_bb)
|
| 234 |
+
|
| 235 |
+
return all_frames_to_global
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def frames_and_literature_positions_to_atom14_pos(
|
| 239 |
+
r: Rigid,
|
| 240 |
+
aatype: torch.Tensor,
|
| 241 |
+
default_frames,
|
| 242 |
+
group_idx,
|
| 243 |
+
atom_mask,
|
| 244 |
+
lit_positions,
|
| 245 |
+
):
|
| 246 |
+
# [*, N, 14, 4, 4]
|
| 247 |
+
default_4x4 = default_frames[aatype, ...]
|
| 248 |
+
|
| 249 |
+
# [*, N, 14]
|
| 250 |
+
group_mask = group_idx[aatype, ...]
|
| 251 |
+
|
| 252 |
+
# [*, N, 14, 8]
|
| 253 |
+
group_mask = nn.functional.one_hot(
|
| 254 |
+
group_mask,
|
| 255 |
+
num_classes=default_frames.shape[-3],
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# [*, N, 14, 8]
|
| 259 |
+
t_atoms_to_global = r[..., None, :] * group_mask
|
| 260 |
+
|
| 261 |
+
# [*, N, 14]
|
| 262 |
+
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
|
| 263 |
+
lambda x: torch.sum(x, dim=-1)
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# [*, N, 14, 1]
|
| 267 |
+
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
|
| 268 |
+
|
| 269 |
+
# [*, N, 14, 3]
|
| 270 |
+
lit_positions = lit_positions[aatype, ...]
|
| 271 |
+
pred_positions = t_atoms_to_global.apply(lit_positions)
|
| 272 |
+
pred_positions = pred_positions * atom_mask
|
| 273 |
+
|
| 274 |
+
return pred_positions
|
openfold/utils/loss.py
ADDED
|
@@ -0,0 +1,1614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from functools import partial
|
| 17 |
+
import logging
|
| 18 |
+
import ml_collections
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from torch.distributions.bernoulli import Bernoulli
|
| 23 |
+
from typing import Dict, Optional, Tuple
|
| 24 |
+
|
| 25 |
+
from openfold.np import residue_constants
|
| 26 |
+
from openfold.utils import feats
|
| 27 |
+
from openfold.utils.rigid_utils import Rotation, Rigid
|
| 28 |
+
from openfold.utils.tensor_utils import (
|
| 29 |
+
tree_map,
|
| 30 |
+
tensor_tree_map,
|
| 31 |
+
masked_mean,
|
| 32 |
+
permute_final_dims,
|
| 33 |
+
batched_gather,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def softmax_cross_entropy(logits, labels):
|
| 38 |
+
loss = -1 * torch.sum(
|
| 39 |
+
labels * torch.nn.functional.log_softmax(logits, dim=-1),
|
| 40 |
+
dim=-1,
|
| 41 |
+
)
|
| 42 |
+
return loss
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def sigmoid_cross_entropy(logits, labels):
|
| 46 |
+
logits_dtype = logits.dtype
|
| 47 |
+
logits = logits.double()
|
| 48 |
+
labels = labels.double()
|
| 49 |
+
log_p = torch.nn.functional.logsigmoid(logits)
|
| 50 |
+
# log_p = torch.log(torch.sigmoid(logits))
|
| 51 |
+
log_not_p = torch.nn.functional.logsigmoid(-1 * logits)
|
| 52 |
+
# log_not_p = torch.log(torch.sigmoid(-logits))
|
| 53 |
+
loss = (-1. * labels) * log_p - (1. - labels) * log_not_p
|
| 54 |
+
loss = loss.to(dtype=logits_dtype)
|
| 55 |
+
return loss
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def torsion_angle_loss(
|
| 59 |
+
a, # [*, N, 7, 2]
|
| 60 |
+
a_gt, # [*, N, 7, 2]
|
| 61 |
+
a_alt_gt, # [*, N, 7, 2]
|
| 62 |
+
):
|
| 63 |
+
# [*, N, 7]
|
| 64 |
+
norm = torch.norm(a, dim=-1)
|
| 65 |
+
|
| 66 |
+
# [*, N, 7, 2]
|
| 67 |
+
a = a / norm.unsqueeze(-1)
|
| 68 |
+
|
| 69 |
+
# [*, N, 7]
|
| 70 |
+
diff_norm_gt = torch.norm(a - a_gt, dim=-1)
|
| 71 |
+
diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1)
|
| 72 |
+
min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2)
|
| 73 |
+
|
| 74 |
+
# [*]
|
| 75 |
+
l_torsion = torch.mean(min_diff, dim=(-1, -2))
|
| 76 |
+
l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2))
|
| 77 |
+
|
| 78 |
+
an_weight = 0.02
|
| 79 |
+
return l_torsion + an_weight * l_angle_norm
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def compute_fape(
|
| 83 |
+
pred_frames: Rigid,
|
| 84 |
+
target_frames: Rigid,
|
| 85 |
+
frames_mask: torch.Tensor,
|
| 86 |
+
pred_positions: torch.Tensor,
|
| 87 |
+
target_positions: torch.Tensor,
|
| 88 |
+
positions_mask: torch.Tensor,
|
| 89 |
+
length_scale: float,
|
| 90 |
+
l1_clamp_distance: Optional[float] = None,
|
| 91 |
+
eps=1e-8,
|
| 92 |
+
) -> torch.Tensor:
|
| 93 |
+
"""
|
| 94 |
+
Computes FAPE loss.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
pred_frames:
|
| 98 |
+
[*, N_frames] Rigid object of predicted frames
|
| 99 |
+
target_frames:
|
| 100 |
+
[*, N_frames] Rigid object of ground truth frames
|
| 101 |
+
frames_mask:
|
| 102 |
+
[*, N_frames] binary mask for the frames
|
| 103 |
+
pred_positions:
|
| 104 |
+
[*, N_pts, 3] predicted atom positions
|
| 105 |
+
target_positions:
|
| 106 |
+
[*, N_pts, 3] ground truth positions
|
| 107 |
+
positions_mask:
|
| 108 |
+
[*, N_pts] positions mask
|
| 109 |
+
length_scale:
|
| 110 |
+
Length scale by which the loss is divided
|
| 111 |
+
l1_clamp_distance:
|
| 112 |
+
Cutoff above which distance errors are disregarded
|
| 113 |
+
eps:
|
| 114 |
+
Small value used to regularize denominators
|
| 115 |
+
Returns:
|
| 116 |
+
[*] loss tensor
|
| 117 |
+
"""
|
| 118 |
+
# [*, N_frames, N_pts, 3]
|
| 119 |
+
local_pred_pos = pred_frames.invert()[..., None].apply(
|
| 120 |
+
pred_positions[..., None, :, :],
|
| 121 |
+
)
|
| 122 |
+
local_target_pos = target_frames.invert()[..., None].apply(
|
| 123 |
+
target_positions[..., None, :, :],
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
error_dist = torch.sqrt(
|
| 127 |
+
torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if l1_clamp_distance is not None:
|
| 131 |
+
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
|
| 132 |
+
|
| 133 |
+
normed_error = error_dist / length_scale
|
| 134 |
+
normed_error = normed_error * frames_mask[..., None]
|
| 135 |
+
normed_error = normed_error * positions_mask[..., None, :]
|
| 136 |
+
|
| 137 |
+
# FP16-friendly averaging. Roughly equivalent to:
|
| 138 |
+
#
|
| 139 |
+
# norm_factor = (
|
| 140 |
+
# torch.sum(frames_mask, dim=-1) *
|
| 141 |
+
# torch.sum(positions_mask, dim=-1)
|
| 142 |
+
# )
|
| 143 |
+
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
|
| 144 |
+
#
|
| 145 |
+
# ("roughly" because eps is necessarily duplicated in the latter)
|
| 146 |
+
normed_error = torch.sum(normed_error, dim=-1)
|
| 147 |
+
normed_error = (
|
| 148 |
+
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
|
| 149 |
+
)
|
| 150 |
+
normed_error = torch.sum(normed_error, dim=-1)
|
| 151 |
+
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
|
| 152 |
+
|
| 153 |
+
return normed_error
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def backbone_loss(
|
| 157 |
+
backbone_rigid_tensor: torch.Tensor,
|
| 158 |
+
backbone_rigid_mask: torch.Tensor,
|
| 159 |
+
traj: torch.Tensor,
|
| 160 |
+
use_clamped_fape: Optional[torch.Tensor] = None,
|
| 161 |
+
clamp_distance: float = 10.0,
|
| 162 |
+
loss_unit_distance: float = 10.0,
|
| 163 |
+
eps: float = 1e-4,
|
| 164 |
+
**kwargs,
|
| 165 |
+
) -> torch.Tensor:
|
| 166 |
+
pred_aff = Rigid.from_tensor_7(traj)
|
| 167 |
+
pred_aff = Rigid(
|
| 168 |
+
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
|
| 169 |
+
pred_aff.get_trans(),
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
|
| 173 |
+
# backbone tensor, normalizes it, and then turns it back to a rotation
|
| 174 |
+
# matrix. To avoid a potentially numerically unstable rotation matrix
|
| 175 |
+
# to quaternion conversion, we just use the original rotation matrix
|
| 176 |
+
# outright. This one hasn't been composed a bunch of times, though, so
|
| 177 |
+
# it might be fine.
|
| 178 |
+
gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
|
| 179 |
+
|
| 180 |
+
fape_loss = compute_fape(
|
| 181 |
+
pred_aff,
|
| 182 |
+
gt_aff[None],
|
| 183 |
+
backbone_rigid_mask[None],
|
| 184 |
+
pred_aff.get_trans(),
|
| 185 |
+
gt_aff[None].get_trans(),
|
| 186 |
+
backbone_rigid_mask[None],
|
| 187 |
+
l1_clamp_distance=clamp_distance,
|
| 188 |
+
length_scale=loss_unit_distance,
|
| 189 |
+
eps=eps,
|
| 190 |
+
)
|
| 191 |
+
if use_clamped_fape is not None:
|
| 192 |
+
unclamped_fape_loss = compute_fape(
|
| 193 |
+
pred_aff,
|
| 194 |
+
gt_aff[None],
|
| 195 |
+
backbone_rigid_mask[None],
|
| 196 |
+
pred_aff.get_trans(),
|
| 197 |
+
gt_aff[None].get_trans(),
|
| 198 |
+
backbone_rigid_mask[None],
|
| 199 |
+
l1_clamp_distance=None,
|
| 200 |
+
length_scale=loss_unit_distance,
|
| 201 |
+
eps=eps,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (
|
| 205 |
+
1 - use_clamped_fape
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Average over the batch dimension
|
| 209 |
+
fape_loss = torch.mean(fape_loss)
|
| 210 |
+
|
| 211 |
+
return fape_loss
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def sidechain_loss(
|
| 215 |
+
sidechain_frames: torch.Tensor,
|
| 216 |
+
sidechain_atom_pos: torch.Tensor,
|
| 217 |
+
rigidgroups_gt_frames: torch.Tensor,
|
| 218 |
+
rigidgroups_alt_gt_frames: torch.Tensor,
|
| 219 |
+
rigidgroups_gt_exists: torch.Tensor,
|
| 220 |
+
renamed_atom14_gt_positions: torch.Tensor,
|
| 221 |
+
renamed_atom14_gt_exists: torch.Tensor,
|
| 222 |
+
alt_naming_is_better: torch.Tensor,
|
| 223 |
+
clamp_distance: float = 10.0,
|
| 224 |
+
length_scale: float = 10.0,
|
| 225 |
+
eps: float = 1e-4,
|
| 226 |
+
**kwargs,
|
| 227 |
+
) -> torch.Tensor:
|
| 228 |
+
renamed_gt_frames = (
|
| 229 |
+
1.0 - alt_naming_is_better[..., None, None, None]
|
| 230 |
+
) * rigidgroups_gt_frames + alt_naming_is_better[
|
| 231 |
+
..., None, None, None
|
| 232 |
+
] * rigidgroups_alt_gt_frames
|
| 233 |
+
|
| 234 |
+
# Steamroll the inputs
|
| 235 |
+
sidechain_frames = sidechain_frames[-1]
|
| 236 |
+
batch_dims = sidechain_frames.shape[:-4]
|
| 237 |
+
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
|
| 238 |
+
sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
|
| 239 |
+
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
|
| 240 |
+
renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
|
| 241 |
+
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
|
| 242 |
+
sidechain_atom_pos = sidechain_atom_pos[-1]
|
| 243 |
+
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
|
| 244 |
+
renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(
|
| 245 |
+
*batch_dims, -1, 3
|
| 246 |
+
)
|
| 247 |
+
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
|
| 248 |
+
|
| 249 |
+
fape = compute_fape(
|
| 250 |
+
sidechain_frames,
|
| 251 |
+
renamed_gt_frames,
|
| 252 |
+
rigidgroups_gt_exists,
|
| 253 |
+
sidechain_atom_pos,
|
| 254 |
+
renamed_atom14_gt_positions,
|
| 255 |
+
renamed_atom14_gt_exists,
|
| 256 |
+
l1_clamp_distance=clamp_distance,
|
| 257 |
+
length_scale=length_scale,
|
| 258 |
+
eps=eps,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
return fape
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def fape_loss(
|
| 265 |
+
out: Dict[str, torch.Tensor],
|
| 266 |
+
batch: Dict[str, torch.Tensor],
|
| 267 |
+
config: ml_collections.ConfigDict,
|
| 268 |
+
) -> torch.Tensor:
|
| 269 |
+
bb_loss = backbone_loss(
|
| 270 |
+
traj=out["sm"]["frames"],
|
| 271 |
+
**{**batch, **config.backbone},
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
sc_loss = sidechain_loss(
|
| 275 |
+
out["sm"]["sidechain_frames"],
|
| 276 |
+
out["sm"]["positions"],
|
| 277 |
+
**{**batch, **config.sidechain},
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
|
| 281 |
+
|
| 282 |
+
# Average over the batch dimension
|
| 283 |
+
loss = torch.mean(loss)
|
| 284 |
+
|
| 285 |
+
return loss
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def supervised_chi_loss(
|
| 289 |
+
angles_sin_cos: torch.Tensor,
|
| 290 |
+
unnormalized_angles_sin_cos: torch.Tensor,
|
| 291 |
+
aatype: torch.Tensor,
|
| 292 |
+
seq_mask: torch.Tensor,
|
| 293 |
+
chi_mask: torch.Tensor,
|
| 294 |
+
chi_angles_sin_cos: torch.Tensor,
|
| 295 |
+
chi_weight: float,
|
| 296 |
+
angle_norm_weight: float,
|
| 297 |
+
eps=1e-6,
|
| 298 |
+
**kwargs,
|
| 299 |
+
) -> torch.Tensor:
|
| 300 |
+
"""
|
| 301 |
+
Implements Algorithm 27 (torsionAngleLoss)
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
angles_sin_cos:
|
| 305 |
+
[*, N, 7, 2] predicted angles
|
| 306 |
+
unnormalized_angles_sin_cos:
|
| 307 |
+
The same angles, but unnormalized
|
| 308 |
+
aatype:
|
| 309 |
+
[*, N] residue indices
|
| 310 |
+
seq_mask:
|
| 311 |
+
[*, N] sequence mask
|
| 312 |
+
chi_mask:
|
| 313 |
+
[*, N, 7] angle mask
|
| 314 |
+
chi_angles_sin_cos:
|
| 315 |
+
[*, N, 7, 2] ground truth angles
|
| 316 |
+
chi_weight:
|
| 317 |
+
Weight for the angle component of the loss
|
| 318 |
+
angle_norm_weight:
|
| 319 |
+
Weight for the normalization component of the loss
|
| 320 |
+
Returns:
|
| 321 |
+
[*] loss tensor
|
| 322 |
+
"""
|
| 323 |
+
pred_angles = angles_sin_cos[..., 3:, :]
|
| 324 |
+
residue_type_one_hot = torch.nn.functional.one_hot(
|
| 325 |
+
aatype,
|
| 326 |
+
residue_constants.restype_num + 1,
|
| 327 |
+
)
|
| 328 |
+
chi_pi_periodic = torch.einsum(
|
| 329 |
+
"...ij,jk->ik",
|
| 330 |
+
residue_type_one_hot.type(angles_sin_cos.dtype),
|
| 331 |
+
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
true_chi = chi_angles_sin_cos[None]
|
| 335 |
+
|
| 336 |
+
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
|
| 337 |
+
true_chi_shifted = shifted_mask * true_chi
|
| 338 |
+
sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
|
| 339 |
+
sq_chi_error_shifted = torch.sum(
|
| 340 |
+
(true_chi_shifted - pred_angles) ** 2, dim=-1
|
| 341 |
+
)
|
| 342 |
+
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
|
| 343 |
+
|
| 344 |
+
# The ol' switcheroo
|
| 345 |
+
sq_chi_error = sq_chi_error.permute(
|
| 346 |
+
*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
sq_chi_loss = masked_mean(
|
| 350 |
+
chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
loss = chi_weight * sq_chi_loss
|
| 354 |
+
|
| 355 |
+
angle_norm = torch.sqrt(
|
| 356 |
+
torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
|
| 357 |
+
)
|
| 358 |
+
norm_error = torch.abs(angle_norm - 1.0)
|
| 359 |
+
norm_error = norm_error.permute(
|
| 360 |
+
*range(len(norm_error.shape))[1:-2], 0, -2, -1
|
| 361 |
+
)
|
| 362 |
+
angle_norm_loss = masked_mean(
|
| 363 |
+
seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
loss = loss + angle_norm_weight * angle_norm_loss
|
| 367 |
+
|
| 368 |
+
# Average over the batch dimension
|
| 369 |
+
loss = torch.mean(loss)
|
| 370 |
+
|
| 371 |
+
return loss
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
|
| 375 |
+
num_bins = logits.shape[-1]
|
| 376 |
+
bin_width = 1.0 / num_bins
|
| 377 |
+
bounds = torch.arange(
|
| 378 |
+
start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device
|
| 379 |
+
)
|
| 380 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 381 |
+
pred_lddt_ca = torch.sum(
|
| 382 |
+
probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
|
| 383 |
+
dim=-1,
|
| 384 |
+
)
|
| 385 |
+
return pred_lddt_ca * 100
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def lddt(
|
| 389 |
+
all_atom_pred_pos: torch.Tensor,
|
| 390 |
+
all_atom_positions: torch.Tensor,
|
| 391 |
+
all_atom_mask: torch.Tensor,
|
| 392 |
+
cutoff: float = 15.0,
|
| 393 |
+
eps: float = 1e-10,
|
| 394 |
+
per_residue: bool = True,
|
| 395 |
+
) -> torch.Tensor:
|
| 396 |
+
n = all_atom_mask.shape[-2]
|
| 397 |
+
dmat_true = torch.sqrt(
|
| 398 |
+
eps
|
| 399 |
+
+ torch.sum(
|
| 400 |
+
(
|
| 401 |
+
all_atom_positions[..., None, :]
|
| 402 |
+
- all_atom_positions[..., None, :, :]
|
| 403 |
+
)
|
| 404 |
+
** 2,
|
| 405 |
+
dim=-1,
|
| 406 |
+
)
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
dmat_pred = torch.sqrt(
|
| 410 |
+
eps
|
| 411 |
+
+ torch.sum(
|
| 412 |
+
(
|
| 413 |
+
all_atom_pred_pos[..., None, :]
|
| 414 |
+
- all_atom_pred_pos[..., None, :, :]
|
| 415 |
+
)
|
| 416 |
+
** 2,
|
| 417 |
+
dim=-1,
|
| 418 |
+
)
|
| 419 |
+
)
|
| 420 |
+
dists_to_score = (
|
| 421 |
+
(dmat_true < cutoff)
|
| 422 |
+
* all_atom_mask
|
| 423 |
+
* permute_final_dims(all_atom_mask, (1, 0))
|
| 424 |
+
* (1.0 - torch.eye(n, device=all_atom_mask.device))
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
dist_l1 = torch.abs(dmat_true - dmat_pred)
|
| 428 |
+
|
| 429 |
+
score = (
|
| 430 |
+
(dist_l1 < 0.5).type(dist_l1.dtype)
|
| 431 |
+
+ (dist_l1 < 1.0).type(dist_l1.dtype)
|
| 432 |
+
+ (dist_l1 < 2.0).type(dist_l1.dtype)
|
| 433 |
+
+ (dist_l1 < 4.0).type(dist_l1.dtype)
|
| 434 |
+
)
|
| 435 |
+
score = score * 0.25
|
| 436 |
+
|
| 437 |
+
dims = (-1,) if per_residue else (-2, -1)
|
| 438 |
+
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
|
| 439 |
+
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
|
| 440 |
+
|
| 441 |
+
return score
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def lddt_ca(
|
| 445 |
+
all_atom_pred_pos: torch.Tensor,
|
| 446 |
+
all_atom_positions: torch.Tensor,
|
| 447 |
+
all_atom_mask: torch.Tensor,
|
| 448 |
+
cutoff: float = 15.0,
|
| 449 |
+
eps: float = 1e-10,
|
| 450 |
+
per_residue: bool = True,
|
| 451 |
+
) -> torch.Tensor:
|
| 452 |
+
ca_pos = residue_constants.atom_order["CA"]
|
| 453 |
+
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
|
| 454 |
+
all_atom_positions = all_atom_positions[..., ca_pos, :]
|
| 455 |
+
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
|
| 456 |
+
|
| 457 |
+
return lddt(
|
| 458 |
+
all_atom_pred_pos,
|
| 459 |
+
all_atom_positions,
|
| 460 |
+
all_atom_mask,
|
| 461 |
+
cutoff=cutoff,
|
| 462 |
+
eps=eps,
|
| 463 |
+
per_residue=per_residue,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def lddt_loss(
|
| 468 |
+
logits: torch.Tensor,
|
| 469 |
+
all_atom_pred_pos: torch.Tensor,
|
| 470 |
+
all_atom_positions: torch.Tensor,
|
| 471 |
+
all_atom_mask: torch.Tensor,
|
| 472 |
+
resolution: torch.Tensor,
|
| 473 |
+
cutoff: float = 15.0,
|
| 474 |
+
no_bins: int = 50,
|
| 475 |
+
min_resolution: float = 0.1,
|
| 476 |
+
max_resolution: float = 3.0,
|
| 477 |
+
eps: float = 1e-10,
|
| 478 |
+
**kwargs,
|
| 479 |
+
) -> torch.Tensor:
|
| 480 |
+
n = all_atom_mask.shape[-2]
|
| 481 |
+
|
| 482 |
+
ca_pos = residue_constants.atom_order["CA"]
|
| 483 |
+
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
|
| 484 |
+
all_atom_positions = all_atom_positions[..., ca_pos, :]
|
| 485 |
+
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
|
| 486 |
+
|
| 487 |
+
score = lddt(
|
| 488 |
+
all_atom_pred_pos,
|
| 489 |
+
all_atom_positions,
|
| 490 |
+
all_atom_mask,
|
| 491 |
+
cutoff=cutoff,
|
| 492 |
+
eps=eps
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
score = score.detach()
|
| 496 |
+
|
| 497 |
+
bin_index = torch.floor(score * no_bins).long()
|
| 498 |
+
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
|
| 499 |
+
lddt_ca_one_hot = torch.nn.functional.one_hot(
|
| 500 |
+
bin_index, num_classes=no_bins
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
|
| 504 |
+
all_atom_mask = all_atom_mask.squeeze(-1)
|
| 505 |
+
loss = torch.sum(errors * all_atom_mask, dim=-1) / (
|
| 506 |
+
eps + torch.sum(all_atom_mask, dim=-1)
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
loss = loss * (
|
| 510 |
+
(resolution >= min_resolution) & (resolution <= max_resolution)
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Average over the batch dimension
|
| 514 |
+
loss = torch.mean(loss)
|
| 515 |
+
|
| 516 |
+
return loss
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def distogram_loss(
|
| 520 |
+
logits,
|
| 521 |
+
pseudo_beta,
|
| 522 |
+
pseudo_beta_mask,
|
| 523 |
+
min_bin=2.3125,
|
| 524 |
+
max_bin=21.6875,
|
| 525 |
+
no_bins=64,
|
| 526 |
+
eps=1e-6,
|
| 527 |
+
**kwargs,
|
| 528 |
+
):
|
| 529 |
+
boundaries = torch.linspace(
|
| 530 |
+
min_bin,
|
| 531 |
+
max_bin,
|
| 532 |
+
no_bins - 1,
|
| 533 |
+
device=logits.device,
|
| 534 |
+
)
|
| 535 |
+
boundaries = boundaries ** 2
|
| 536 |
+
|
| 537 |
+
dists = torch.sum(
|
| 538 |
+
(pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
|
| 539 |
+
dim=-1,
|
| 540 |
+
keepdims=True,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
true_bins = torch.sum(dists > boundaries, dim=-1)
|
| 544 |
+
|
| 545 |
+
errors = softmax_cross_entropy(
|
| 546 |
+
logits,
|
| 547 |
+
torch.nn.functional.one_hot(true_bins, no_bins),
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
|
| 551 |
+
|
| 552 |
+
# FP16-friendly sum. Equivalent to:
|
| 553 |
+
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
|
| 554 |
+
# (eps + torch.sum(square_mask, dim=(-1, -2))))
|
| 555 |
+
denom = eps + torch.sum(square_mask, dim=(-1, -2))
|
| 556 |
+
mean = errors * square_mask
|
| 557 |
+
mean = torch.sum(mean, dim=-1)
|
| 558 |
+
mean = mean / denom[..., None]
|
| 559 |
+
mean = torch.sum(mean, dim=-1)
|
| 560 |
+
|
| 561 |
+
# Average over the batch dimensions
|
| 562 |
+
mean = torch.mean(mean)
|
| 563 |
+
|
| 564 |
+
return mean
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def _calculate_bin_centers(boundaries: torch.Tensor):
|
| 568 |
+
step = boundaries[1] - boundaries[0]
|
| 569 |
+
bin_centers = boundaries + step / 2
|
| 570 |
+
bin_centers = torch.cat(
|
| 571 |
+
[bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
|
| 572 |
+
)
|
| 573 |
+
return bin_centers
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def _calculate_expected_aligned_error(
|
| 577 |
+
alignment_confidence_breaks: torch.Tensor,
|
| 578 |
+
aligned_distance_error_probs: torch.Tensor,
|
| 579 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 580 |
+
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
|
| 581 |
+
return (
|
| 582 |
+
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
|
| 583 |
+
bin_centers[-1],
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def compute_predicted_aligned_error(
|
| 588 |
+
logits: torch.Tensor,
|
| 589 |
+
max_bin: int = 31,
|
| 590 |
+
no_bins: int = 64,
|
| 591 |
+
**kwargs,
|
| 592 |
+
) -> Dict[str, torch.Tensor]:
|
| 593 |
+
"""Computes aligned confidence metrics from logits.
|
| 594 |
+
|
| 595 |
+
Args:
|
| 596 |
+
logits: [*, num_res, num_res, num_bins] the logits output from
|
| 597 |
+
PredictedAlignedErrorHead.
|
| 598 |
+
max_bin: Maximum bin value
|
| 599 |
+
no_bins: Number of bins
|
| 600 |
+
Returns:
|
| 601 |
+
aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
|
| 602 |
+
aligned error probabilities over bins for each residue pair.
|
| 603 |
+
predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
|
| 604 |
+
error for each pair of residues.
|
| 605 |
+
max_predicted_aligned_error: [*] the maximum predicted error possible.
|
| 606 |
+
"""
|
| 607 |
+
boundaries = torch.linspace(
|
| 608 |
+
0, max_bin, steps=(no_bins - 1), device=logits.device
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 612 |
+
(
|
| 613 |
+
predicted_aligned_error,
|
| 614 |
+
max_predicted_aligned_error,
|
| 615 |
+
) = _calculate_expected_aligned_error(
|
| 616 |
+
alignment_confidence_breaks=boundaries,
|
| 617 |
+
aligned_distance_error_probs=aligned_confidence_probs,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
return {
|
| 621 |
+
"aligned_confidence_probs": aligned_confidence_probs,
|
| 622 |
+
"predicted_aligned_error": predicted_aligned_error,
|
| 623 |
+
"max_predicted_aligned_error": max_predicted_aligned_error,
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def compute_tm(
|
| 628 |
+
logits: torch.Tensor,
|
| 629 |
+
residue_weights: Optional[torch.Tensor] = None,
|
| 630 |
+
max_bin: int = 31,
|
| 631 |
+
no_bins: int = 64,
|
| 632 |
+
eps: float = 1e-8,
|
| 633 |
+
**kwargs,
|
| 634 |
+
) -> torch.Tensor:
|
| 635 |
+
if residue_weights is None:
|
| 636 |
+
residue_weights = logits.new_ones(logits.shape[-2])
|
| 637 |
+
|
| 638 |
+
boundaries = torch.linspace(
|
| 639 |
+
0, max_bin, steps=(no_bins - 1), device=logits.device
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
bin_centers = _calculate_bin_centers(boundaries)
|
| 643 |
+
torch.sum(residue_weights)
|
| 644 |
+
n = logits.shape[-2]
|
| 645 |
+
clipped_n = max(n, 19)
|
| 646 |
+
|
| 647 |
+
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
|
| 648 |
+
|
| 649 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 650 |
+
|
| 651 |
+
tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
|
| 652 |
+
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
|
| 653 |
+
|
| 654 |
+
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
|
| 655 |
+
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
|
| 656 |
+
weighted = per_alignment * residue_weights
|
| 657 |
+
argmax = (weighted == torch.max(weighted)).nonzero()[0]
|
| 658 |
+
return per_alignment[tuple(argmax)]
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def tm_loss(
|
| 662 |
+
logits,
|
| 663 |
+
final_affine_tensor,
|
| 664 |
+
backbone_rigid_tensor,
|
| 665 |
+
backbone_rigid_mask,
|
| 666 |
+
resolution,
|
| 667 |
+
max_bin=31,
|
| 668 |
+
no_bins=64,
|
| 669 |
+
min_resolution: float = 0.1,
|
| 670 |
+
max_resolution: float = 3.0,
|
| 671 |
+
eps=1e-8,
|
| 672 |
+
**kwargs,
|
| 673 |
+
):
|
| 674 |
+
pred_affine = Rigid.from_tensor_7(final_affine_tensor)
|
| 675 |
+
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
|
| 676 |
+
|
| 677 |
+
def _points(affine):
|
| 678 |
+
pts = affine.get_trans()[..., None, :, :]
|
| 679 |
+
return affine.invert()[..., None].apply(pts)
|
| 680 |
+
|
| 681 |
+
sq_diff = torch.sum(
|
| 682 |
+
(_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
sq_diff = sq_diff.detach()
|
| 686 |
+
|
| 687 |
+
boundaries = torch.linspace(
|
| 688 |
+
0, max_bin, steps=(no_bins - 1), device=logits.device
|
| 689 |
+
)
|
| 690 |
+
boundaries = boundaries ** 2
|
| 691 |
+
true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1)
|
| 692 |
+
|
| 693 |
+
errors = softmax_cross_entropy(
|
| 694 |
+
logits, torch.nn.functional.one_hot(true_bins, no_bins)
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
square_mask = (
|
| 698 |
+
backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
loss = torch.sum(errors * square_mask, dim=-1)
|
| 702 |
+
scale = 0.5 # hack to help FP16 training along
|
| 703 |
+
denom = eps + torch.sum(scale * square_mask, dim=(-1, -2))
|
| 704 |
+
loss = loss / denom[..., None]
|
| 705 |
+
loss = torch.sum(loss, dim=-1)
|
| 706 |
+
loss = loss * scale
|
| 707 |
+
|
| 708 |
+
loss = loss * (
|
| 709 |
+
(resolution >= min_resolution) & (resolution <= max_resolution)
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
# Average over the loss dimension
|
| 713 |
+
loss = torch.mean(loss)
|
| 714 |
+
|
| 715 |
+
return loss
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def between_residue_bond_loss(
|
| 719 |
+
pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3)
|
| 720 |
+
pred_atom_mask: torch.Tensor, # (*, N, 37/14)
|
| 721 |
+
residue_index: torch.Tensor, # (*, N)
|
| 722 |
+
aatype: torch.Tensor, # (*, N)
|
| 723 |
+
tolerance_factor_soft=12.0,
|
| 724 |
+
tolerance_factor_hard=12.0,
|
| 725 |
+
eps=1e-6,
|
| 726 |
+
) -> Dict[str, torch.Tensor]:
|
| 727 |
+
"""Flat-bottom loss to penalize structural violations between residues.
|
| 728 |
+
|
| 729 |
+
This is a loss penalizing any violation of the geometry around the peptide
|
| 730 |
+
bond between consecutive amino acids. This loss corresponds to
|
| 731 |
+
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
|
| 732 |
+
|
| 733 |
+
Args:
|
| 734 |
+
pred_atom_positions: Atom positions in atom37/14 representation
|
| 735 |
+
pred_atom_mask: Atom mask in atom37/14 representation
|
| 736 |
+
residue_index: Residue index for given amino acid, this is assumed to be
|
| 737 |
+
monotonically increasing.
|
| 738 |
+
aatype: Amino acid type of given residue
|
| 739 |
+
tolerance_factor_soft: soft tolerance factor measured in standard deviations
|
| 740 |
+
of pdb distributions
|
| 741 |
+
tolerance_factor_hard: hard tolerance factor measured in standard deviations
|
| 742 |
+
of pdb distributions
|
| 743 |
+
|
| 744 |
+
Returns:
|
| 745 |
+
Dict containing:
|
| 746 |
+
* 'c_n_loss_mean': Loss for peptide bond length violations
|
| 747 |
+
* 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned
|
| 748 |
+
by CA, C, N
|
| 749 |
+
* 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned
|
| 750 |
+
by C, N, CA
|
| 751 |
+
* 'per_residue_loss_sum': sum of all losses for each residue
|
| 752 |
+
* 'per_residue_violation_mask': mask denoting all residues with violation
|
| 753 |
+
present.
|
| 754 |
+
"""
|
| 755 |
+
# Get the positions of the relevant backbone atoms.
|
| 756 |
+
this_ca_pos = pred_atom_positions[..., :-1, 1, :]
|
| 757 |
+
this_ca_mask = pred_atom_mask[..., :-1, 1]
|
| 758 |
+
this_c_pos = pred_atom_positions[..., :-1, 2, :]
|
| 759 |
+
this_c_mask = pred_atom_mask[..., :-1, 2]
|
| 760 |
+
next_n_pos = pred_atom_positions[..., 1:, 0, :]
|
| 761 |
+
next_n_mask = pred_atom_mask[..., 1:, 0]
|
| 762 |
+
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
|
| 763 |
+
next_ca_mask = pred_atom_mask[..., 1:, 1]
|
| 764 |
+
has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
|
| 765 |
+
|
| 766 |
+
# Compute loss for the C--N bond.
|
| 767 |
+
c_n_bond_length = torch.sqrt(
|
| 768 |
+
eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1)
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
# The C-N bond to proline has slightly different length because of the ring.
|
| 772 |
+
next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
|
| 773 |
+
gt_length = (
|
| 774 |
+
~next_is_proline
|
| 775 |
+
) * residue_constants.between_res_bond_length_c_n[
|
| 776 |
+
0
|
| 777 |
+
] + next_is_proline * residue_constants.between_res_bond_length_c_n[
|
| 778 |
+
1
|
| 779 |
+
]
|
| 780 |
+
gt_stddev = (
|
| 781 |
+
~next_is_proline
|
| 782 |
+
) * residue_constants.between_res_bond_length_stddev_c_n[
|
| 783 |
+
0
|
| 784 |
+
] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
|
| 785 |
+
1
|
| 786 |
+
]
|
| 787 |
+
c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
|
| 788 |
+
c_n_loss_per_residue = torch.nn.functional.relu(
|
| 789 |
+
c_n_bond_length_error - tolerance_factor_soft * gt_stddev
|
| 790 |
+
)
|
| 791 |
+
mask = this_c_mask * next_n_mask * has_no_gap_mask
|
| 792 |
+
c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (
|
| 793 |
+
torch.sum(mask, dim=-1) + eps
|
| 794 |
+
)
|
| 795 |
+
c_n_violation_mask = mask * (
|
| 796 |
+
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# Compute loss for the angles.
|
| 800 |
+
ca_c_bond_length = torch.sqrt(
|
| 801 |
+
eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1)
|
| 802 |
+
)
|
| 803 |
+
n_ca_bond_length = torch.sqrt(
|
| 804 |
+
eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1)
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None]
|
| 808 |
+
c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None]
|
| 809 |
+
n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None]
|
| 810 |
+
|
| 811 |
+
ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1)
|
| 812 |
+
gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
|
| 813 |
+
gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
|
| 814 |
+
ca_c_n_cos_angle_error = torch.sqrt(
|
| 815 |
+
eps + (ca_c_n_cos_angle - gt_angle) ** 2
|
| 816 |
+
)
|
| 817 |
+
ca_c_n_loss_per_residue = torch.nn.functional.relu(
|
| 818 |
+
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev
|
| 819 |
+
)
|
| 820 |
+
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
|
| 821 |
+
ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (
|
| 822 |
+
torch.sum(mask, dim=-1) + eps
|
| 823 |
+
)
|
| 824 |
+
ca_c_n_violation_mask = mask * (
|
| 825 |
+
ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1)
|
| 829 |
+
gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
|
| 830 |
+
gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
|
| 831 |
+
c_n_ca_cos_angle_error = torch.sqrt(
|
| 832 |
+
eps + torch.square(c_n_ca_cos_angle - gt_angle)
|
| 833 |
+
)
|
| 834 |
+
c_n_ca_loss_per_residue = torch.nn.functional.relu(
|
| 835 |
+
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev
|
| 836 |
+
)
|
| 837 |
+
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
|
| 838 |
+
c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (
|
| 839 |
+
torch.sum(mask, dim=-1) + eps
|
| 840 |
+
)
|
| 841 |
+
c_n_ca_violation_mask = mask * (
|
| 842 |
+
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
# Compute a per residue loss (equally distribute the loss to both
|
| 846 |
+
# neighbouring residues).
|
| 847 |
+
per_residue_loss_sum = (
|
| 848 |
+
c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
|
| 849 |
+
)
|
| 850 |
+
per_residue_loss_sum = 0.5 * (
|
| 851 |
+
torch.nn.functional.pad(per_residue_loss_sum, (0, 1))
|
| 852 |
+
+ torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
# Compute hard violations.
|
| 856 |
+
violation_mask = torch.max(
|
| 857 |
+
torch.stack(
|
| 858 |
+
[c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask],
|
| 859 |
+
dim=-2,
|
| 860 |
+
),
|
| 861 |
+
dim=-2,
|
| 862 |
+
)[0]
|
| 863 |
+
violation_mask = torch.maximum(
|
| 864 |
+
torch.nn.functional.pad(violation_mask, (0, 1)),
|
| 865 |
+
torch.nn.functional.pad(violation_mask, (1, 0)),
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
return {
|
| 869 |
+
"c_n_loss_mean": c_n_loss,
|
| 870 |
+
"ca_c_n_loss_mean": ca_c_n_loss,
|
| 871 |
+
"c_n_ca_loss_mean": c_n_ca_loss,
|
| 872 |
+
"per_residue_loss_sum": per_residue_loss_sum,
|
| 873 |
+
"per_residue_violation_mask": violation_mask,
|
| 874 |
+
}
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
def between_residue_clash_loss(
|
| 878 |
+
atom14_pred_positions: torch.Tensor,
|
| 879 |
+
atom14_atom_exists: torch.Tensor,
|
| 880 |
+
atom14_atom_radius: torch.Tensor,
|
| 881 |
+
residue_index: torch.Tensor,
|
| 882 |
+
overlap_tolerance_soft=1.5,
|
| 883 |
+
overlap_tolerance_hard=1.5,
|
| 884 |
+
eps=1e-10,
|
| 885 |
+
) -> Dict[str, torch.Tensor]:
|
| 886 |
+
"""Loss to penalize steric clashes between residues.
|
| 887 |
+
|
| 888 |
+
This is a loss penalizing any steric clashes due to non bonded atoms in
|
| 889 |
+
different peptides coming too close. This loss corresponds to the part with
|
| 890 |
+
different residues of
|
| 891 |
+
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
|
| 892 |
+
|
| 893 |
+
Args:
|
| 894 |
+
atom14_pred_positions: Predicted positions of atoms in
|
| 895 |
+
global prediction frame
|
| 896 |
+
atom14_atom_exists: Mask denoting whether atom at positions exists for given
|
| 897 |
+
amino acid type
|
| 898 |
+
atom14_atom_radius: Van der Waals radius for each atom.
|
| 899 |
+
residue_index: Residue index for given amino acid.
|
| 900 |
+
overlap_tolerance_soft: Soft tolerance factor.
|
| 901 |
+
overlap_tolerance_hard: Hard tolerance factor.
|
| 902 |
+
|
| 903 |
+
Returns:
|
| 904 |
+
Dict containing:
|
| 905 |
+
* 'mean_loss': average clash loss
|
| 906 |
+
* 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
|
| 907 |
+
* 'per_atom_clash_mask': mask whether atom clashes with any other atom
|
| 908 |
+
shape (N, 14)
|
| 909 |
+
"""
|
| 910 |
+
fp_type = atom14_pred_positions.dtype
|
| 911 |
+
|
| 912 |
+
# Create the distance matrix.
|
| 913 |
+
# (N, N, 14, 14)
|
| 914 |
+
dists = torch.sqrt(
|
| 915 |
+
eps
|
| 916 |
+
+ torch.sum(
|
| 917 |
+
(
|
| 918 |
+
atom14_pred_positions[..., :, None, :, None, :]
|
| 919 |
+
- atom14_pred_positions[..., None, :, None, :, :]
|
| 920 |
+
)
|
| 921 |
+
** 2,
|
| 922 |
+
dim=-1,
|
| 923 |
+
)
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
# Create the mask for valid distances.
|
| 927 |
+
# shape (N, N, 14, 14)
|
| 928 |
+
dists_mask = (
|
| 929 |
+
atom14_atom_exists[..., :, None, :, None]
|
| 930 |
+
* atom14_atom_exists[..., None, :, None, :]
|
| 931 |
+
).type(fp_type)
|
| 932 |
+
|
| 933 |
+
# Mask out all the duplicate entries in the lower triangular matrix.
|
| 934 |
+
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
|
| 935 |
+
# are handled separately.
|
| 936 |
+
dists_mask = dists_mask * (
|
| 937 |
+
residue_index[..., :, None, None, None]
|
| 938 |
+
< residue_index[..., None, :, None, None]
|
| 939 |
+
)
|
| 940 |
+
|
| 941 |
+
# Backbone C--N bond between subsequent residues is no clash.
|
| 942 |
+
c_one_hot = torch.nn.functional.one_hot(
|
| 943 |
+
residue_index.new_tensor(2), num_classes=14
|
| 944 |
+
)
|
| 945 |
+
c_one_hot = c_one_hot.reshape(
|
| 946 |
+
*((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape
|
| 947 |
+
)
|
| 948 |
+
c_one_hot = c_one_hot.type(fp_type)
|
| 949 |
+
n_one_hot = torch.nn.functional.one_hot(
|
| 950 |
+
residue_index.new_tensor(0), num_classes=14
|
| 951 |
+
)
|
| 952 |
+
n_one_hot = n_one_hot.reshape(
|
| 953 |
+
*((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape
|
| 954 |
+
)
|
| 955 |
+
n_one_hot = n_one_hot.type(fp_type)
|
| 956 |
+
|
| 957 |
+
neighbour_mask = (
|
| 958 |
+
residue_index[..., :, None, None, None] + 1
|
| 959 |
+
) == residue_index[..., None, :, None, None]
|
| 960 |
+
c_n_bonds = (
|
| 961 |
+
neighbour_mask
|
| 962 |
+
* c_one_hot[..., None, None, :, None]
|
| 963 |
+
* n_one_hot[..., None, None, None, :]
|
| 964 |
+
)
|
| 965 |
+
dists_mask = dists_mask * (1.0 - c_n_bonds)
|
| 966 |
+
|
| 967 |
+
# Disulfide bridge between two cysteines is no clash.
|
| 968 |
+
cys = residue_constants.restype_name_to_atom14_names["CYS"]
|
| 969 |
+
cys_sg_idx = cys.index("SG")
|
| 970 |
+
cys_sg_idx = residue_index.new_tensor(cys_sg_idx)
|
| 971 |
+
cys_sg_idx = cys_sg_idx.reshape(
|
| 972 |
+
*((1,) * len(residue_index.shape[:-1])), 1
|
| 973 |
+
).squeeze(-1)
|
| 974 |
+
cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14)
|
| 975 |
+
disulfide_bonds = (
|
| 976 |
+
cys_sg_one_hot[..., None, None, :, None]
|
| 977 |
+
* cys_sg_one_hot[..., None, None, None, :]
|
| 978 |
+
)
|
| 979 |
+
dists_mask = dists_mask * (1.0 - disulfide_bonds)
|
| 980 |
+
|
| 981 |
+
# Compute the lower bound for the allowed distances.
|
| 982 |
+
# shape (N, N, 14, 14)
|
| 983 |
+
dists_lower_bound = dists_mask * (
|
| 984 |
+
atom14_atom_radius[..., :, None, :, None]
|
| 985 |
+
+ atom14_atom_radius[..., None, :, None, :]
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
# Compute the error.
|
| 989 |
+
# shape (N, N, 14, 14)
|
| 990 |
+
dists_to_low_error = dists_mask * torch.nn.functional.relu(
|
| 991 |
+
dists_lower_bound - overlap_tolerance_soft - dists
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
# Compute the mean loss.
|
| 995 |
+
# shape ()
|
| 996 |
+
mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask))
|
| 997 |
+
|
| 998 |
+
# Compute the per atom loss sum.
|
| 999 |
+
# shape (N, 14)
|
| 1000 |
+
per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
|
| 1001 |
+
dists_to_low_error, axis=(-3, -1)
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
# Compute the hard clash mask.
|
| 1005 |
+
# shape (N, N, 14, 14)
|
| 1006 |
+
clash_mask = dists_mask * (
|
| 1007 |
+
dists < (dists_lower_bound - overlap_tolerance_hard)
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
# Compute the per atom clash.
|
| 1011 |
+
# shape (N, 14)
|
| 1012 |
+
per_atom_clash_mask = torch.maximum(
|
| 1013 |
+
torch.amax(clash_mask, axis=(-4, -2)),
|
| 1014 |
+
torch.amax(clash_mask, axis=(-3, -1)),
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
return {
|
| 1018 |
+
"mean_loss": mean_loss, # shape ()
|
| 1019 |
+
"per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
|
| 1020 |
+
"per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
|
| 1021 |
+
}
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
def within_residue_violations(
|
| 1025 |
+
atom14_pred_positions: torch.Tensor,
|
| 1026 |
+
atom14_atom_exists: torch.Tensor,
|
| 1027 |
+
atom14_dists_lower_bound: torch.Tensor,
|
| 1028 |
+
atom14_dists_upper_bound: torch.Tensor,
|
| 1029 |
+
tighten_bounds_for_loss=0.0,
|
| 1030 |
+
eps=1e-10,
|
| 1031 |
+
) -> Dict[str, torch.Tensor]:
|
| 1032 |
+
"""Loss to penalize steric clashes within residues.
|
| 1033 |
+
|
| 1034 |
+
This is a loss penalizing any steric violations or clashes of non-bonded atoms
|
| 1035 |
+
in a given peptide. This loss corresponds to the part with
|
| 1036 |
+
the same residues of
|
| 1037 |
+
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
|
| 1038 |
+
|
| 1039 |
+
Args:
|
| 1040 |
+
atom14_pred_positions ([*, N, 14, 3]):
|
| 1041 |
+
Predicted positions of atoms in global prediction frame.
|
| 1042 |
+
atom14_atom_exists ([*, N, 14]):
|
| 1043 |
+
Mask denoting whether atom at positions exists for given
|
| 1044 |
+
amino acid type
|
| 1045 |
+
atom14_dists_lower_bound ([*, N, 14]):
|
| 1046 |
+
Lower bound on allowed distances.
|
| 1047 |
+
atom14_dists_upper_bound ([*, N, 14]):
|
| 1048 |
+
Upper bound on allowed distances
|
| 1049 |
+
tighten_bounds_for_loss ([*, N]):
|
| 1050 |
+
Extra factor to tighten loss
|
| 1051 |
+
|
| 1052 |
+
Returns:
|
| 1053 |
+
Dict containing:
|
| 1054 |
+
* 'per_atom_loss_sum' ([*, N, 14]):
|
| 1055 |
+
sum of all clash losses per atom, shape
|
| 1056 |
+
* 'per_atom_clash_mask' ([*, N, 14]):
|
| 1057 |
+
mask whether atom clashes with any other atom shape
|
| 1058 |
+
"""
|
| 1059 |
+
# Compute the mask for each residue.
|
| 1060 |
+
dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None]
|
| 1061 |
+
dists_masks = dists_masks.reshape(
|
| 1062 |
+
*((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape
|
| 1063 |
+
)
|
| 1064 |
+
dists_masks = (
|
| 1065 |
+
atom14_atom_exists[..., :, :, None]
|
| 1066 |
+
* atom14_atom_exists[..., :, None, :]
|
| 1067 |
+
* dists_masks
|
| 1068 |
+
)
|
| 1069 |
+
|
| 1070 |
+
# Distance matrix
|
| 1071 |
+
dists = torch.sqrt(
|
| 1072 |
+
eps
|
| 1073 |
+
+ torch.sum(
|
| 1074 |
+
(
|
| 1075 |
+
atom14_pred_positions[..., :, :, None, :]
|
| 1076 |
+
- atom14_pred_positions[..., :, None, :, :]
|
| 1077 |
+
)
|
| 1078 |
+
** 2,
|
| 1079 |
+
dim=-1,
|
| 1080 |
+
)
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
# Compute the loss.
|
| 1084 |
+
dists_to_low_error = torch.nn.functional.relu(
|
| 1085 |
+
atom14_dists_lower_bound + tighten_bounds_for_loss - dists
|
| 1086 |
+
)
|
| 1087 |
+
dists_to_high_error = torch.nn.functional.relu(
|
| 1088 |
+
dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)
|
| 1089 |
+
)
|
| 1090 |
+
loss = dists_masks * (dists_to_low_error + dists_to_high_error)
|
| 1091 |
+
|
| 1092 |
+
# Compute the per atom loss sum.
|
| 1093 |
+
per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1)
|
| 1094 |
+
|
| 1095 |
+
# Compute the violations mask.
|
| 1096 |
+
violations = dists_masks * (
|
| 1097 |
+
(dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
# Compute the per atom violations.
|
| 1101 |
+
per_atom_violations = torch.maximum(
|
| 1102 |
+
torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
return {
|
| 1106 |
+
"per_atom_loss_sum": per_atom_loss_sum,
|
| 1107 |
+
"per_atom_violations": per_atom_violations,
|
| 1108 |
+
}
|
| 1109 |
+
|
| 1110 |
+
|
| 1111 |
+
def find_structural_violations(
|
| 1112 |
+
batch: Dict[str, torch.Tensor],
|
| 1113 |
+
atom14_pred_positions: torch.Tensor,
|
| 1114 |
+
violation_tolerance_factor: float,
|
| 1115 |
+
clash_overlap_tolerance: float,
|
| 1116 |
+
**kwargs,
|
| 1117 |
+
) -> Dict[str, torch.Tensor]:
|
| 1118 |
+
"""Computes several checks for structural violations."""
|
| 1119 |
+
|
| 1120 |
+
# Compute between residue backbone violations of bonds and angles.
|
| 1121 |
+
connection_violations = between_residue_bond_loss(
|
| 1122 |
+
pred_atom_positions=atom14_pred_positions,
|
| 1123 |
+
pred_atom_mask=batch["atom14_atom_exists"],
|
| 1124 |
+
residue_index=batch["residue_index"],
|
| 1125 |
+
aatype=batch["aatype"],
|
| 1126 |
+
tolerance_factor_soft=violation_tolerance_factor,
|
| 1127 |
+
tolerance_factor_hard=violation_tolerance_factor,
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
# Compute the Van der Waals radius for every atom
|
| 1131 |
+
# (the first letter of the atom name is the element type).
|
| 1132 |
+
# Shape: (N, 14).
|
| 1133 |
+
atomtype_radius = [
|
| 1134 |
+
residue_constants.van_der_waals_radius[name[0]]
|
| 1135 |
+
for name in residue_constants.atom_types
|
| 1136 |
+
]
|
| 1137 |
+
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
|
| 1138 |
+
atom14_atom_radius = (
|
| 1139 |
+
batch["atom14_atom_exists"]
|
| 1140 |
+
* atomtype_radius[batch["residx_atom14_to_atom37"]]
|
| 1141 |
+
)
|
| 1142 |
+
|
| 1143 |
+
# Compute the between residue clash loss.
|
| 1144 |
+
between_residue_clashes = between_residue_clash_loss(
|
| 1145 |
+
atom14_pred_positions=atom14_pred_positions,
|
| 1146 |
+
atom14_atom_exists=batch["atom14_atom_exists"],
|
| 1147 |
+
atom14_atom_radius=atom14_atom_radius,
|
| 1148 |
+
residue_index=batch["residue_index"],
|
| 1149 |
+
overlap_tolerance_soft=clash_overlap_tolerance,
|
| 1150 |
+
overlap_tolerance_hard=clash_overlap_tolerance,
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
# Compute all within-residue violations (clashes,
|
| 1154 |
+
# bond length and angle violations).
|
| 1155 |
+
restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
|
| 1156 |
+
overlap_tolerance=clash_overlap_tolerance,
|
| 1157 |
+
bond_length_tolerance_factor=violation_tolerance_factor,
|
| 1158 |
+
)
|
| 1159 |
+
atom14_atom_exists = batch["atom14_atom_exists"]
|
| 1160 |
+
atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
|
| 1161 |
+
restype_atom14_bounds["lower_bound"]
|
| 1162 |
+
)[batch["aatype"]]
|
| 1163 |
+
atom14_dists_upper_bound = atom14_pred_positions.new_tensor(
|
| 1164 |
+
restype_atom14_bounds["upper_bound"]
|
| 1165 |
+
)[batch["aatype"]]
|
| 1166 |
+
residue_violations = within_residue_violations(
|
| 1167 |
+
atom14_pred_positions=atom14_pred_positions,
|
| 1168 |
+
atom14_atom_exists=batch["atom14_atom_exists"],
|
| 1169 |
+
atom14_dists_lower_bound=atom14_dists_lower_bound,
|
| 1170 |
+
atom14_dists_upper_bound=atom14_dists_upper_bound,
|
| 1171 |
+
tighten_bounds_for_loss=0.0,
|
| 1172 |
+
)
|
| 1173 |
+
|
| 1174 |
+
# Combine them to a single per-residue violation mask (used later for LDDT).
|
| 1175 |
+
per_residue_violations_mask = torch.max(
|
| 1176 |
+
torch.stack(
|
| 1177 |
+
[
|
| 1178 |
+
connection_violations["per_residue_violation_mask"],
|
| 1179 |
+
torch.max(
|
| 1180 |
+
between_residue_clashes["per_atom_clash_mask"], dim=-1
|
| 1181 |
+
)[0],
|
| 1182 |
+
torch.max(residue_violations["per_atom_violations"], dim=-1)[0],
|
| 1183 |
+
],
|
| 1184 |
+
dim=-1,
|
| 1185 |
+
),
|
| 1186 |
+
dim=-1,
|
| 1187 |
+
)[0]
|
| 1188 |
+
|
| 1189 |
+
return {
|
| 1190 |
+
"between_residues": {
|
| 1191 |
+
"bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # ()
|
| 1192 |
+
"angles_ca_c_n_loss_mean": connection_violations[
|
| 1193 |
+
"ca_c_n_loss_mean"
|
| 1194 |
+
], # ()
|
| 1195 |
+
"angles_c_n_ca_loss_mean": connection_violations[
|
| 1196 |
+
"c_n_ca_loss_mean"
|
| 1197 |
+
], # ()
|
| 1198 |
+
"connections_per_residue_loss_sum": connection_violations[
|
| 1199 |
+
"per_residue_loss_sum"
|
| 1200 |
+
], # (N)
|
| 1201 |
+
"connections_per_residue_violation_mask": connection_violations[
|
| 1202 |
+
"per_residue_violation_mask"
|
| 1203 |
+
], # (N)
|
| 1204 |
+
"clashes_mean_loss": between_residue_clashes["mean_loss"], # ()
|
| 1205 |
+
"clashes_per_atom_loss_sum": between_residue_clashes[
|
| 1206 |
+
"per_atom_loss_sum"
|
| 1207 |
+
], # (N, 14)
|
| 1208 |
+
"clashes_per_atom_clash_mask": between_residue_clashes[
|
| 1209 |
+
"per_atom_clash_mask"
|
| 1210 |
+
], # (N, 14)
|
| 1211 |
+
},
|
| 1212 |
+
"within_residues": {
|
| 1213 |
+
"per_atom_loss_sum": residue_violations[
|
| 1214 |
+
"per_atom_loss_sum"
|
| 1215 |
+
], # (N, 14)
|
| 1216 |
+
"per_atom_violations": residue_violations[
|
| 1217 |
+
"per_atom_violations"
|
| 1218 |
+
], # (N, 14),
|
| 1219 |
+
},
|
| 1220 |
+
"total_per_residue_violations_mask": per_residue_violations_mask, # (N)
|
| 1221 |
+
}
|
| 1222 |
+
|
| 1223 |
+
|
| 1224 |
+
def find_structural_violations_np(
|
| 1225 |
+
batch: Dict[str, np.ndarray],
|
| 1226 |
+
atom14_pred_positions: np.ndarray,
|
| 1227 |
+
config: ml_collections.ConfigDict,
|
| 1228 |
+
) -> Dict[str, np.ndarray]:
|
| 1229 |
+
to_tensor = lambda x: torch.tensor(x)
|
| 1230 |
+
batch = tree_map(to_tensor, batch, np.ndarray)
|
| 1231 |
+
atom14_pred_positions = to_tensor(atom14_pred_positions)
|
| 1232 |
+
|
| 1233 |
+
out = find_structural_violations(batch, atom14_pred_positions, **config)
|
| 1234 |
+
|
| 1235 |
+
to_np = lambda x: np.array(x)
|
| 1236 |
+
np_out = tensor_tree_map(to_np, out)
|
| 1237 |
+
|
| 1238 |
+
return np_out
|
| 1239 |
+
|
| 1240 |
+
|
| 1241 |
+
def extreme_ca_ca_distance_violations(
|
| 1242 |
+
pred_atom_positions: torch.Tensor, # (N, 37(14), 3)
|
| 1243 |
+
pred_atom_mask: torch.Tensor, # (N, 37(14))
|
| 1244 |
+
residue_index: torch.Tensor, # (N)
|
| 1245 |
+
max_angstrom_tolerance=1.5,
|
| 1246 |
+
eps=1e-6,
|
| 1247 |
+
) -> torch.Tensor:
|
| 1248 |
+
"""Counts residues whose Ca is a large distance from its neighbour.
|
| 1249 |
+
|
| 1250 |
+
Measures the fraction of CA-CA pairs between consecutive amino acids that are
|
| 1251 |
+
more than 'max_angstrom_tolerance' apart.
|
| 1252 |
+
|
| 1253 |
+
Args:
|
| 1254 |
+
pred_atom_positions: Atom positions in atom37/14 representation
|
| 1255 |
+
pred_atom_mask: Atom mask in atom37/14 representation
|
| 1256 |
+
residue_index: Residue index for given amino acid, this is assumed to be
|
| 1257 |
+
monotonically increasing.
|
| 1258 |
+
max_angstrom_tolerance: Maximum distance allowed to not count as violation.
|
| 1259 |
+
Returns:
|
| 1260 |
+
Fraction of consecutive CA-CA pairs with violation.
|
| 1261 |
+
"""
|
| 1262 |
+
this_ca_pos = pred_atom_positions[..., :-1, 1, :]
|
| 1263 |
+
this_ca_mask = pred_atom_mask[..., :-1, 1]
|
| 1264 |
+
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
|
| 1265 |
+
next_ca_mask = pred_atom_mask[..., 1:, 1]
|
| 1266 |
+
has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
|
| 1267 |
+
ca_ca_distance = torch.sqrt(
|
| 1268 |
+
eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1)
|
| 1269 |
+
)
|
| 1270 |
+
violations = (
|
| 1271 |
+
ca_ca_distance - residue_constants.ca_ca
|
| 1272 |
+
) > max_angstrom_tolerance
|
| 1273 |
+
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
|
| 1274 |
+
mean = masked_mean(mask, violations, -1)
|
| 1275 |
+
return mean
|
| 1276 |
+
|
| 1277 |
+
|
| 1278 |
+
def compute_violation_metrics(
|
| 1279 |
+
batch: Dict[str, torch.Tensor],
|
| 1280 |
+
atom14_pred_positions: torch.Tensor, # (N, 14, 3)
|
| 1281 |
+
violations: Dict[str, torch.Tensor],
|
| 1282 |
+
) -> Dict[str, torch.Tensor]:
|
| 1283 |
+
"""Compute several metrics to assess the structural violations."""
|
| 1284 |
+
ret = {}
|
| 1285 |
+
extreme_ca_ca_violations = extreme_ca_ca_distance_violations(
|
| 1286 |
+
pred_atom_positions=atom14_pred_positions,
|
| 1287 |
+
pred_atom_mask=batch["atom14_atom_exists"],
|
| 1288 |
+
residue_index=batch["residue_index"],
|
| 1289 |
+
)
|
| 1290 |
+
ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations
|
| 1291 |
+
ret["violations_between_residue_bond"] = masked_mean(
|
| 1292 |
+
batch["seq_mask"],
|
| 1293 |
+
violations["between_residues"][
|
| 1294 |
+
"connections_per_residue_violation_mask"
|
| 1295 |
+
],
|
| 1296 |
+
dim=-1,
|
| 1297 |
+
)
|
| 1298 |
+
ret["violations_between_residue_clash"] = masked_mean(
|
| 1299 |
+
mask=batch["seq_mask"],
|
| 1300 |
+
value=torch.max(
|
| 1301 |
+
violations["between_residues"]["clashes_per_atom_clash_mask"],
|
| 1302 |
+
dim=-1,
|
| 1303 |
+
)[0],
|
| 1304 |
+
dim=-1,
|
| 1305 |
+
)
|
| 1306 |
+
ret["violations_within_residue"] = masked_mean(
|
| 1307 |
+
mask=batch["seq_mask"],
|
| 1308 |
+
value=torch.max(
|
| 1309 |
+
violations["within_residues"]["per_atom_violations"], dim=-1
|
| 1310 |
+
)[0],
|
| 1311 |
+
dim=-1,
|
| 1312 |
+
)
|
| 1313 |
+
ret["violations_per_residue"] = masked_mean(
|
| 1314 |
+
mask=batch["seq_mask"],
|
| 1315 |
+
value=violations["total_per_residue_violations_mask"],
|
| 1316 |
+
dim=-1,
|
| 1317 |
+
)
|
| 1318 |
+
return ret
|
| 1319 |
+
|
| 1320 |
+
|
| 1321 |
+
def compute_violation_metrics_np(
|
| 1322 |
+
batch: Dict[str, np.ndarray],
|
| 1323 |
+
atom14_pred_positions: np.ndarray,
|
| 1324 |
+
violations: Dict[str, np.ndarray],
|
| 1325 |
+
) -> Dict[str, np.ndarray]:
|
| 1326 |
+
to_tensor = lambda x: torch.tensor(x)
|
| 1327 |
+
batch = tree_map(to_tensor, batch, np.ndarray)
|
| 1328 |
+
atom14_pred_positions = to_tensor(atom14_pred_positions)
|
| 1329 |
+
violations = tree_map(to_tensor, violations, np.ndarray)
|
| 1330 |
+
|
| 1331 |
+
out = compute_violation_metrics(batch, atom14_pred_positions, violations)
|
| 1332 |
+
|
| 1333 |
+
to_np = lambda x: np.array(x)
|
| 1334 |
+
return tree_map(to_np, out, torch.Tensor)
|
| 1335 |
+
|
| 1336 |
+
|
| 1337 |
+
def violation_loss(
|
| 1338 |
+
violations: Dict[str, torch.Tensor],
|
| 1339 |
+
atom14_atom_exists: torch.Tensor,
|
| 1340 |
+
eps=1e-6,
|
| 1341 |
+
**kwargs,
|
| 1342 |
+
) -> torch.Tensor:
|
| 1343 |
+
num_atoms = torch.sum(atom14_atom_exists)
|
| 1344 |
+
l_clash = torch.sum(
|
| 1345 |
+
violations["between_residues"]["clashes_per_atom_loss_sum"]
|
| 1346 |
+
+ violations["within_residues"]["per_atom_loss_sum"]
|
| 1347 |
+
)
|
| 1348 |
+
l_clash = l_clash / (eps + num_atoms)
|
| 1349 |
+
loss = (
|
| 1350 |
+
violations["between_residues"]["bonds_c_n_loss_mean"]
|
| 1351 |
+
+ violations["between_residues"]["angles_ca_c_n_loss_mean"]
|
| 1352 |
+
+ violations["between_residues"]["angles_c_n_ca_loss_mean"]
|
| 1353 |
+
+ l_clash
|
| 1354 |
+
)
|
| 1355 |
+
|
| 1356 |
+
return loss
|
| 1357 |
+
|
| 1358 |
+
|
| 1359 |
+
def compute_renamed_ground_truth(
|
| 1360 |
+
batch: Dict[str, torch.Tensor],
|
| 1361 |
+
atom14_pred_positions: torch.Tensor,
|
| 1362 |
+
eps=1e-10,
|
| 1363 |
+
) -> Dict[str, torch.Tensor]:
|
| 1364 |
+
"""
|
| 1365 |
+
Find optimal renaming of ground truth based on the predicted positions.
|
| 1366 |
+
|
| 1367 |
+
Alg. 26 "renameSymmetricGroundTruthAtoms"
|
| 1368 |
+
|
| 1369 |
+
This renamed ground truth is then used for all losses,
|
| 1370 |
+
such that each loss moves the atoms in the same direction.
|
| 1371 |
+
|
| 1372 |
+
Args:
|
| 1373 |
+
batch: Dictionary containing:
|
| 1374 |
+
* atom14_gt_positions: Ground truth positions.
|
| 1375 |
+
* atom14_alt_gt_positions: Ground truth positions with renaming swaps.
|
| 1376 |
+
* atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
|
| 1377 |
+
renaming swaps.
|
| 1378 |
+
* atom14_gt_exists: Mask for which atoms exist in ground truth.
|
| 1379 |
+
* atom14_alt_gt_exists: Mask for which atoms exist in ground truth
|
| 1380 |
+
after renaming.
|
| 1381 |
+
* atom14_atom_exists: Mask for whether each atom is part of the given
|
| 1382 |
+
amino acid type.
|
| 1383 |
+
atom14_pred_positions: Array of atom positions in global frame with shape
|
| 1384 |
+
Returns:
|
| 1385 |
+
Dictionary containing:
|
| 1386 |
+
alt_naming_is_better: Array with 1.0 where alternative swap is better.
|
| 1387 |
+
renamed_atom14_gt_positions: Array of optimal ground truth positions
|
| 1388 |
+
after renaming swaps are performed.
|
| 1389 |
+
renamed_atom14_gt_exists: Mask after renaming swap is performed.
|
| 1390 |
+
"""
|
| 1391 |
+
|
| 1392 |
+
pred_dists = torch.sqrt(
|
| 1393 |
+
eps
|
| 1394 |
+
+ torch.sum(
|
| 1395 |
+
(
|
| 1396 |
+
atom14_pred_positions[..., None, :, None, :]
|
| 1397 |
+
- atom14_pred_positions[..., None, :, None, :, :]
|
| 1398 |
+
)
|
| 1399 |
+
** 2,
|
| 1400 |
+
dim=-1,
|
| 1401 |
+
)
|
| 1402 |
+
)
|
| 1403 |
+
|
| 1404 |
+
atom14_gt_positions = batch["atom14_gt_positions"]
|
| 1405 |
+
gt_dists = torch.sqrt(
|
| 1406 |
+
eps
|
| 1407 |
+
+ torch.sum(
|
| 1408 |
+
(
|
| 1409 |
+
atom14_gt_positions[..., None, :, None, :]
|
| 1410 |
+
- atom14_gt_positions[..., None, :, None, :, :]
|
| 1411 |
+
)
|
| 1412 |
+
** 2,
|
| 1413 |
+
dim=-1,
|
| 1414 |
+
)
|
| 1415 |
+
)
|
| 1416 |
+
|
| 1417 |
+
atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
|
| 1418 |
+
alt_gt_dists = torch.sqrt(
|
| 1419 |
+
eps
|
| 1420 |
+
+ torch.sum(
|
| 1421 |
+
(
|
| 1422 |
+
atom14_alt_gt_positions[..., None, :, None, :]
|
| 1423 |
+
- atom14_alt_gt_positions[..., None, :, None, :, :]
|
| 1424 |
+
)
|
| 1425 |
+
** 2,
|
| 1426 |
+
dim=-1,
|
| 1427 |
+
)
|
| 1428 |
+
)
|
| 1429 |
+
|
| 1430 |
+
lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
|
| 1431 |
+
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
|
| 1432 |
+
|
| 1433 |
+
atom14_gt_exists = batch["atom14_gt_exists"]
|
| 1434 |
+
atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
|
| 1435 |
+
mask = (
|
| 1436 |
+
atom14_gt_exists[..., None, :, None]
|
| 1437 |
+
* atom14_atom_is_ambiguous[..., None, :, None]
|
| 1438 |
+
* atom14_gt_exists[..., None, :, None, :]
|
| 1439 |
+
* (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
|
| 1440 |
+
)
|
| 1441 |
+
|
| 1442 |
+
per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
|
| 1443 |
+
alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3))
|
| 1444 |
+
|
| 1445 |
+
fp_type = atom14_pred_positions.dtype
|
| 1446 |
+
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
|
| 1447 |
+
|
| 1448 |
+
renamed_atom14_gt_positions = (
|
| 1449 |
+
1.0 - alt_naming_is_better[..., None, None]
|
| 1450 |
+
) * atom14_gt_positions + alt_naming_is_better[
|
| 1451 |
+
..., None, None
|
| 1452 |
+
] * atom14_alt_gt_positions
|
| 1453 |
+
|
| 1454 |
+
renamed_atom14_gt_mask = (
|
| 1455 |
+
1.0 - alt_naming_is_better[..., None]
|
| 1456 |
+
) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
|
| 1457 |
+
"atom14_alt_gt_exists"
|
| 1458 |
+
]
|
| 1459 |
+
|
| 1460 |
+
return {
|
| 1461 |
+
"alt_naming_is_better": alt_naming_is_better,
|
| 1462 |
+
"renamed_atom14_gt_positions": renamed_atom14_gt_positions,
|
| 1463 |
+
"renamed_atom14_gt_exists": renamed_atom14_gt_mask,
|
| 1464 |
+
}
|
| 1465 |
+
|
| 1466 |
+
|
| 1467 |
+
def experimentally_resolved_loss(
|
| 1468 |
+
logits: torch.Tensor,
|
| 1469 |
+
atom37_atom_exists: torch.Tensor,
|
| 1470 |
+
all_atom_mask: torch.Tensor,
|
| 1471 |
+
resolution: torch.Tensor,
|
| 1472 |
+
min_resolution: float,
|
| 1473 |
+
max_resolution: float,
|
| 1474 |
+
eps: float = 1e-8,
|
| 1475 |
+
**kwargs,
|
| 1476 |
+
) -> torch.Tensor:
|
| 1477 |
+
errors = sigmoid_cross_entropy(logits, all_atom_mask)
|
| 1478 |
+
loss = torch.sum(errors * atom37_atom_exists, dim=-1)
|
| 1479 |
+
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
|
| 1480 |
+
loss = torch.sum(loss, dim=-1)
|
| 1481 |
+
|
| 1482 |
+
loss = loss * (
|
| 1483 |
+
(resolution >= min_resolution) & (resolution <= max_resolution)
|
| 1484 |
+
)
|
| 1485 |
+
|
| 1486 |
+
loss = torch.mean(loss)
|
| 1487 |
+
|
| 1488 |
+
return loss
|
| 1489 |
+
|
| 1490 |
+
|
| 1491 |
+
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
|
| 1492 |
+
"""
|
| 1493 |
+
Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
|
| 1494 |
+
|
| 1495 |
+
Args:
|
| 1496 |
+
logits: [*, N_seq, N_res, 23] predicted residue distribution
|
| 1497 |
+
true_msa: [*, N_seq, N_res] true MSA
|
| 1498 |
+
bert_mask: [*, N_seq, N_res] MSA mask
|
| 1499 |
+
Returns:
|
| 1500 |
+
Masked MSA loss
|
| 1501 |
+
"""
|
| 1502 |
+
errors = softmax_cross_entropy(
|
| 1503 |
+
logits, torch.nn.functional.one_hot(true_msa, num_classes=23)
|
| 1504 |
+
)
|
| 1505 |
+
|
| 1506 |
+
# FP16-friendly averaging. Equivalent to:
|
| 1507 |
+
# loss = (
|
| 1508 |
+
# torch.sum(errors * bert_mask, dim=(-1, -2)) /
|
| 1509 |
+
# (eps + torch.sum(bert_mask, dim=(-1, -2)))
|
| 1510 |
+
# )
|
| 1511 |
+
loss = errors * bert_mask
|
| 1512 |
+
loss = torch.sum(loss, dim=-1)
|
| 1513 |
+
scale = 0.5
|
| 1514 |
+
denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2))
|
| 1515 |
+
loss = loss / denom[..., None]
|
| 1516 |
+
loss = torch.sum(loss, dim=-1)
|
| 1517 |
+
loss = loss * scale
|
| 1518 |
+
|
| 1519 |
+
loss = torch.mean(loss)
|
| 1520 |
+
|
| 1521 |
+
return loss
|
| 1522 |
+
|
| 1523 |
+
|
| 1524 |
+
class AlphaFoldLoss(nn.Module):
|
| 1525 |
+
"""Aggregation of the various losses described in the supplement"""
|
| 1526 |
+
def __init__(self, config):
|
| 1527 |
+
super(AlphaFoldLoss, self).__init__()
|
| 1528 |
+
self.config = config
|
| 1529 |
+
|
| 1530 |
+
def forward(self, out, batch, _return_breakdown=False):
|
| 1531 |
+
if "violation" not in out.keys():
|
| 1532 |
+
out["violation"] = find_structural_violations(
|
| 1533 |
+
batch,
|
| 1534 |
+
out["sm"]["positions"][-1],
|
| 1535 |
+
**self.config.violation,
|
| 1536 |
+
)
|
| 1537 |
+
|
| 1538 |
+
if "renamed_atom14_gt_positions" not in out.keys():
|
| 1539 |
+
batch.update(
|
| 1540 |
+
compute_renamed_ground_truth(
|
| 1541 |
+
batch,
|
| 1542 |
+
out["sm"]["positions"][-1],
|
| 1543 |
+
)
|
| 1544 |
+
)
|
| 1545 |
+
|
| 1546 |
+
loss_fns = {
|
| 1547 |
+
"distogram": lambda: distogram_loss(
|
| 1548 |
+
logits=out["distogram_logits"],
|
| 1549 |
+
**{**batch, **self.config.distogram},
|
| 1550 |
+
),
|
| 1551 |
+
"experimentally_resolved": lambda: experimentally_resolved_loss(
|
| 1552 |
+
logits=out["experimentally_resolved_logits"],
|
| 1553 |
+
**{**batch, **self.config.experimentally_resolved},
|
| 1554 |
+
),
|
| 1555 |
+
"fape": lambda: fape_loss(
|
| 1556 |
+
out,
|
| 1557 |
+
batch,
|
| 1558 |
+
self.config.fape,
|
| 1559 |
+
),
|
| 1560 |
+
"lddt": lambda: lddt_loss(
|
| 1561 |
+
logits=out["lddt_logits"],
|
| 1562 |
+
all_atom_pred_pos=out["final_atom_positions"],
|
| 1563 |
+
**{**batch, **self.config.lddt},
|
| 1564 |
+
),
|
| 1565 |
+
"masked_msa": lambda: masked_msa_loss(
|
| 1566 |
+
logits=out["masked_msa_logits"],
|
| 1567 |
+
**{**batch, **self.config.masked_msa},
|
| 1568 |
+
),
|
| 1569 |
+
"supervised_chi": lambda: supervised_chi_loss(
|
| 1570 |
+
out["sm"]["angles"],
|
| 1571 |
+
out["sm"]["unnormalized_angles"],
|
| 1572 |
+
**{**batch, **self.config.supervised_chi},
|
| 1573 |
+
),
|
| 1574 |
+
"violation": lambda: violation_loss(
|
| 1575 |
+
out["violation"],
|
| 1576 |
+
**batch,
|
| 1577 |
+
),
|
| 1578 |
+
}
|
| 1579 |
+
|
| 1580 |
+
if(self.config.tm.enabled):
|
| 1581 |
+
loss_fns["tm"] = lambda: tm_loss(
|
| 1582 |
+
logits=out["tm_logits"],
|
| 1583 |
+
**{**batch, **out, **self.config.tm},
|
| 1584 |
+
)
|
| 1585 |
+
|
| 1586 |
+
cum_loss = 0.
|
| 1587 |
+
losses = {}
|
| 1588 |
+
for loss_name, loss_fn in loss_fns.items():
|
| 1589 |
+
weight = self.config[loss_name].weight
|
| 1590 |
+
loss = loss_fn()
|
| 1591 |
+
if(torch.isnan(loss) or torch.isinf(loss)):
|
| 1592 |
+
#for k,v in batch.items():
|
| 1593 |
+
# if(torch.any(torch.isnan(v)) or torch.any(torch.isinf(v))):
|
| 1594 |
+
# logging.warning(f"{k}: is nan")
|
| 1595 |
+
#logging.warning(f"{loss_name}: {loss}")
|
| 1596 |
+
logging.warning(f"{loss_name} loss is NaN. Skipping...")
|
| 1597 |
+
loss = loss.new_tensor(0., requires_grad=True)
|
| 1598 |
+
cum_loss = cum_loss + weight * loss
|
| 1599 |
+
losses[loss_name] = loss.detach().clone()
|
| 1600 |
+
|
| 1601 |
+
losses["unscaled_loss"] = cum_loss.detach().clone()
|
| 1602 |
+
|
| 1603 |
+
# Scale the loss by the square root of the minimum of the crop size and
|
| 1604 |
+
# the (average) sequence length. See subsection 1.9.
|
| 1605 |
+
seq_len = torch.mean(batch["seq_length"].float())
|
| 1606 |
+
crop_len = batch["aatype"].shape[-1]
|
| 1607 |
+
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
|
| 1608 |
+
|
| 1609 |
+
losses["loss"] = cum_loss.detach().clone()
|
| 1610 |
+
|
| 1611 |
+
if(not _return_breakdown):
|
| 1612 |
+
return cum_loss
|
| 1613 |
+
|
| 1614 |
+
return cum_loss, losses
|
openfold/utils/rigid_utils.py
ADDED
|
@@ -0,0 +1,1367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
from typing import Tuple, Any, Sequence, Callable, Optional
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def rot_matmul(
|
| 24 |
+
a: torch.Tensor,
|
| 25 |
+
b: torch.Tensor
|
| 26 |
+
) -> torch.Tensor:
|
| 27 |
+
"""
|
| 28 |
+
Performs matrix multiplication of two rotation matrix tensors. Written
|
| 29 |
+
out by hand to avoid AMP downcasting.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
a: [*, 3, 3] left multiplicand
|
| 33 |
+
b: [*, 3, 3] right multiplicand
|
| 34 |
+
Returns:
|
| 35 |
+
The product ab
|
| 36 |
+
"""
|
| 37 |
+
def row_mul(i):
|
| 38 |
+
return torch.stack(
|
| 39 |
+
[
|
| 40 |
+
a[..., i, 0] * b[..., 0, 0]
|
| 41 |
+
+ a[..., i, 1] * b[..., 1, 0]
|
| 42 |
+
+ a[..., i, 2] * b[..., 2, 0],
|
| 43 |
+
a[..., i, 0] * b[..., 0, 1]
|
| 44 |
+
+ a[..., i, 1] * b[..., 1, 1]
|
| 45 |
+
+ a[..., i, 2] * b[..., 2, 1],
|
| 46 |
+
a[..., i, 0] * b[..., 0, 2]
|
| 47 |
+
+ a[..., i, 1] * b[..., 1, 2]
|
| 48 |
+
+ a[..., i, 2] * b[..., 2, 2],
|
| 49 |
+
],
|
| 50 |
+
dim=-1,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
return torch.stack(
|
| 54 |
+
[
|
| 55 |
+
row_mul(0),
|
| 56 |
+
row_mul(1),
|
| 57 |
+
row_mul(2),
|
| 58 |
+
],
|
| 59 |
+
dim=-2
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def rot_vec_mul(
|
| 64 |
+
r: torch.Tensor,
|
| 65 |
+
t: torch.Tensor
|
| 66 |
+
) -> torch.Tensor:
|
| 67 |
+
"""
|
| 68 |
+
Applies a rotation to a vector. Written out by hand to avoid transfer
|
| 69 |
+
to avoid AMP downcasting.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
r: [*, 3, 3] rotation matrices
|
| 73 |
+
t: [*, 3] coordinate tensors
|
| 74 |
+
Returns:
|
| 75 |
+
[*, 3] rotated coordinates
|
| 76 |
+
"""
|
| 77 |
+
x, y, z = torch.unbind(t, dim=-1)
|
| 78 |
+
return torch.stack(
|
| 79 |
+
[
|
| 80 |
+
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
|
| 81 |
+
r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
|
| 82 |
+
r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
|
| 83 |
+
],
|
| 84 |
+
dim=-1,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def identity_rot_mats(
|
| 89 |
+
batch_dims: Tuple[int],
|
| 90 |
+
dtype: Optional[torch.dtype] = None,
|
| 91 |
+
device: Optional[torch.device] = None,
|
| 92 |
+
requires_grad: bool = True,
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
rots = torch.eye(
|
| 95 |
+
3, dtype=dtype, device=device, requires_grad=requires_grad
|
| 96 |
+
)
|
| 97 |
+
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
|
| 98 |
+
rots = rots.expand(*batch_dims, -1, -1)
|
| 99 |
+
|
| 100 |
+
return rots
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def identity_trans(
|
| 104 |
+
batch_dims: Tuple[int],
|
| 105 |
+
dtype: Optional[torch.dtype] = None,
|
| 106 |
+
device: Optional[torch.device] = None,
|
| 107 |
+
requires_grad: bool = True,
|
| 108 |
+
) -> torch.Tensor:
|
| 109 |
+
trans = torch.zeros(
|
| 110 |
+
(*batch_dims, 3),
|
| 111 |
+
dtype=dtype,
|
| 112 |
+
device=device,
|
| 113 |
+
requires_grad=requires_grad
|
| 114 |
+
)
|
| 115 |
+
return trans
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def identity_quats(
|
| 119 |
+
batch_dims: Tuple[int],
|
| 120 |
+
dtype: Optional[torch.dtype] = None,
|
| 121 |
+
device: Optional[torch.device] = None,
|
| 122 |
+
requires_grad: bool = True,
|
| 123 |
+
) -> torch.Tensor:
|
| 124 |
+
quat = torch.zeros(
|
| 125 |
+
(*batch_dims, 4),
|
| 126 |
+
dtype=dtype,
|
| 127 |
+
device=device,
|
| 128 |
+
requires_grad=requires_grad
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
quat[..., 0] = 1
|
| 133 |
+
|
| 134 |
+
return quat
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
_quat_elements = ["a", "b", "c", "d"]
|
| 138 |
+
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
|
| 139 |
+
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _to_mat(pairs):
|
| 143 |
+
mat = np.zeros((4, 4))
|
| 144 |
+
for pair in pairs:
|
| 145 |
+
key, value = pair
|
| 146 |
+
ind = _qtr_ind_dict[key]
|
| 147 |
+
mat[ind // 4][ind % 4] = value
|
| 148 |
+
|
| 149 |
+
return mat
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
_QTR_MAT = np.zeros((4, 4, 3, 3))
|
| 153 |
+
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
|
| 154 |
+
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
|
| 155 |
+
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
|
| 156 |
+
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
|
| 157 |
+
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
|
| 158 |
+
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
|
| 159 |
+
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
|
| 160 |
+
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
|
| 161 |
+
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
"""
|
| 166 |
+
Converts a quaternion to a rotation matrix.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
quat: [*, 4] quaternions
|
| 170 |
+
Returns:
|
| 171 |
+
[*, 3, 3] rotation matrices
|
| 172 |
+
"""
|
| 173 |
+
# [*, 4, 4]
|
| 174 |
+
quat = quat[..., None] * quat[..., None, :]
|
| 175 |
+
|
| 176 |
+
# [4, 4, 3, 3]
|
| 177 |
+
mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
|
| 178 |
+
|
| 179 |
+
# [*, 4, 4, 3, 3]
|
| 180 |
+
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
|
| 181 |
+
quat = quat[..., None, None] * shaped_qtr_mat
|
| 182 |
+
|
| 183 |
+
# [*, 3, 3]
|
| 184 |
+
return torch.sum(quat, dim=(-3, -4))
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def rot_to_quat(
|
| 188 |
+
rot: torch.Tensor,
|
| 189 |
+
):
|
| 190 |
+
if(rot.shape[-2:] != (3, 3)):
|
| 191 |
+
raise ValueError("Input rotation is incorrectly shaped")
|
| 192 |
+
|
| 193 |
+
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
|
| 194 |
+
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
|
| 195 |
+
|
| 196 |
+
k = [
|
| 197 |
+
[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
|
| 198 |
+
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
|
| 199 |
+
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
|
| 200 |
+
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
|
| 204 |
+
|
| 205 |
+
_, vectors = torch.linalg.eigh(k)
|
| 206 |
+
return vectors[..., -1]
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
|
| 210 |
+
_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
|
| 211 |
+
[ 0,-1, 0, 0],
|
| 212 |
+
[ 0, 0,-1, 0],
|
| 213 |
+
[ 0, 0, 0,-1]]
|
| 214 |
+
|
| 215 |
+
_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
|
| 216 |
+
[ 1, 0, 0, 0],
|
| 217 |
+
[ 0, 0, 0, 1],
|
| 218 |
+
[ 0, 0,-1, 0]]
|
| 219 |
+
|
| 220 |
+
_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
|
| 221 |
+
[ 0, 0, 0,-1],
|
| 222 |
+
[ 1, 0, 0, 0],
|
| 223 |
+
[ 0, 1, 0, 0]]
|
| 224 |
+
|
| 225 |
+
_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
|
| 226 |
+
[ 0, 0, 1, 0],
|
| 227 |
+
[ 0,-1, 0, 0],
|
| 228 |
+
[ 1, 0, 0, 0]]
|
| 229 |
+
|
| 230 |
+
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def quat_multiply(quat1, quat2):
|
| 234 |
+
"""Multiply a quaternion by another quaternion."""
|
| 235 |
+
mat = quat1.new_tensor(_QUAT_MULTIPLY)
|
| 236 |
+
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
|
| 237 |
+
return torch.sum(
|
| 238 |
+
reshaped_mat *
|
| 239 |
+
quat1[..., :, None, None] *
|
| 240 |
+
quat2[..., None, :, None],
|
| 241 |
+
dim=(-3, -2)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def quat_multiply_by_vec(quat, vec):
|
| 246 |
+
"""Multiply a quaternion by a pure-vector quaternion."""
|
| 247 |
+
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
|
| 248 |
+
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
|
| 249 |
+
return torch.sum(
|
| 250 |
+
reshaped_mat *
|
| 251 |
+
quat[..., :, None, None] *
|
| 252 |
+
vec[..., None, :, None],
|
| 253 |
+
dim=(-3, -2)
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def invert_rot_mat(rot_mat: torch.Tensor):
|
| 258 |
+
return rot_mat.transpose(-1, -2)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def invert_quat(quat: torch.Tensor):
|
| 262 |
+
quat_prime = quat.clone()
|
| 263 |
+
quat_prime[..., 1:] *= -1
|
| 264 |
+
inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
|
| 265 |
+
return inv
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class Rotation:
|
| 269 |
+
"""
|
| 270 |
+
A 3D rotation. Depending on how the object is initialized, the
|
| 271 |
+
rotation is represented by either a rotation matrix or a
|
| 272 |
+
quaternion, though both formats are made available by helper functions.
|
| 273 |
+
To simplify gradient computation, the underlying format of the
|
| 274 |
+
rotation cannot be changed in-place. Like Rigid, the class is designed
|
| 275 |
+
to mimic the behavior of a torch Tensor, almost as if each Rotation
|
| 276 |
+
object were a tensor of rotations, in one format or another.
|
| 277 |
+
"""
|
| 278 |
+
def __init__(self,
|
| 279 |
+
rot_mats: Optional[torch.Tensor] = None,
|
| 280 |
+
quats: Optional[torch.Tensor] = None,
|
| 281 |
+
normalize_quats: bool = True,
|
| 282 |
+
):
|
| 283 |
+
"""
|
| 284 |
+
Args:
|
| 285 |
+
rot_mats:
|
| 286 |
+
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
|
| 287 |
+
quats
|
| 288 |
+
quats:
|
| 289 |
+
A [*, 4] quaternion. Mutually exclusive with rot_mats. If
|
| 290 |
+
normalize_quats is not True, must be a unit quaternion
|
| 291 |
+
normalize_quats:
|
| 292 |
+
If quats is specified, whether to normalize quats
|
| 293 |
+
"""
|
| 294 |
+
if((rot_mats is None and quats is None) or
|
| 295 |
+
(rot_mats is not None and quats is not None)):
|
| 296 |
+
raise ValueError("Exactly one input argument must be specified")
|
| 297 |
+
|
| 298 |
+
if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
|
| 299 |
+
(quats is not None and quats.shape[-1] != 4)):
|
| 300 |
+
raise ValueError(
|
| 301 |
+
"Incorrectly shaped rotation matrix or quaternion"
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Force full-precision
|
| 305 |
+
if(quats is not None):
|
| 306 |
+
quats = quats.to(dtype=torch.float32)
|
| 307 |
+
if(rot_mats is not None):
|
| 308 |
+
rot_mats = rot_mats.to(dtype=torch.float32)
|
| 309 |
+
|
| 310 |
+
if(quats is not None and normalize_quats):
|
| 311 |
+
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
|
| 312 |
+
|
| 313 |
+
self._rot_mats = rot_mats
|
| 314 |
+
self._quats = quats
|
| 315 |
+
|
| 316 |
+
@staticmethod
|
| 317 |
+
def identity(
|
| 318 |
+
shape,
|
| 319 |
+
dtype: Optional[torch.dtype] = None,
|
| 320 |
+
device: Optional[torch.device] = None,
|
| 321 |
+
requires_grad: bool = True,
|
| 322 |
+
fmt: str = "quat",
|
| 323 |
+
) -> Rotation:
|
| 324 |
+
"""
|
| 325 |
+
Returns an identity Rotation.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
shape:
|
| 329 |
+
The "shape" of the resulting Rotation object. See documentation
|
| 330 |
+
for the shape property
|
| 331 |
+
dtype:
|
| 332 |
+
The torch dtype for the rotation
|
| 333 |
+
device:
|
| 334 |
+
The torch device for the new rotation
|
| 335 |
+
requires_grad:
|
| 336 |
+
Whether the underlying tensors in the new rotation object
|
| 337 |
+
should require gradient computation
|
| 338 |
+
fmt:
|
| 339 |
+
One of "quat" or "rot_mat". Determines the underlying format
|
| 340 |
+
of the new object's rotation
|
| 341 |
+
Returns:
|
| 342 |
+
A new identity rotation
|
| 343 |
+
"""
|
| 344 |
+
if(fmt == "rot_mat"):
|
| 345 |
+
rot_mats = identity_rot_mats(
|
| 346 |
+
shape, dtype, device, requires_grad,
|
| 347 |
+
)
|
| 348 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 349 |
+
elif(fmt == "quat"):
|
| 350 |
+
quats = identity_quats(shape, dtype, device, requires_grad)
|
| 351 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 352 |
+
else:
|
| 353 |
+
raise ValueError(f"Invalid format: f{fmt}")
|
| 354 |
+
|
| 355 |
+
# Magic methods
|
| 356 |
+
|
| 357 |
+
def __getitem__(self, index: Any) -> Rotation:
|
| 358 |
+
"""
|
| 359 |
+
Allows torch-style indexing over the virtual shape of the rotation
|
| 360 |
+
object. See documentation for the shape property.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
index:
|
| 364 |
+
A torch index. E.g. (1, 3, 2), or (slice(None,))
|
| 365 |
+
Returns:
|
| 366 |
+
The indexed rotation
|
| 367 |
+
"""
|
| 368 |
+
if type(index) != tuple:
|
| 369 |
+
index = (index,)
|
| 370 |
+
|
| 371 |
+
if(self._rot_mats is not None):
|
| 372 |
+
rot_mats = self._rot_mats[index + (slice(None), slice(None))]
|
| 373 |
+
return Rotation(rot_mats=rot_mats)
|
| 374 |
+
elif(self._quats is not None):
|
| 375 |
+
quats = self._quats[index + (slice(None),)]
|
| 376 |
+
return Rotation(quats=quats, normalize_quats=False)
|
| 377 |
+
else:
|
| 378 |
+
raise ValueError("Both rotations are None")
|
| 379 |
+
|
| 380 |
+
def __mul__(self,
|
| 381 |
+
right: torch.Tensor,
|
| 382 |
+
) -> Rotation:
|
| 383 |
+
"""
|
| 384 |
+
Pointwise left multiplication of the rotation with a tensor. Can be
|
| 385 |
+
used to e.g. mask the Rotation.
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
right:
|
| 389 |
+
The tensor multiplicand
|
| 390 |
+
Returns:
|
| 391 |
+
The product
|
| 392 |
+
"""
|
| 393 |
+
if not(isinstance(right, torch.Tensor)):
|
| 394 |
+
raise TypeError("The other multiplicand must be a Tensor")
|
| 395 |
+
|
| 396 |
+
if(self._rot_mats is not None):
|
| 397 |
+
rot_mats = self._rot_mats * right[..., None, None]
|
| 398 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 399 |
+
elif(self._quats is not None):
|
| 400 |
+
quats = self._quats * right[..., None]
|
| 401 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 402 |
+
else:
|
| 403 |
+
raise ValueError("Both rotations are None")
|
| 404 |
+
|
| 405 |
+
def __rmul__(self,
|
| 406 |
+
left: torch.Tensor,
|
| 407 |
+
) -> Rotation:
|
| 408 |
+
"""
|
| 409 |
+
Reverse pointwise multiplication of the rotation with a tensor.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
left:
|
| 413 |
+
The left multiplicand
|
| 414 |
+
Returns:
|
| 415 |
+
The product
|
| 416 |
+
"""
|
| 417 |
+
return self.__mul__(left)
|
| 418 |
+
|
| 419 |
+
# Properties
|
| 420 |
+
|
| 421 |
+
@property
|
| 422 |
+
def shape(self) -> torch.Size:
|
| 423 |
+
"""
|
| 424 |
+
Returns the virtual shape of the rotation object. This shape is
|
| 425 |
+
defined as the batch dimensions of the underlying rotation matrix
|
| 426 |
+
or quaternion. If the Rotation was initialized with a [10, 3, 3]
|
| 427 |
+
rotation matrix tensor, for example, the resulting shape would be
|
| 428 |
+
[10].
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
The virtual shape of the rotation object
|
| 432 |
+
"""
|
| 433 |
+
s = None
|
| 434 |
+
if(self._quats is not None):
|
| 435 |
+
s = self._quats.shape[:-1]
|
| 436 |
+
else:
|
| 437 |
+
s = self._rot_mats.shape[:-2]
|
| 438 |
+
|
| 439 |
+
return s
|
| 440 |
+
|
| 441 |
+
@property
|
| 442 |
+
def dtype(self) -> torch.dtype:
|
| 443 |
+
"""
|
| 444 |
+
Returns the dtype of the underlying rotation.
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
The dtype of the underlying rotation
|
| 448 |
+
"""
|
| 449 |
+
if(self._rot_mats is not None):
|
| 450 |
+
return self._rot_mats.dtype
|
| 451 |
+
elif(self._quats is not None):
|
| 452 |
+
return self._quats.dtype
|
| 453 |
+
else:
|
| 454 |
+
raise ValueError("Both rotations are None")
|
| 455 |
+
|
| 456 |
+
@property
|
| 457 |
+
def device(self) -> torch.device:
|
| 458 |
+
"""
|
| 459 |
+
The device of the underlying rotation
|
| 460 |
+
|
| 461 |
+
Returns:
|
| 462 |
+
The device of the underlying rotation
|
| 463 |
+
"""
|
| 464 |
+
if(self._rot_mats is not None):
|
| 465 |
+
return self._rot_mats.device
|
| 466 |
+
elif(self._quats is not None):
|
| 467 |
+
return self._quats.device
|
| 468 |
+
else:
|
| 469 |
+
raise ValueError("Both rotations are None")
|
| 470 |
+
|
| 471 |
+
@property
|
| 472 |
+
def requires_grad(self) -> bool:
|
| 473 |
+
"""
|
| 474 |
+
Returns the requires_grad property of the underlying rotation
|
| 475 |
+
|
| 476 |
+
Returns:
|
| 477 |
+
The requires_grad property of the underlying tensor
|
| 478 |
+
"""
|
| 479 |
+
if(self._rot_mats is not None):
|
| 480 |
+
return self._rot_mats.requires_grad
|
| 481 |
+
elif(self._quats is not None):
|
| 482 |
+
return self._quats.requires_grad
|
| 483 |
+
else:
|
| 484 |
+
raise ValueError("Both rotations are None")
|
| 485 |
+
|
| 486 |
+
def get_rot_mats(self) -> torch.Tensor:
|
| 487 |
+
"""
|
| 488 |
+
Returns the underlying rotation as a rotation matrix tensor.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
The rotation as a rotation matrix tensor
|
| 492 |
+
"""
|
| 493 |
+
rot_mats = self._rot_mats
|
| 494 |
+
if(rot_mats is None):
|
| 495 |
+
if(self._quats is None):
|
| 496 |
+
raise ValueError("Both rotations are None")
|
| 497 |
+
else:
|
| 498 |
+
rot_mats = quat_to_rot(self._quats)
|
| 499 |
+
|
| 500 |
+
return rot_mats
|
| 501 |
+
|
| 502 |
+
def get_quats(self) -> torch.Tensor:
|
| 503 |
+
"""
|
| 504 |
+
Returns the underlying rotation as a quaternion tensor.
|
| 505 |
+
|
| 506 |
+
Depending on whether the Rotation was initialized with a
|
| 507 |
+
quaternion, this function may call torch.linalg.eigh.
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
The rotation as a quaternion tensor.
|
| 511 |
+
"""
|
| 512 |
+
quats = self._quats
|
| 513 |
+
if(quats is None):
|
| 514 |
+
if(self._rot_mats is None):
|
| 515 |
+
raise ValueError("Both rotations are None")
|
| 516 |
+
else:
|
| 517 |
+
quats = rot_to_quat(self._rot_mats)
|
| 518 |
+
|
| 519 |
+
return quats
|
| 520 |
+
|
| 521 |
+
def get_cur_rot(self) -> torch.Tensor:
|
| 522 |
+
"""
|
| 523 |
+
Return the underlying rotation in its current form
|
| 524 |
+
|
| 525 |
+
Returns:
|
| 526 |
+
The stored rotation
|
| 527 |
+
"""
|
| 528 |
+
if(self._rot_mats is not None):
|
| 529 |
+
return self._rot_mats
|
| 530 |
+
elif(self._quats is not None):
|
| 531 |
+
return self._quats
|
| 532 |
+
else:
|
| 533 |
+
raise ValueError("Both rotations are None")
|
| 534 |
+
|
| 535 |
+
# Rotation functions
|
| 536 |
+
|
| 537 |
+
def compose_q_update_vec(self,
|
| 538 |
+
q_update_vec: torch.Tensor,
|
| 539 |
+
normalize_quats: bool = True
|
| 540 |
+
) -> Rotation:
|
| 541 |
+
"""
|
| 542 |
+
Returns a new quaternion Rotation after updating the current
|
| 543 |
+
object's underlying rotation with a quaternion update, formatted
|
| 544 |
+
as a [*, 3] tensor whose final three columns represent x, y, z such
|
| 545 |
+
that (1, x, y, z) is the desired (not necessarily unit) quaternion
|
| 546 |
+
update.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
q_update_vec:
|
| 550 |
+
A [*, 3] quaternion update tensor
|
| 551 |
+
normalize_quats:
|
| 552 |
+
Whether to normalize the output quaternion
|
| 553 |
+
Returns:
|
| 554 |
+
An updated Rotation
|
| 555 |
+
"""
|
| 556 |
+
quats = self.get_quats()
|
| 557 |
+
new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
|
| 558 |
+
return Rotation(
|
| 559 |
+
rot_mats=None,
|
| 560 |
+
quats=new_quats,
|
| 561 |
+
normalize_quats=normalize_quats,
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
def compose_r(self, r: Rotation) -> Rotation:
|
| 565 |
+
"""
|
| 566 |
+
Compose the rotation matrices of the current Rotation object with
|
| 567 |
+
those of another.
|
| 568 |
+
|
| 569 |
+
Args:
|
| 570 |
+
r:
|
| 571 |
+
An update rotation object
|
| 572 |
+
Returns:
|
| 573 |
+
An updated rotation object
|
| 574 |
+
"""
|
| 575 |
+
r1 = self.get_rot_mats()
|
| 576 |
+
r2 = r.get_rot_mats()
|
| 577 |
+
new_rot_mats = rot_matmul(r1, r2)
|
| 578 |
+
return Rotation(rot_mats=new_rot_mats, quats=None)
|
| 579 |
+
|
| 580 |
+
def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
|
| 581 |
+
"""
|
| 582 |
+
Compose the quaternions of the current Rotation object with those
|
| 583 |
+
of another.
|
| 584 |
+
|
| 585 |
+
Depending on whether either Rotation was initialized with
|
| 586 |
+
quaternions, this function may call torch.linalg.eigh.
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
r:
|
| 590 |
+
An update rotation object
|
| 591 |
+
Returns:
|
| 592 |
+
An updated rotation object
|
| 593 |
+
"""
|
| 594 |
+
q1 = self.get_quats()
|
| 595 |
+
q2 = r.get_quats()
|
| 596 |
+
new_quats = quat_multiply(q1, q2)
|
| 597 |
+
return Rotation(
|
| 598 |
+
rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
def apply(self, pts: torch.Tensor) -> torch.Tensor:
|
| 602 |
+
"""
|
| 603 |
+
Apply the current Rotation as a rotation matrix to a set of 3D
|
| 604 |
+
coordinates.
|
| 605 |
+
|
| 606 |
+
Args:
|
| 607 |
+
pts:
|
| 608 |
+
A [*, 3] set of points
|
| 609 |
+
Returns:
|
| 610 |
+
[*, 3] rotated points
|
| 611 |
+
"""
|
| 612 |
+
rot_mats = self.get_rot_mats()
|
| 613 |
+
return rot_vec_mul(rot_mats, pts)
|
| 614 |
+
|
| 615 |
+
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
|
| 616 |
+
"""
|
| 617 |
+
The inverse of the apply() method.
|
| 618 |
+
|
| 619 |
+
Args:
|
| 620 |
+
pts:
|
| 621 |
+
A [*, 3] set of points
|
| 622 |
+
Returns:
|
| 623 |
+
[*, 3] inverse-rotated points
|
| 624 |
+
"""
|
| 625 |
+
rot_mats = self.get_rot_mats()
|
| 626 |
+
inv_rot_mats = invert_rot_mat(rot_mats)
|
| 627 |
+
return rot_vec_mul(inv_rot_mats, pts)
|
| 628 |
+
|
| 629 |
+
def invert(self) -> Rotation:
|
| 630 |
+
"""
|
| 631 |
+
Returns the inverse of the current Rotation.
|
| 632 |
+
|
| 633 |
+
Returns:
|
| 634 |
+
The inverse of the current Rotation
|
| 635 |
+
"""
|
| 636 |
+
if(self._rot_mats is not None):
|
| 637 |
+
return Rotation(
|
| 638 |
+
rot_mats=invert_rot_mat(self._rot_mats),
|
| 639 |
+
quats=None
|
| 640 |
+
)
|
| 641 |
+
elif(self._quats is not None):
|
| 642 |
+
return Rotation(
|
| 643 |
+
rot_mats=None,
|
| 644 |
+
quats=invert_quat(self._quats),
|
| 645 |
+
normalize_quats=False,
|
| 646 |
+
)
|
| 647 |
+
else:
|
| 648 |
+
raise ValueError("Both rotations are None")
|
| 649 |
+
|
| 650 |
+
# "Tensor" stuff
|
| 651 |
+
|
| 652 |
+
def unsqueeze(self,
|
| 653 |
+
dim: int,
|
| 654 |
+
) -> Rigid:
|
| 655 |
+
"""
|
| 656 |
+
Analogous to torch.unsqueeze. The dimension is relative to the
|
| 657 |
+
shape of the Rotation object.
|
| 658 |
+
|
| 659 |
+
Args:
|
| 660 |
+
dim: A positive or negative dimension index.
|
| 661 |
+
Returns:
|
| 662 |
+
The unsqueezed Rotation.
|
| 663 |
+
"""
|
| 664 |
+
if dim >= len(self.shape):
|
| 665 |
+
raise ValueError("Invalid dimension")
|
| 666 |
+
|
| 667 |
+
if(self._rot_mats is not None):
|
| 668 |
+
rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
|
| 669 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 670 |
+
elif(self._quats is not None):
|
| 671 |
+
quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
|
| 672 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 673 |
+
else:
|
| 674 |
+
raise ValueError("Both rotations are None")
|
| 675 |
+
|
| 676 |
+
@staticmethod
|
| 677 |
+
def cat(
|
| 678 |
+
rs: Sequence[Rotation],
|
| 679 |
+
dim: int,
|
| 680 |
+
) -> Rigid:
|
| 681 |
+
"""
|
| 682 |
+
Concatenates rotations along one of the batch dimensions. Analogous
|
| 683 |
+
to torch.cat().
|
| 684 |
+
|
| 685 |
+
Note that the output of this operation is always a rotation matrix,
|
| 686 |
+
regardless of the format of input rotations.
|
| 687 |
+
|
| 688 |
+
Args:
|
| 689 |
+
rs:
|
| 690 |
+
A list of rotation objects
|
| 691 |
+
dim:
|
| 692 |
+
The dimension along which the rotations should be
|
| 693 |
+
concatenated
|
| 694 |
+
Returns:
|
| 695 |
+
A concatenated Rotation object in rotation matrix format
|
| 696 |
+
"""
|
| 697 |
+
rot_mats = [r.get_rot_mats() for r in rs]
|
| 698 |
+
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
|
| 699 |
+
|
| 700 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 701 |
+
|
| 702 |
+
def map_tensor_fn(self,
|
| 703 |
+
fn: Callable[torch.Tensor, torch.Tensor]
|
| 704 |
+
) -> Rotation:
|
| 705 |
+
"""
|
| 706 |
+
Apply a Tensor -> Tensor function to underlying rotation tensors,
|
| 707 |
+
mapping over the rotation dimension(s). Can be used e.g. to sum out
|
| 708 |
+
a one-hot batch dimension.
|
| 709 |
+
|
| 710 |
+
Args:
|
| 711 |
+
fn:
|
| 712 |
+
A Tensor -> Tensor function to be mapped over the Rotation
|
| 713 |
+
Returns:
|
| 714 |
+
The transformed Rotation object
|
| 715 |
+
"""
|
| 716 |
+
if(self._rot_mats is not None):
|
| 717 |
+
rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
|
| 718 |
+
rot_mats = torch.stack(
|
| 719 |
+
list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
|
| 720 |
+
)
|
| 721 |
+
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
|
| 722 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 723 |
+
elif(self._quats is not None):
|
| 724 |
+
quats = torch.stack(
|
| 725 |
+
list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
|
| 726 |
+
)
|
| 727 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 728 |
+
else:
|
| 729 |
+
raise ValueError("Both rotations are None")
|
| 730 |
+
|
| 731 |
+
def cuda(self) -> Rotation:
|
| 732 |
+
"""
|
| 733 |
+
Analogous to the cuda() method of torch Tensors
|
| 734 |
+
|
| 735 |
+
Returns:
|
| 736 |
+
A copy of the Rotation in CUDA memory
|
| 737 |
+
"""
|
| 738 |
+
if(self._rot_mats is not None):
|
| 739 |
+
return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
|
| 740 |
+
elif(self._quats is not None):
|
| 741 |
+
return Rotation(
|
| 742 |
+
rot_mats=None,
|
| 743 |
+
quats=self._quats.cuda(),
|
| 744 |
+
normalize_quats=False
|
| 745 |
+
)
|
| 746 |
+
else:
|
| 747 |
+
raise ValueError("Both rotations are None")
|
| 748 |
+
|
| 749 |
+
def to(self,
|
| 750 |
+
device: Optional[torch.device],
|
| 751 |
+
dtype: Optional[torch.dtype]
|
| 752 |
+
) -> Rotation:
|
| 753 |
+
"""
|
| 754 |
+
Analogous to the to() method of torch Tensors
|
| 755 |
+
|
| 756 |
+
Args:
|
| 757 |
+
device:
|
| 758 |
+
A torch device
|
| 759 |
+
dtype:
|
| 760 |
+
A torch dtype
|
| 761 |
+
Returns:
|
| 762 |
+
A copy of the Rotation using the new device and dtype
|
| 763 |
+
"""
|
| 764 |
+
if(self._rot_mats is not None):
|
| 765 |
+
return Rotation(
|
| 766 |
+
rot_mats=self._rot_mats.to(device=device, dtype=dtype),
|
| 767 |
+
quats=None,
|
| 768 |
+
)
|
| 769 |
+
elif(self._quats is not None):
|
| 770 |
+
return Rotation(
|
| 771 |
+
rot_mats=None,
|
| 772 |
+
quats=self._quats.to(device=device, dtype=dtype),
|
| 773 |
+
normalize_quats=False,
|
| 774 |
+
)
|
| 775 |
+
else:
|
| 776 |
+
raise ValueError("Both rotations are None")
|
| 777 |
+
|
| 778 |
+
def detach(self) -> Rotation:
|
| 779 |
+
"""
|
| 780 |
+
Returns a copy of the Rotation whose underlying Tensor has been
|
| 781 |
+
detached from its torch graph.
|
| 782 |
+
|
| 783 |
+
Returns:
|
| 784 |
+
A copy of the Rotation whose underlying Tensor has been detached
|
| 785 |
+
from its torch graph
|
| 786 |
+
"""
|
| 787 |
+
if(self._rot_mats is not None):
|
| 788 |
+
return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
|
| 789 |
+
elif(self._quats is not None):
|
| 790 |
+
return Rotation(
|
| 791 |
+
rot_mats=None,
|
| 792 |
+
quats=self._quats.detach(),
|
| 793 |
+
normalize_quats=False,
|
| 794 |
+
)
|
| 795 |
+
else:
|
| 796 |
+
raise ValueError("Both rotations are None")
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
class Rigid:
|
| 800 |
+
"""
|
| 801 |
+
A class representing a rigid transformation. Little more than a wrapper
|
| 802 |
+
around two objects: a Rotation object and a [*, 3] translation
|
| 803 |
+
Designed to behave approximately like a single torch tensor with the
|
| 804 |
+
shape of the shared batch dimensions of its component parts.
|
| 805 |
+
"""
|
| 806 |
+
def __init__(self,
|
| 807 |
+
rots: Optional[Rotation],
|
| 808 |
+
trans: Optional[torch.Tensor],
|
| 809 |
+
):
|
| 810 |
+
"""
|
| 811 |
+
Args:
|
| 812 |
+
rots: A [*, 3, 3] rotation tensor
|
| 813 |
+
trans: A corresponding [*, 3] translation tensor
|
| 814 |
+
"""
|
| 815 |
+
# (we need device, dtype, etc. from at least one input)
|
| 816 |
+
|
| 817 |
+
batch_dims, dtype, device, requires_grad = None, None, None, None
|
| 818 |
+
if(trans is not None):
|
| 819 |
+
batch_dims = trans.shape[:-1]
|
| 820 |
+
dtype = trans.dtype
|
| 821 |
+
device = trans.device
|
| 822 |
+
requires_grad = trans.requires_grad
|
| 823 |
+
elif(rots is not None):
|
| 824 |
+
batch_dims = rots.shape
|
| 825 |
+
dtype = rots.dtype
|
| 826 |
+
device = rots.device
|
| 827 |
+
requires_grad = rots.requires_grad
|
| 828 |
+
else:
|
| 829 |
+
raise ValueError("At least one input argument must be specified")
|
| 830 |
+
|
| 831 |
+
if(rots is None):
|
| 832 |
+
rots = Rotation.identity(
|
| 833 |
+
batch_dims, dtype, device, requires_grad,
|
| 834 |
+
)
|
| 835 |
+
elif(trans is None):
|
| 836 |
+
trans = identity_trans(
|
| 837 |
+
batch_dims, dtype, device, requires_grad,
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
if((rots.shape != trans.shape[:-1]) or
|
| 841 |
+
(rots.device != trans.device)):
|
| 842 |
+
raise ValueError("Rots and trans incompatible")
|
| 843 |
+
|
| 844 |
+
# Force full precision. Happens to the rotations automatically.
|
| 845 |
+
trans = trans.to(dtype=torch.float32)
|
| 846 |
+
|
| 847 |
+
self._rots = rots
|
| 848 |
+
self._trans = trans
|
| 849 |
+
|
| 850 |
+
@staticmethod
|
| 851 |
+
def identity(
|
| 852 |
+
shape: Tuple[int],
|
| 853 |
+
dtype: Optional[torch.dtype] = None,
|
| 854 |
+
device: Optional[torch.device] = None,
|
| 855 |
+
requires_grad: bool = True,
|
| 856 |
+
fmt: str = "quat",
|
| 857 |
+
) -> Rigid:
|
| 858 |
+
"""
|
| 859 |
+
Constructs an identity transformation.
|
| 860 |
+
|
| 861 |
+
Args:
|
| 862 |
+
shape:
|
| 863 |
+
The desired shape
|
| 864 |
+
dtype:
|
| 865 |
+
The dtype of both internal tensors
|
| 866 |
+
device:
|
| 867 |
+
The device of both internal tensors
|
| 868 |
+
requires_grad:
|
| 869 |
+
Whether grad should be enabled for the internal tensors
|
| 870 |
+
Returns:
|
| 871 |
+
The identity transformation
|
| 872 |
+
"""
|
| 873 |
+
return Rigid(
|
| 874 |
+
Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
|
| 875 |
+
identity_trans(shape, dtype, device, requires_grad),
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
def __getitem__(self,
|
| 879 |
+
index: Any,
|
| 880 |
+
) -> Rigid:
|
| 881 |
+
"""
|
| 882 |
+
Indexes the affine transformation with PyTorch-style indices.
|
| 883 |
+
The index is applied to the shared dimensions of both the rotation
|
| 884 |
+
and the translation.
|
| 885 |
+
|
| 886 |
+
E.g.::
|
| 887 |
+
|
| 888 |
+
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
|
| 889 |
+
t = Rigid(r, torch.rand(10, 10, 3))
|
| 890 |
+
indexed = t[3, 4:6]
|
| 891 |
+
assert(indexed.shape == (2,))
|
| 892 |
+
assert(indexed.get_rots().shape == (2,))
|
| 893 |
+
assert(indexed.get_trans().shape == (2, 3))
|
| 894 |
+
|
| 895 |
+
Args:
|
| 896 |
+
index: A standard torch tensor index. E.g. 8, (10, None, 3),
|
| 897 |
+
or (3, slice(0, 1, None))
|
| 898 |
+
Returns:
|
| 899 |
+
The indexed tensor
|
| 900 |
+
"""
|
| 901 |
+
if type(index) != tuple:
|
| 902 |
+
index = (index,)
|
| 903 |
+
|
| 904 |
+
return Rigid(
|
| 905 |
+
self._rots[index],
|
| 906 |
+
self._trans[index + (slice(None),)],
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
def __mul__(self,
|
| 910 |
+
right: torch.Tensor,
|
| 911 |
+
) -> Rigid:
|
| 912 |
+
"""
|
| 913 |
+
Pointwise left multiplication of the transformation with a tensor.
|
| 914 |
+
Can be used to e.g. mask the Rigid.
|
| 915 |
+
|
| 916 |
+
Args:
|
| 917 |
+
right:
|
| 918 |
+
The tensor multiplicand
|
| 919 |
+
Returns:
|
| 920 |
+
The product
|
| 921 |
+
"""
|
| 922 |
+
if not(isinstance(right, torch.Tensor)):
|
| 923 |
+
raise TypeError("The other multiplicand must be a Tensor")
|
| 924 |
+
|
| 925 |
+
new_rots = self._rots * right
|
| 926 |
+
new_trans = self._trans * right[..., None]
|
| 927 |
+
|
| 928 |
+
return Rigid(new_rots, new_trans)
|
| 929 |
+
|
| 930 |
+
def __rmul__(self,
|
| 931 |
+
left: torch.Tensor,
|
| 932 |
+
) -> Rigid:
|
| 933 |
+
"""
|
| 934 |
+
Reverse pointwise multiplication of the transformation with a
|
| 935 |
+
tensor.
|
| 936 |
+
|
| 937 |
+
Args:
|
| 938 |
+
left:
|
| 939 |
+
The left multiplicand
|
| 940 |
+
Returns:
|
| 941 |
+
The product
|
| 942 |
+
"""
|
| 943 |
+
return self.__mul__(left)
|
| 944 |
+
|
| 945 |
+
@property
|
| 946 |
+
def shape(self) -> torch.Size:
|
| 947 |
+
"""
|
| 948 |
+
Returns the shape of the shared dimensions of the rotation and
|
| 949 |
+
the translation.
|
| 950 |
+
|
| 951 |
+
Returns:
|
| 952 |
+
The shape of the transformation
|
| 953 |
+
"""
|
| 954 |
+
s = self._trans.shape[:-1]
|
| 955 |
+
return s
|
| 956 |
+
|
| 957 |
+
@property
|
| 958 |
+
def device(self) -> torch.device:
|
| 959 |
+
"""
|
| 960 |
+
Returns the device on which the Rigid's tensors are located.
|
| 961 |
+
|
| 962 |
+
Returns:
|
| 963 |
+
The device on which the Rigid's tensors are located
|
| 964 |
+
"""
|
| 965 |
+
return self._trans.device
|
| 966 |
+
|
| 967 |
+
def get_rots(self) -> Rotation:
|
| 968 |
+
"""
|
| 969 |
+
Getter for the rotation.
|
| 970 |
+
|
| 971 |
+
Returns:
|
| 972 |
+
The rotation object
|
| 973 |
+
"""
|
| 974 |
+
return self._rots
|
| 975 |
+
|
| 976 |
+
def get_trans(self) -> torch.Tensor:
|
| 977 |
+
"""
|
| 978 |
+
Getter for the translation.
|
| 979 |
+
|
| 980 |
+
Returns:
|
| 981 |
+
The stored translation
|
| 982 |
+
"""
|
| 983 |
+
return self._trans
|
| 984 |
+
|
| 985 |
+
def compose_q_update_vec(self,
|
| 986 |
+
q_update_vec: torch.Tensor,
|
| 987 |
+
) -> Rigid:
|
| 988 |
+
"""
|
| 989 |
+
Composes the transformation with a quaternion update vector of
|
| 990 |
+
shape [*, 6], where the final 6 columns represent the x, y, and
|
| 991 |
+
z values of a quaternion of form (1, x, y, z) followed by a 3D
|
| 992 |
+
translation.
|
| 993 |
+
|
| 994 |
+
Args:
|
| 995 |
+
q_vec: The quaternion update vector.
|
| 996 |
+
Returns:
|
| 997 |
+
The composed transformation.
|
| 998 |
+
"""
|
| 999 |
+
q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
|
| 1000 |
+
new_rots = self._rots.compose_q_update_vec(q_vec)
|
| 1001 |
+
|
| 1002 |
+
trans_update = self._rots.apply(t_vec)
|
| 1003 |
+
new_translation = self._trans + trans_update
|
| 1004 |
+
|
| 1005 |
+
return Rigid(new_rots, new_translation)
|
| 1006 |
+
|
| 1007 |
+
def compose(self,
|
| 1008 |
+
r: Rigid,
|
| 1009 |
+
) -> Rigid:
|
| 1010 |
+
"""
|
| 1011 |
+
Composes the current rigid object with another.
|
| 1012 |
+
|
| 1013 |
+
Args:
|
| 1014 |
+
r:
|
| 1015 |
+
Another Rigid object
|
| 1016 |
+
Returns:
|
| 1017 |
+
The composition of the two transformations
|
| 1018 |
+
"""
|
| 1019 |
+
new_rot = self._rots.compose_r(r._rots)
|
| 1020 |
+
new_trans = self._rots.apply(r._trans) + self._trans
|
| 1021 |
+
return Rigid(new_rot, new_trans)
|
| 1022 |
+
|
| 1023 |
+
def apply(self,
|
| 1024 |
+
pts: torch.Tensor,
|
| 1025 |
+
) -> torch.Tensor:
|
| 1026 |
+
"""
|
| 1027 |
+
Applies the transformation to a coordinate tensor.
|
| 1028 |
+
|
| 1029 |
+
Args:
|
| 1030 |
+
pts: A [*, 3] coordinate tensor.
|
| 1031 |
+
Returns:
|
| 1032 |
+
The transformed points.
|
| 1033 |
+
"""
|
| 1034 |
+
rotated = self._rots.apply(pts)
|
| 1035 |
+
return rotated + self._trans
|
| 1036 |
+
|
| 1037 |
+
def invert_apply(self,
|
| 1038 |
+
pts: torch.Tensor
|
| 1039 |
+
) -> torch.Tensor:
|
| 1040 |
+
"""
|
| 1041 |
+
Applies the inverse of the transformation to a coordinate tensor.
|
| 1042 |
+
|
| 1043 |
+
Args:
|
| 1044 |
+
pts: A [*, 3] coordinate tensor
|
| 1045 |
+
Returns:
|
| 1046 |
+
The transformed points.
|
| 1047 |
+
"""
|
| 1048 |
+
pts = pts - self._trans
|
| 1049 |
+
return self._rots.invert_apply(pts)
|
| 1050 |
+
|
| 1051 |
+
def invert(self) -> Rigid:
|
| 1052 |
+
"""
|
| 1053 |
+
Inverts the transformation.
|
| 1054 |
+
|
| 1055 |
+
Returns:
|
| 1056 |
+
The inverse transformation.
|
| 1057 |
+
"""
|
| 1058 |
+
rot_inv = self._rots.invert()
|
| 1059 |
+
trn_inv = rot_inv.apply(self._trans)
|
| 1060 |
+
|
| 1061 |
+
return Rigid(rot_inv, -1 * trn_inv)
|
| 1062 |
+
|
| 1063 |
+
def map_tensor_fn(self,
|
| 1064 |
+
fn: Callable[torch.Tensor, torch.Tensor]
|
| 1065 |
+
) -> Rigid:
|
| 1066 |
+
"""
|
| 1067 |
+
Apply a Tensor -> Tensor function to underlying translation and
|
| 1068 |
+
rotation tensors, mapping over the translation/rotation dimensions
|
| 1069 |
+
respectively.
|
| 1070 |
+
|
| 1071 |
+
Args:
|
| 1072 |
+
fn:
|
| 1073 |
+
A Tensor -> Tensor function to be mapped over the Rigid
|
| 1074 |
+
Returns:
|
| 1075 |
+
The transformed Rigid object
|
| 1076 |
+
"""
|
| 1077 |
+
new_rots = self._rots.map_tensor_fn(fn)
|
| 1078 |
+
new_trans = torch.stack(
|
| 1079 |
+
list(map(fn, torch.unbind(self._trans, dim=-1))),
|
| 1080 |
+
dim=-1
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
return Rigid(new_rots, new_trans)
|
| 1084 |
+
|
| 1085 |
+
def to_tensor_4x4(self) -> torch.Tensor:
|
| 1086 |
+
"""
|
| 1087 |
+
Converts a transformation to a homogenous transformation tensor.
|
| 1088 |
+
|
| 1089 |
+
Returns:
|
| 1090 |
+
A [*, 4, 4] homogenous transformation tensor
|
| 1091 |
+
"""
|
| 1092 |
+
tensor = self._trans.new_zeros((*self.shape, 4, 4))
|
| 1093 |
+
tensor[..., :3, :3] = self._rots.get_rot_mats()
|
| 1094 |
+
tensor[..., :3, 3] = self._trans
|
| 1095 |
+
tensor[..., 3, 3] = 1
|
| 1096 |
+
return tensor
|
| 1097 |
+
|
| 1098 |
+
@staticmethod
|
| 1099 |
+
def from_tensor_4x4(
|
| 1100 |
+
t: torch.Tensor
|
| 1101 |
+
) -> Rigid:
|
| 1102 |
+
"""
|
| 1103 |
+
Constructs a transformation from a homogenous transformation
|
| 1104 |
+
tensor.
|
| 1105 |
+
|
| 1106 |
+
Args:
|
| 1107 |
+
t: [*, 4, 4] homogenous transformation tensor
|
| 1108 |
+
Returns:
|
| 1109 |
+
T object with shape [*]
|
| 1110 |
+
"""
|
| 1111 |
+
if(t.shape[-2:] != (4, 4)):
|
| 1112 |
+
raise ValueError("Incorrectly shaped input tensor")
|
| 1113 |
+
|
| 1114 |
+
rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
|
| 1115 |
+
trans = t[..., :3, 3]
|
| 1116 |
+
|
| 1117 |
+
return Rigid(rots, trans)
|
| 1118 |
+
|
| 1119 |
+
def to_tensor_7(self) -> torch.Tensor:
|
| 1120 |
+
"""
|
| 1121 |
+
Converts a transformation to a tensor with 7 final columns, four
|
| 1122 |
+
for the quaternion followed by three for the translation.
|
| 1123 |
+
|
| 1124 |
+
Returns:
|
| 1125 |
+
A [*, 7] tensor representation of the transformation
|
| 1126 |
+
"""
|
| 1127 |
+
tensor = self._trans.new_zeros((*self.shape, 7))
|
| 1128 |
+
tensor[..., :4] = self._rots.get_quats()
|
| 1129 |
+
tensor[..., 4:] = self._trans
|
| 1130 |
+
|
| 1131 |
+
return tensor
|
| 1132 |
+
|
| 1133 |
+
@staticmethod
|
| 1134 |
+
def from_tensor_7(
|
| 1135 |
+
t: torch.Tensor,
|
| 1136 |
+
normalize_quats: bool = False,
|
| 1137 |
+
) -> Rigid:
|
| 1138 |
+
if(t.shape[-1] != 7):
|
| 1139 |
+
raise ValueError("Incorrectly shaped input tensor")
|
| 1140 |
+
|
| 1141 |
+
quats, trans = t[..., :4], t[..., 4:]
|
| 1142 |
+
|
| 1143 |
+
rots = Rotation(
|
| 1144 |
+
rot_mats=None,
|
| 1145 |
+
quats=quats,
|
| 1146 |
+
normalize_quats=normalize_quats
|
| 1147 |
+
)
|
| 1148 |
+
|
| 1149 |
+
return Rigid(rots, trans)
|
| 1150 |
+
|
| 1151 |
+
@staticmethod
|
| 1152 |
+
def from_3_points(
|
| 1153 |
+
p_neg_x_axis: torch.Tensor,
|
| 1154 |
+
origin: torch.Tensor,
|
| 1155 |
+
p_xy_plane: torch.Tensor,
|
| 1156 |
+
eps: float = 1e-8
|
| 1157 |
+
) -> Rigid:
|
| 1158 |
+
"""
|
| 1159 |
+
Implements algorithm 21. Constructs transformations from sets of 3
|
| 1160 |
+
points using the Gram-Schmidt algorithm.
|
| 1161 |
+
|
| 1162 |
+
Args:
|
| 1163 |
+
p_neg_x_axis: [*, 3] coordinates
|
| 1164 |
+
origin: [*, 3] coordinates used as frame origins
|
| 1165 |
+
p_xy_plane: [*, 3] coordinates
|
| 1166 |
+
eps: Small epsilon value
|
| 1167 |
+
Returns:
|
| 1168 |
+
A transformation object of shape [*]
|
| 1169 |
+
"""
|
| 1170 |
+
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
|
| 1171 |
+
origin = torch.unbind(origin, dim=-1)
|
| 1172 |
+
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
|
| 1173 |
+
|
| 1174 |
+
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
|
| 1175 |
+
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
|
| 1176 |
+
|
| 1177 |
+
denom = torch.sqrt(sum((c * c for c in e0)) + eps)
|
| 1178 |
+
e0 = [c / denom for c in e0]
|
| 1179 |
+
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
|
| 1180 |
+
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
|
| 1181 |
+
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
|
| 1182 |
+
e1 = [c / denom for c in e1]
|
| 1183 |
+
e2 = [
|
| 1184 |
+
e0[1] * e1[2] - e0[2] * e1[1],
|
| 1185 |
+
e0[2] * e1[0] - e0[0] * e1[2],
|
| 1186 |
+
e0[0] * e1[1] - e0[1] * e1[0],
|
| 1187 |
+
]
|
| 1188 |
+
|
| 1189 |
+
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
|
| 1190 |
+
rots = rots.reshape(rots.shape[:-1] + (3, 3))
|
| 1191 |
+
|
| 1192 |
+
rot_obj = Rotation(rot_mats=rots, quats=None)
|
| 1193 |
+
|
| 1194 |
+
return Rigid(rot_obj, torch.stack(origin, dim=-1))
|
| 1195 |
+
|
| 1196 |
+
def unsqueeze(self,
|
| 1197 |
+
dim: int,
|
| 1198 |
+
) -> Rigid:
|
| 1199 |
+
"""
|
| 1200 |
+
Analogous to torch.unsqueeze. The dimension is relative to the
|
| 1201 |
+
shared dimensions of the rotation/translation.
|
| 1202 |
+
|
| 1203 |
+
Args:
|
| 1204 |
+
dim: A positive or negative dimension index.
|
| 1205 |
+
Returns:
|
| 1206 |
+
The unsqueezed transformation.
|
| 1207 |
+
"""
|
| 1208 |
+
if dim >= len(self.shape):
|
| 1209 |
+
raise ValueError("Invalid dimension")
|
| 1210 |
+
rots = self._rots.unsqueeze(dim)
|
| 1211 |
+
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
|
| 1212 |
+
|
| 1213 |
+
return Rigid(rots, trans)
|
| 1214 |
+
|
| 1215 |
+
@staticmethod
|
| 1216 |
+
def cat(
|
| 1217 |
+
ts: Sequence[Rigid],
|
| 1218 |
+
dim: int,
|
| 1219 |
+
) -> Rigid:
|
| 1220 |
+
"""
|
| 1221 |
+
Concatenates transformations along a new dimension.
|
| 1222 |
+
|
| 1223 |
+
Args:
|
| 1224 |
+
ts:
|
| 1225 |
+
A list of T objects
|
| 1226 |
+
dim:
|
| 1227 |
+
The dimension along which the transformations should be
|
| 1228 |
+
concatenated
|
| 1229 |
+
Returns:
|
| 1230 |
+
A concatenated transformation object
|
| 1231 |
+
"""
|
| 1232 |
+
rots = Rotation.cat([t._rots for t in ts], dim)
|
| 1233 |
+
trans = torch.cat(
|
| 1234 |
+
[t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
|
| 1235 |
+
)
|
| 1236 |
+
|
| 1237 |
+
return Rigid(rots, trans)
|
| 1238 |
+
|
| 1239 |
+
def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid:
|
| 1240 |
+
"""
|
| 1241 |
+
Applies a Rotation -> Rotation function to the stored rotation
|
| 1242 |
+
object.
|
| 1243 |
+
|
| 1244 |
+
Args:
|
| 1245 |
+
fn: A function of type Rotation -> Rotation
|
| 1246 |
+
Returns:
|
| 1247 |
+
A transformation object with a transformed rotation.
|
| 1248 |
+
"""
|
| 1249 |
+
return Rigid(fn(self._rots), self._trans)
|
| 1250 |
+
|
| 1251 |
+
def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
|
| 1252 |
+
"""
|
| 1253 |
+
Applies a Tensor -> Tensor function to the stored translation.
|
| 1254 |
+
|
| 1255 |
+
Args:
|
| 1256 |
+
fn:
|
| 1257 |
+
A function of type Tensor -> Tensor to be applied to the
|
| 1258 |
+
translation
|
| 1259 |
+
Returns:
|
| 1260 |
+
A transformation object with a transformed translation.
|
| 1261 |
+
"""
|
| 1262 |
+
return Rigid(self._rots, fn(self._trans))
|
| 1263 |
+
|
| 1264 |
+
def scale_translation(self, trans_scale_factor: float) -> Rigid:
|
| 1265 |
+
"""
|
| 1266 |
+
Scales the translation by a constant factor.
|
| 1267 |
+
|
| 1268 |
+
Args:
|
| 1269 |
+
trans_scale_factor:
|
| 1270 |
+
The constant factor
|
| 1271 |
+
Returns:
|
| 1272 |
+
A transformation object with a scaled translation.
|
| 1273 |
+
"""
|
| 1274 |
+
fn = lambda t: t * trans_scale_factor
|
| 1275 |
+
return self.apply_trans_fn(fn)
|
| 1276 |
+
|
| 1277 |
+
def stop_rot_gradient(self) -> Rigid:
|
| 1278 |
+
"""
|
| 1279 |
+
Detaches the underlying rotation object
|
| 1280 |
+
|
| 1281 |
+
Returns:
|
| 1282 |
+
A transformation object with detached rotations
|
| 1283 |
+
"""
|
| 1284 |
+
fn = lambda r: r.detach()
|
| 1285 |
+
return self.apply_rot_fn(fn)
|
| 1286 |
+
|
| 1287 |
+
@staticmethod
|
| 1288 |
+
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
|
| 1289 |
+
"""
|
| 1290 |
+
Returns a transformation object from reference coordinates.
|
| 1291 |
+
|
| 1292 |
+
Note that this method does not take care of symmetries. If you
|
| 1293 |
+
provide the atom positions in the non-standard way, the N atom will
|
| 1294 |
+
end up not at [-0.527250, 1.359329, 0.0] but instead at
|
| 1295 |
+
[-0.527250, -1.359329, 0.0]. You need to take care of such cases in
|
| 1296 |
+
your code.
|
| 1297 |
+
|
| 1298 |
+
Args:
|
| 1299 |
+
n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
|
| 1300 |
+
ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
|
| 1301 |
+
c_xyz: A [*, 3] tensor of carbon xyz coordinates.
|
| 1302 |
+
Returns:
|
| 1303 |
+
A transformation object. After applying the translation and
|
| 1304 |
+
rotation to the reference backbone, the coordinates will
|
| 1305 |
+
approximately equal to the input coordinates.
|
| 1306 |
+
"""
|
| 1307 |
+
translation = -1 * ca_xyz
|
| 1308 |
+
n_xyz = n_xyz + translation
|
| 1309 |
+
c_xyz = c_xyz + translation
|
| 1310 |
+
|
| 1311 |
+
c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
|
| 1312 |
+
norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2)
|
| 1313 |
+
sin_c1 = -c_y / norm
|
| 1314 |
+
cos_c1 = c_x / norm
|
| 1315 |
+
zeros = sin_c1.new_zeros(sin_c1.shape)
|
| 1316 |
+
ones = sin_c1.new_ones(sin_c1.shape)
|
| 1317 |
+
|
| 1318 |
+
c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
|
| 1319 |
+
c1_rots[..., 0, 0] = cos_c1
|
| 1320 |
+
c1_rots[..., 0, 1] = -1 * sin_c1
|
| 1321 |
+
c1_rots[..., 1, 0] = sin_c1
|
| 1322 |
+
c1_rots[..., 1, 1] = cos_c1
|
| 1323 |
+
c1_rots[..., 2, 2] = 1
|
| 1324 |
+
|
| 1325 |
+
norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2)
|
| 1326 |
+
sin_c2 = c_z / norm
|
| 1327 |
+
cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm
|
| 1328 |
+
|
| 1329 |
+
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
|
| 1330 |
+
c2_rots[..., 0, 0] = cos_c2
|
| 1331 |
+
c2_rots[..., 0, 2] = sin_c2
|
| 1332 |
+
c2_rots[..., 1, 1] = 1
|
| 1333 |
+
c2_rots[..., 2, 0] = -1 * sin_c2
|
| 1334 |
+
c2_rots[..., 2, 2] = cos_c2
|
| 1335 |
+
|
| 1336 |
+
c_rots = rot_matmul(c2_rots, c1_rots)
|
| 1337 |
+
n_xyz = rot_vec_mul(c_rots, n_xyz)
|
| 1338 |
+
|
| 1339 |
+
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
|
| 1340 |
+
norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2)
|
| 1341 |
+
sin_n = -n_z / norm
|
| 1342 |
+
cos_n = n_y / norm
|
| 1343 |
+
|
| 1344 |
+
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
|
| 1345 |
+
n_rots[..., 0, 0] = 1
|
| 1346 |
+
n_rots[..., 1, 1] = cos_n
|
| 1347 |
+
n_rots[..., 1, 2] = -1 * sin_n
|
| 1348 |
+
n_rots[..., 2, 1] = sin_n
|
| 1349 |
+
n_rots[..., 2, 2] = cos_n
|
| 1350 |
+
|
| 1351 |
+
rots = rot_matmul(n_rots, c_rots)
|
| 1352 |
+
|
| 1353 |
+
rots = rots.transpose(-1, -2)
|
| 1354 |
+
translation = -1 * translation
|
| 1355 |
+
|
| 1356 |
+
rot_obj = Rotation(rot_mats=rots, quats=None)
|
| 1357 |
+
|
| 1358 |
+
return Rigid(rot_obj, translation)
|
| 1359 |
+
|
| 1360 |
+
def cuda(self) -> Rigid:
|
| 1361 |
+
"""
|
| 1362 |
+
Moves the transformation object to GPU memory
|
| 1363 |
+
|
| 1364 |
+
Returns:
|
| 1365 |
+
A version of the transformation on GPU
|
| 1366 |
+
"""
|
| 1367 |
+
return Rigid(self._rots.cuda(), self._trans.cuda())
|
openfold/utils/tensor_utils.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from functools import partial
|
| 17 |
+
import logging
|
| 18 |
+
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def add(m1, m2, inplace):
|
| 25 |
+
# The first operation in a checkpoint can't be in-place, but it's
|
| 26 |
+
# nice to have in-place addition during inference. Thus...
|
| 27 |
+
if(not inplace):
|
| 28 |
+
m1 = m1 + m2
|
| 29 |
+
else:
|
| 30 |
+
m1 += m2
|
| 31 |
+
|
| 32 |
+
return m1
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
|
| 36 |
+
zero_index = -1 * len(inds)
|
| 37 |
+
first_inds = list(range(len(tensor.shape[:zero_index])))
|
| 38 |
+
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def flatten_final_dims(t: torch.Tensor, no_dims: int):
|
| 42 |
+
return t.reshape(t.shape[:-no_dims] + (-1,))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def masked_mean(mask, value, dim, eps=1e-4):
|
| 46 |
+
mask = mask.expand(*value.shape)
|
| 47 |
+
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
|
| 51 |
+
boundaries = torch.linspace(
|
| 52 |
+
min_bin, max_bin, no_bins - 1, device=pts.device
|
| 53 |
+
)
|
| 54 |
+
dists = torch.sqrt(
|
| 55 |
+
torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
|
| 56 |
+
)
|
| 57 |
+
return torch.bucketize(dists, boundaries)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def dict_multimap(fn, dicts):
|
| 61 |
+
first = dicts[0]
|
| 62 |
+
new_dict = {}
|
| 63 |
+
for k, v in first.items():
|
| 64 |
+
all_v = [d[k] for d in dicts]
|
| 65 |
+
if type(v) is dict:
|
| 66 |
+
new_dict[k] = dict_multimap(fn, all_v)
|
| 67 |
+
else:
|
| 68 |
+
new_dict[k] = fn(all_v)
|
| 69 |
+
|
| 70 |
+
return new_dict
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def one_hot(x, v_bins):
|
| 74 |
+
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
|
| 75 |
+
diffs = x[..., None] - reshaped_bins
|
| 76 |
+
am = torch.argmin(torch.abs(diffs), dim=-1)
|
| 77 |
+
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def batched_gather(data, inds, dim=0, no_batch_dims=0):
|
| 81 |
+
ranges = []
|
| 82 |
+
for i, s in enumerate(data.shape[:no_batch_dims]):
|
| 83 |
+
r = torch.arange(s)
|
| 84 |
+
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
|
| 85 |
+
ranges.append(r)
|
| 86 |
+
|
| 87 |
+
remaining_dims = [
|
| 88 |
+
slice(None) for _ in range(len(data.shape) - no_batch_dims)
|
| 89 |
+
]
|
| 90 |
+
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
|
| 91 |
+
ranges.extend(remaining_dims)
|
| 92 |
+
return data[ranges]
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# With tree_map, a poor man's JAX tree_map
|
| 96 |
+
def dict_map(fn, dic, leaf_type):
|
| 97 |
+
new_dict = {}
|
| 98 |
+
for k, v in dic.items():
|
| 99 |
+
if type(v) is dict:
|
| 100 |
+
new_dict[k] = dict_map(fn, v, leaf_type)
|
| 101 |
+
else:
|
| 102 |
+
new_dict[k] = tree_map(fn, v, leaf_type)
|
| 103 |
+
|
| 104 |
+
return new_dict
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def tree_map(fn, tree, leaf_type):
|
| 108 |
+
if isinstance(tree, dict):
|
| 109 |
+
return dict_map(fn, tree, leaf_type)
|
| 110 |
+
elif isinstance(tree, list):
|
| 111 |
+
return [tree_map(fn, x, leaf_type) for x in tree]
|
| 112 |
+
elif isinstance(tree, tuple):
|
| 113 |
+
return tuple([tree_map(fn, x, leaf_type) for x in tree])
|
| 114 |
+
elif isinstance(tree, leaf_type):
|
| 115 |
+
return fn(tree)
|
| 116 |
+
else:
|
| 117 |
+
print(type(tree))
|
| 118 |
+
raise ValueError("Not supported")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
|
requirements.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
biopython==1.79
|
| 2 |
+
filelock==3.13.1
|
| 3 |
+
fsspec==2024.3.1
|
| 4 |
+
Jinja2==3.1.3
|
| 5 |
+
MarkupSafe==2.1.5
|
| 6 |
+
mpmath==1.3.0
|
| 7 |
+
networkx==3.2.1
|
| 8 |
+
numpy==1.23.5
|
| 9 |
+
nvidia-cublas-cu12==12.1.3.1
|
| 10 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
| 11 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
| 12 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
| 13 |
+
nvidia-cudnn-cu12==8.9.2.26
|
| 14 |
+
nvidia-cufft-cu12==11.0.2.54
|
| 15 |
+
nvidia-curand-cu12==10.3.2.106
|
| 16 |
+
nvidia-cusolver-cu12==11.4.5.107
|
| 17 |
+
nvidia-cusparse-cu12==12.1.0.106
|
| 18 |
+
nvidia-nccl-cu12==2.19.3
|
| 19 |
+
nvidia-nvjitlink-cu12==12.4.99
|
| 20 |
+
nvidia-nvtx-cu12==12.1.105
|
| 21 |
+
ProDy==2.4.1
|
| 22 |
+
pyparsing==3.1.1
|
| 23 |
+
scipy==1.12.0
|
| 24 |
+
sympy==1.12
|
| 25 |
+
torch==2.2.1
|
| 26 |
+
triton==2.2.0
|
| 27 |
+
typing_extensions==4.10.0
|
| 28 |
+
ml-collections==0.1.1
|
| 29 |
+
dm-tree==0.1.8
|
run.py
ADDED
|
@@ -0,0 +1,990 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
import json
|
| 4 |
+
import os.path
|
| 5 |
+
import random
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from data_utils import (
|
| 11 |
+
alphabet,
|
| 12 |
+
element_dict_rev,
|
| 13 |
+
featurize,
|
| 14 |
+
get_score,
|
| 15 |
+
get_seq_rec,
|
| 16 |
+
parse_PDB,
|
| 17 |
+
restype_1to3,
|
| 18 |
+
restype_int_to_str,
|
| 19 |
+
restype_str_to_int,
|
| 20 |
+
write_full_PDB,
|
| 21 |
+
)
|
| 22 |
+
from model_utils import ProteinMPNN
|
| 23 |
+
from prody import writePDB
|
| 24 |
+
from sc_utils import Packer, pack_side_chains
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main(args) -> None:
|
| 28 |
+
"""
|
| 29 |
+
Inference function
|
| 30 |
+
"""
|
| 31 |
+
if args.seed:
|
| 32 |
+
seed = args.seed
|
| 33 |
+
else:
|
| 34 |
+
seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0])
|
| 35 |
+
torch.manual_seed(seed)
|
| 36 |
+
random.seed(seed)
|
| 37 |
+
np.random.seed(seed)
|
| 38 |
+
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
|
| 39 |
+
folder_for_outputs = args.out_folder
|
| 40 |
+
base_folder = folder_for_outputs
|
| 41 |
+
if base_folder[-1] != "/":
|
| 42 |
+
base_folder = base_folder + "/"
|
| 43 |
+
if not os.path.exists(base_folder):
|
| 44 |
+
os.makedirs(base_folder, exist_ok=True)
|
| 45 |
+
if not os.path.exists(base_folder + "seqs"):
|
| 46 |
+
os.makedirs(base_folder + "seqs", exist_ok=True)
|
| 47 |
+
if not os.path.exists(base_folder + "backbones"):
|
| 48 |
+
os.makedirs(base_folder + "backbones", exist_ok=True)
|
| 49 |
+
if not os.path.exists(base_folder + "packed"):
|
| 50 |
+
os.makedirs(base_folder + "packed", exist_ok=True)
|
| 51 |
+
if args.save_stats:
|
| 52 |
+
if not os.path.exists(base_folder + "stats"):
|
| 53 |
+
os.makedirs(base_folder + "stats", exist_ok=True)
|
| 54 |
+
if args.model_type == "protein_mpnn":
|
| 55 |
+
checkpoint_path = args.checkpoint_protein_mpnn
|
| 56 |
+
elif args.model_type == "ligand_mpnn":
|
| 57 |
+
checkpoint_path = args.checkpoint_ligand_mpnn
|
| 58 |
+
elif args.model_type == "per_residue_label_membrane_mpnn":
|
| 59 |
+
checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn
|
| 60 |
+
elif args.model_type == "global_label_membrane_mpnn":
|
| 61 |
+
checkpoint_path = args.checkpoint_global_label_membrane_mpnn
|
| 62 |
+
elif args.model_type == "soluble_mpnn":
|
| 63 |
+
checkpoint_path = args.checkpoint_soluble_mpnn
|
| 64 |
+
else:
|
| 65 |
+
print("Choose one of the available models")
|
| 66 |
+
sys.exit()
|
| 67 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 68 |
+
if args.model_type == "ligand_mpnn":
|
| 69 |
+
atom_context_num = checkpoint["atom_context_num"]
|
| 70 |
+
ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context
|
| 71 |
+
k_neighbors = checkpoint["num_edges"]
|
| 72 |
+
else:
|
| 73 |
+
atom_context_num = 1
|
| 74 |
+
ligand_mpnn_use_side_chain_context = 0
|
| 75 |
+
k_neighbors = checkpoint["num_edges"]
|
| 76 |
+
|
| 77 |
+
model = ProteinMPNN(
|
| 78 |
+
node_features=128,
|
| 79 |
+
edge_features=128,
|
| 80 |
+
hidden_dim=128,
|
| 81 |
+
num_encoder_layers=3,
|
| 82 |
+
num_decoder_layers=3,
|
| 83 |
+
k_neighbors=k_neighbors,
|
| 84 |
+
device=device,
|
| 85 |
+
atom_context_num=atom_context_num,
|
| 86 |
+
model_type=args.model_type,
|
| 87 |
+
ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 91 |
+
model.to(device)
|
| 92 |
+
model.eval()
|
| 93 |
+
|
| 94 |
+
if args.pack_side_chains:
|
| 95 |
+
model_sc = Packer(
|
| 96 |
+
node_features=128,
|
| 97 |
+
edge_features=128,
|
| 98 |
+
num_positional_embeddings=16,
|
| 99 |
+
num_chain_embeddings=16,
|
| 100 |
+
num_rbf=16,
|
| 101 |
+
hidden_dim=128,
|
| 102 |
+
num_encoder_layers=3,
|
| 103 |
+
num_decoder_layers=3,
|
| 104 |
+
atom_context_num=16,
|
| 105 |
+
lower_bound=0.0,
|
| 106 |
+
upper_bound=20.0,
|
| 107 |
+
top_k=32,
|
| 108 |
+
dropout=0.0,
|
| 109 |
+
augment_eps=0.0,
|
| 110 |
+
atom37_order=False,
|
| 111 |
+
device=device,
|
| 112 |
+
num_mix=3,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
checkpoint_sc = torch.load(args.checkpoint_path_sc, map_location=device)
|
| 116 |
+
model_sc.load_state_dict(checkpoint_sc["model_state_dict"])
|
| 117 |
+
model_sc.to(device)
|
| 118 |
+
model_sc.eval()
|
| 119 |
+
|
| 120 |
+
if args.pdb_path_multi:
|
| 121 |
+
with open(args.pdb_path_multi, "r") as fh:
|
| 122 |
+
pdb_paths = list(json.load(fh))
|
| 123 |
+
else:
|
| 124 |
+
pdb_paths = [args.pdb_path]
|
| 125 |
+
|
| 126 |
+
if args.fixed_residues_multi:
|
| 127 |
+
with open(args.fixed_residues_multi, "r") as fh:
|
| 128 |
+
fixed_residues_multi = json.load(fh)
|
| 129 |
+
fixed_residues_multi = {key:value.split() for key,value in fixed_residues_multi.items()}
|
| 130 |
+
else:
|
| 131 |
+
fixed_residues = [item for item in args.fixed_residues.split()]
|
| 132 |
+
fixed_residues_multi = {}
|
| 133 |
+
for pdb in pdb_paths:
|
| 134 |
+
fixed_residues_multi[pdb] = fixed_residues
|
| 135 |
+
|
| 136 |
+
if args.redesigned_residues_multi:
|
| 137 |
+
with open(args.redesigned_residues_multi, "r") as fh:
|
| 138 |
+
redesigned_residues_multi = json.load(fh)
|
| 139 |
+
redesigned_residues_multi = {key:value.split() for key,value in redesigned_residues_multi.items()}
|
| 140 |
+
else:
|
| 141 |
+
redesigned_residues = [item for item in args.redesigned_residues.split()]
|
| 142 |
+
redesigned_residues_multi = {}
|
| 143 |
+
for pdb in pdb_paths:
|
| 144 |
+
redesigned_residues_multi[pdb] = redesigned_residues
|
| 145 |
+
|
| 146 |
+
bias_AA = torch.zeros([21], device=device, dtype=torch.float32)
|
| 147 |
+
if args.bias_AA:
|
| 148 |
+
tmp = [item.split(":") for item in args.bias_AA.split(",")]
|
| 149 |
+
a1 = [b[0] for b in tmp]
|
| 150 |
+
a2 = [float(b[1]) for b in tmp]
|
| 151 |
+
for i, AA in enumerate(a1):
|
| 152 |
+
bias_AA[restype_str_to_int[AA]] = a2[i]
|
| 153 |
+
|
| 154 |
+
if args.bias_AA_per_residue_multi:
|
| 155 |
+
with open(args.bias_AA_per_residue_multi, "r") as fh:
|
| 156 |
+
bias_AA_per_residue_multi = json.load(
|
| 157 |
+
fh
|
| 158 |
+
) # {"pdb_path" : {"A12": {"G": 1.1}}}
|
| 159 |
+
else:
|
| 160 |
+
if args.bias_AA_per_residue:
|
| 161 |
+
with open(args.bias_AA_per_residue, "r") as fh:
|
| 162 |
+
bias_AA_per_residue = json.load(fh) # {"A12": {"G": 1.1}}
|
| 163 |
+
bias_AA_per_residue_multi = {}
|
| 164 |
+
for pdb in pdb_paths:
|
| 165 |
+
bias_AA_per_residue_multi[pdb] = bias_AA_per_residue
|
| 166 |
+
|
| 167 |
+
if args.omit_AA_per_residue_multi:
|
| 168 |
+
with open(args.omit_AA_per_residue_multi, "r") as fh:
|
| 169 |
+
omit_AA_per_residue_multi = json.load(
|
| 170 |
+
fh
|
| 171 |
+
) # {"pdb_path" : {"A12": "PQR", "A13": "QS"}}
|
| 172 |
+
else:
|
| 173 |
+
if args.omit_AA_per_residue:
|
| 174 |
+
with open(args.omit_AA_per_residue, "r") as fh:
|
| 175 |
+
omit_AA_per_residue = json.load(fh) # {"A12": "PG"}
|
| 176 |
+
omit_AA_per_residue_multi = {}
|
| 177 |
+
for pdb in pdb_paths:
|
| 178 |
+
omit_AA_per_residue_multi[pdb] = omit_AA_per_residue
|
| 179 |
+
omit_AA_list = args.omit_AA
|
| 180 |
+
omit_AA = torch.tensor(
|
| 181 |
+
np.array([AA in omit_AA_list for AA in alphabet]).astype(np.float32),
|
| 182 |
+
device=device,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if len(args.parse_these_chains_only) != 0:
|
| 186 |
+
parse_these_chains_only_list = args.parse_these_chains_only.split(",")
|
| 187 |
+
else:
|
| 188 |
+
parse_these_chains_only_list = []
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# loop over PDB paths
|
| 192 |
+
for pdb in pdb_paths:
|
| 193 |
+
if args.verbose:
|
| 194 |
+
print("Designing protein from this path:", pdb)
|
| 195 |
+
fixed_residues = fixed_residues_multi[pdb]
|
| 196 |
+
redesigned_residues = redesigned_residues_multi[pdb]
|
| 197 |
+
parse_all_atoms_flag = args.ligand_mpnn_use_side_chain_context or (
|
| 198 |
+
args.pack_side_chains and not args.repack_everything
|
| 199 |
+
)
|
| 200 |
+
protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(
|
| 201 |
+
pdb,
|
| 202 |
+
device=device,
|
| 203 |
+
chains=parse_these_chains_only_list,
|
| 204 |
+
parse_all_atoms=parse_all_atoms_flag,
|
| 205 |
+
parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy,
|
| 206 |
+
)
|
| 207 |
+
# make chain_letter + residue_idx + insertion_code mapping to integers
|
| 208 |
+
R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices
|
| 209 |
+
chain_letters_list = list(protein_dict["chain_letters"]) # chain letters
|
| 210 |
+
encoded_residues = []
|
| 211 |
+
for i, R_idx_item in enumerate(R_idx_list):
|
| 212 |
+
tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i]
|
| 213 |
+
encoded_residues.append(tmp)
|
| 214 |
+
encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues))))
|
| 215 |
+
encoded_residue_dict_rev = dict(
|
| 216 |
+
zip(list(range(len(encoded_residues))), encoded_residues)
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
bias_AA_per_residue = torch.zeros(
|
| 220 |
+
[len(encoded_residues), 21], device=device, dtype=torch.float32
|
| 221 |
+
)
|
| 222 |
+
if args.bias_AA_per_residue_multi or args.bias_AA_per_residue:
|
| 223 |
+
bias_dict = bias_AA_per_residue_multi[pdb]
|
| 224 |
+
for residue_name, v1 in bias_dict.items():
|
| 225 |
+
if residue_name in encoded_residues:
|
| 226 |
+
i1 = encoded_residue_dict[residue_name]
|
| 227 |
+
for amino_acid, v2 in v1.items():
|
| 228 |
+
if amino_acid in alphabet:
|
| 229 |
+
j1 = restype_str_to_int[amino_acid]
|
| 230 |
+
bias_AA_per_residue[i1, j1] = v2
|
| 231 |
+
|
| 232 |
+
omit_AA_per_residue = torch.zeros(
|
| 233 |
+
[len(encoded_residues), 21], device=device, dtype=torch.float32
|
| 234 |
+
)
|
| 235 |
+
if args.omit_AA_per_residue_multi or args.omit_AA_per_residue:
|
| 236 |
+
omit_dict = omit_AA_per_residue_multi[pdb]
|
| 237 |
+
for residue_name, v1 in omit_dict.items():
|
| 238 |
+
if residue_name in encoded_residues:
|
| 239 |
+
i1 = encoded_residue_dict[residue_name]
|
| 240 |
+
for amino_acid in v1:
|
| 241 |
+
if amino_acid in alphabet:
|
| 242 |
+
j1 = restype_str_to_int[amino_acid]
|
| 243 |
+
omit_AA_per_residue[i1, j1] = 1.0
|
| 244 |
+
|
| 245 |
+
fixed_positions = torch.tensor(
|
| 246 |
+
[int(item not in fixed_residues) for item in encoded_residues],
|
| 247 |
+
device=device,
|
| 248 |
+
)
|
| 249 |
+
redesigned_positions = torch.tensor(
|
| 250 |
+
[int(item not in redesigned_residues) for item in encoded_residues],
|
| 251 |
+
device=device,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model
|
| 255 |
+
if args.transmembrane_buried:
|
| 256 |
+
buried_residues = [item for item in args.transmembrane_buried.split()]
|
| 257 |
+
buried_positions = torch.tensor(
|
| 258 |
+
[int(item in buried_residues) for item in encoded_residues],
|
| 259 |
+
device=device,
|
| 260 |
+
)
|
| 261 |
+
else:
|
| 262 |
+
buried_positions = torch.zeros_like(fixed_positions)
|
| 263 |
+
|
| 264 |
+
if args.transmembrane_interface:
|
| 265 |
+
interface_residues = [item for item in args.transmembrane_interface.split()]
|
| 266 |
+
interface_positions = torch.tensor(
|
| 267 |
+
[int(item in interface_residues) for item in encoded_residues],
|
| 268 |
+
device=device,
|
| 269 |
+
)
|
| 270 |
+
else:
|
| 271 |
+
interface_positions = torch.zeros_like(fixed_positions)
|
| 272 |
+
protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * (
|
| 273 |
+
1 - interface_positions
|
| 274 |
+
) + 1 * interface_positions * (1 - buried_positions)
|
| 275 |
+
|
| 276 |
+
if args.model_type == "global_label_membrane_mpnn":
|
| 277 |
+
protein_dict["membrane_per_residue_labels"] = (
|
| 278 |
+
args.global_transmembrane_label + 0 * fixed_positions
|
| 279 |
+
)
|
| 280 |
+
if len(args.chains_to_design) != 0:
|
| 281 |
+
chains_to_design_list = args.chains_to_design.split(",")
|
| 282 |
+
else:
|
| 283 |
+
chains_to_design_list = protein_dict["chain_letters"]
|
| 284 |
+
|
| 285 |
+
chain_mask = torch.tensor(
|
| 286 |
+
np.array(
|
| 287 |
+
[
|
| 288 |
+
item in chains_to_design_list
|
| 289 |
+
for item in protein_dict["chain_letters"]
|
| 290 |
+
],
|
| 291 |
+
dtype=np.int32,
|
| 292 |
+
),
|
| 293 |
+
device=device,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# create chain_mask to notify which residues are fixed (0) and which need to be designed (1)
|
| 297 |
+
if redesigned_residues:
|
| 298 |
+
protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions)
|
| 299 |
+
elif fixed_residues:
|
| 300 |
+
protein_dict["chain_mask"] = chain_mask * fixed_positions
|
| 301 |
+
else:
|
| 302 |
+
protein_dict["chain_mask"] = chain_mask
|
| 303 |
+
|
| 304 |
+
if args.verbose:
|
| 305 |
+
PDB_residues_to_be_redesigned = [
|
| 306 |
+
encoded_residue_dict_rev[item]
|
| 307 |
+
for item in range(protein_dict["chain_mask"].shape[0])
|
| 308 |
+
if protein_dict["chain_mask"][item] == 1
|
| 309 |
+
]
|
| 310 |
+
PDB_residues_to_be_fixed = [
|
| 311 |
+
encoded_residue_dict_rev[item]
|
| 312 |
+
for item in range(protein_dict["chain_mask"].shape[0])
|
| 313 |
+
if protein_dict["chain_mask"][item] == 0
|
| 314 |
+
]
|
| 315 |
+
print("These residues will be redesigned: ", PDB_residues_to_be_redesigned)
|
| 316 |
+
print("These residues will be fixed: ", PDB_residues_to_be_fixed)
|
| 317 |
+
|
| 318 |
+
# specify which residues are linked
|
| 319 |
+
if args.symmetry_residues:
|
| 320 |
+
symmetry_residues_list_of_lists = [
|
| 321 |
+
x.split(",") for x in args.symmetry_residues.split("|")
|
| 322 |
+
]
|
| 323 |
+
remapped_symmetry_residues = []
|
| 324 |
+
for t_list in symmetry_residues_list_of_lists:
|
| 325 |
+
tmp_list = []
|
| 326 |
+
for t in t_list:
|
| 327 |
+
tmp_list.append(encoded_residue_dict[t])
|
| 328 |
+
remapped_symmetry_residues.append(tmp_list)
|
| 329 |
+
else:
|
| 330 |
+
remapped_symmetry_residues = [[]]
|
| 331 |
+
|
| 332 |
+
# specify linking weights
|
| 333 |
+
if args.symmetry_weights:
|
| 334 |
+
symmetry_weights = [
|
| 335 |
+
[float(item) for item in x.split(",")]
|
| 336 |
+
for x in args.symmetry_weights.split("|")
|
| 337 |
+
]
|
| 338 |
+
else:
|
| 339 |
+
symmetry_weights = [[]]
|
| 340 |
+
|
| 341 |
+
if args.homo_oligomer:
|
| 342 |
+
if args.verbose:
|
| 343 |
+
print("Designing HOMO-OLIGOMER")
|
| 344 |
+
chain_letters_set = list(set(chain_letters_list))
|
| 345 |
+
reference_chain = chain_letters_set[0]
|
| 346 |
+
lc = len(reference_chain)
|
| 347 |
+
residue_indices = [
|
| 348 |
+
item[lc:] for item in encoded_residues if item[:lc] == reference_chain
|
| 349 |
+
]
|
| 350 |
+
remapped_symmetry_residues = []
|
| 351 |
+
symmetry_weights = []
|
| 352 |
+
for res in residue_indices:
|
| 353 |
+
tmp_list = []
|
| 354 |
+
tmp_w_list = []
|
| 355 |
+
for chain in chain_letters_set:
|
| 356 |
+
name = chain + res
|
| 357 |
+
tmp_list.append(encoded_residue_dict[name])
|
| 358 |
+
tmp_w_list.append(1 / len(chain_letters_set))
|
| 359 |
+
remapped_symmetry_residues.append(tmp_list)
|
| 360 |
+
symmetry_weights.append(tmp_w_list)
|
| 361 |
+
|
| 362 |
+
# set other atom bfactors to 0.0
|
| 363 |
+
if other_atoms:
|
| 364 |
+
other_bfactors = other_atoms.getBetas()
|
| 365 |
+
other_atoms.setBetas(other_bfactors * 0.0)
|
| 366 |
+
|
| 367 |
+
# adjust input PDB name by dropping .pdb if it does exist
|
| 368 |
+
name = pdb[pdb.rfind("/") + 1 :]
|
| 369 |
+
if name[-4:] == ".pdb":
|
| 370 |
+
name = name[:-4]
|
| 371 |
+
|
| 372 |
+
with torch.no_grad():
|
| 373 |
+
# run featurize to remap R_idx and add batch dimension
|
| 374 |
+
if args.verbose:
|
| 375 |
+
if "Y" in list(protein_dict):
|
| 376 |
+
atom_coords = protein_dict["Y"].cpu().numpy()
|
| 377 |
+
atom_types = list(protein_dict["Y_t"].cpu().numpy())
|
| 378 |
+
atom_mask = list(protein_dict["Y_m"].cpu().numpy())
|
| 379 |
+
number_of_atoms_parsed = np.sum(atom_mask)
|
| 380 |
+
else:
|
| 381 |
+
print("No ligand atoms parsed")
|
| 382 |
+
number_of_atoms_parsed = 0
|
| 383 |
+
atom_types = ""
|
| 384 |
+
atom_coords = []
|
| 385 |
+
if number_of_atoms_parsed == 0:
|
| 386 |
+
print("No ligand atoms parsed")
|
| 387 |
+
elif args.model_type == "ligand_mpnn":
|
| 388 |
+
print(
|
| 389 |
+
f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}"
|
| 390 |
+
)
|
| 391 |
+
for i, atom_type in enumerate(atom_types):
|
| 392 |
+
print(
|
| 393 |
+
f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}"
|
| 394 |
+
)
|
| 395 |
+
feature_dict = featurize(
|
| 396 |
+
protein_dict,
|
| 397 |
+
cutoff_for_score=args.ligand_mpnn_cutoff_for_score,
|
| 398 |
+
use_atom_context=args.ligand_mpnn_use_atom_context,
|
| 399 |
+
number_of_ligand_atoms=atom_context_num,
|
| 400 |
+
model_type=args.model_type,
|
| 401 |
+
)
|
| 402 |
+
feature_dict["batch_size"] = args.batch_size
|
| 403 |
+
B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now.
|
| 404 |
+
# add additional keys to the feature dictionary
|
| 405 |
+
feature_dict["temperature"] = args.temperature
|
| 406 |
+
feature_dict["bias"] = (
|
| 407 |
+
(-1e8 * omit_AA[None, None, :] + bias_AA).repeat([1, L, 1])
|
| 408 |
+
+ bias_AA_per_residue[None]
|
| 409 |
+
- 1e8 * omit_AA_per_residue[None]
|
| 410 |
+
)
|
| 411 |
+
feature_dict["symmetry_residues"] = remapped_symmetry_residues
|
| 412 |
+
feature_dict["symmetry_weights"] = symmetry_weights
|
| 413 |
+
|
| 414 |
+
sampling_probs_list = []
|
| 415 |
+
log_probs_list = []
|
| 416 |
+
decoding_order_list = []
|
| 417 |
+
S_list = []
|
| 418 |
+
loss_list = []
|
| 419 |
+
loss_per_residue_list = []
|
| 420 |
+
loss_XY_list = []
|
| 421 |
+
for _ in range(args.number_of_batches):
|
| 422 |
+
feature_dict["randn"] = torch.randn(
|
| 423 |
+
[feature_dict["batch_size"], feature_dict["mask"].shape[1]],
|
| 424 |
+
device=device,
|
| 425 |
+
)
|
| 426 |
+
output_dict = model.sample(feature_dict)
|
| 427 |
+
|
| 428 |
+
# compute confidence scores
|
| 429 |
+
loss, loss_per_residue = get_score(
|
| 430 |
+
output_dict["S"],
|
| 431 |
+
output_dict["log_probs"],
|
| 432 |
+
feature_dict["mask"] * feature_dict["chain_mask"],
|
| 433 |
+
)
|
| 434 |
+
if args.model_type == "ligand_mpnn":
|
| 435 |
+
combined_mask = (
|
| 436 |
+
feature_dict["mask"]
|
| 437 |
+
* feature_dict["mask_XY"]
|
| 438 |
+
* feature_dict["chain_mask"]
|
| 439 |
+
)
|
| 440 |
+
else:
|
| 441 |
+
combined_mask = feature_dict["mask"] * feature_dict["chain_mask"]
|
| 442 |
+
loss_XY, _ = get_score(
|
| 443 |
+
output_dict["S"], output_dict["log_probs"], combined_mask
|
| 444 |
+
)
|
| 445 |
+
# -----
|
| 446 |
+
S_list.append(output_dict["S"])
|
| 447 |
+
log_probs_list.append(output_dict["log_probs"])
|
| 448 |
+
sampling_probs_list.append(output_dict["sampling_probs"])
|
| 449 |
+
decoding_order_list.append(output_dict["decoding_order"])
|
| 450 |
+
loss_list.append(loss)
|
| 451 |
+
loss_per_residue_list.append(loss_per_residue)
|
| 452 |
+
loss_XY_list.append(loss_XY)
|
| 453 |
+
S_stack = torch.cat(S_list, 0)
|
| 454 |
+
log_probs_stack = torch.cat(log_probs_list, 0)
|
| 455 |
+
sampling_probs_stack = torch.cat(sampling_probs_list, 0)
|
| 456 |
+
decoding_order_stack = torch.cat(decoding_order_list, 0)
|
| 457 |
+
loss_stack = torch.cat(loss_list, 0)
|
| 458 |
+
loss_per_residue_stack = torch.cat(loss_per_residue_list, 0)
|
| 459 |
+
loss_XY_stack = torch.cat(loss_XY_list, 0)
|
| 460 |
+
rec_mask = feature_dict["mask"][:1] * feature_dict["chain_mask"][:1]
|
| 461 |
+
rec_stack = get_seq_rec(feature_dict["S"][:1], S_stack, rec_mask)
|
| 462 |
+
|
| 463 |
+
native_seq = "".join(
|
| 464 |
+
[restype_int_to_str[AA] for AA in feature_dict["S"][0].cpu().numpy()]
|
| 465 |
+
)
|
| 466 |
+
seq_np = np.array(list(native_seq))
|
| 467 |
+
seq_out_str = []
|
| 468 |
+
for mask in protein_dict["mask_c"]:
|
| 469 |
+
seq_out_str += list(seq_np[mask.cpu().numpy()])
|
| 470 |
+
seq_out_str += [args.fasta_seq_separation]
|
| 471 |
+
seq_out_str = "".join(seq_out_str)[:-1]
|
| 472 |
+
|
| 473 |
+
output_fasta = base_folder + "/seqs/" + name + args.file_ending + ".fa"
|
| 474 |
+
output_backbones = base_folder + "/backbones/"
|
| 475 |
+
output_packed = base_folder + "/packed/"
|
| 476 |
+
output_stats_path = base_folder + "stats/" + name + args.file_ending + ".pt"
|
| 477 |
+
|
| 478 |
+
out_dict = {}
|
| 479 |
+
out_dict["generated_sequences"] = S_stack.cpu()
|
| 480 |
+
out_dict["sampling_probs"] = sampling_probs_stack.cpu()
|
| 481 |
+
out_dict["log_probs"] = log_probs_stack.cpu()
|
| 482 |
+
out_dict["decoding_order"] = decoding_order_stack.cpu()
|
| 483 |
+
out_dict["native_sequence"] = feature_dict["S"][0].cpu()
|
| 484 |
+
out_dict["mask"] = feature_dict["mask"][0].cpu()
|
| 485 |
+
out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu()
|
| 486 |
+
out_dict["seed"] = seed
|
| 487 |
+
out_dict["temperature"] = args.temperature
|
| 488 |
+
if args.save_stats:
|
| 489 |
+
torch.save(out_dict, output_stats_path)
|
| 490 |
+
|
| 491 |
+
if args.pack_side_chains:
|
| 492 |
+
if args.verbose:
|
| 493 |
+
print("Packing side chains...")
|
| 494 |
+
feature_dict_ = featurize(
|
| 495 |
+
protein_dict,
|
| 496 |
+
cutoff_for_score=8.0,
|
| 497 |
+
use_atom_context=args.pack_with_ligand_context,
|
| 498 |
+
number_of_ligand_atoms=16,
|
| 499 |
+
model_type="ligand_mpnn",
|
| 500 |
+
)
|
| 501 |
+
sc_feature_dict = copy.deepcopy(feature_dict_)
|
| 502 |
+
B = args.batch_size
|
| 503 |
+
for k, v in sc_feature_dict.items():
|
| 504 |
+
if k != "S":
|
| 505 |
+
try:
|
| 506 |
+
num_dim = len(v.shape)
|
| 507 |
+
if num_dim == 2:
|
| 508 |
+
sc_feature_dict[k] = v.repeat(B, 1)
|
| 509 |
+
elif num_dim == 3:
|
| 510 |
+
sc_feature_dict[k] = v.repeat(B, 1, 1)
|
| 511 |
+
elif num_dim == 4:
|
| 512 |
+
sc_feature_dict[k] = v.repeat(B, 1, 1, 1)
|
| 513 |
+
elif num_dim == 5:
|
| 514 |
+
sc_feature_dict[k] = v.repeat(B, 1, 1, 1, 1)
|
| 515 |
+
except:
|
| 516 |
+
pass
|
| 517 |
+
X_stack_list = []
|
| 518 |
+
X_m_stack_list = []
|
| 519 |
+
b_factor_stack_list = []
|
| 520 |
+
for _ in range(args.number_of_packs_per_design):
|
| 521 |
+
X_list = []
|
| 522 |
+
X_m_list = []
|
| 523 |
+
b_factor_list = []
|
| 524 |
+
for c in range(args.number_of_batches):
|
| 525 |
+
sc_feature_dict["S"] = S_list[c]
|
| 526 |
+
sc_dict = pack_side_chains(
|
| 527 |
+
sc_feature_dict,
|
| 528 |
+
model_sc,
|
| 529 |
+
args.sc_num_denoising_steps,
|
| 530 |
+
args.sc_num_samples,
|
| 531 |
+
args.repack_everything,
|
| 532 |
+
)
|
| 533 |
+
X_list.append(sc_dict["X"])
|
| 534 |
+
X_m_list.append(sc_dict["X_m"])
|
| 535 |
+
b_factor_list.append(sc_dict["b_factors"])
|
| 536 |
+
|
| 537 |
+
X_stack = torch.cat(X_list, 0)
|
| 538 |
+
X_m_stack = torch.cat(X_m_list, 0)
|
| 539 |
+
b_factor_stack = torch.cat(b_factor_list, 0)
|
| 540 |
+
|
| 541 |
+
X_stack_list.append(X_stack)
|
| 542 |
+
X_m_stack_list.append(X_m_stack)
|
| 543 |
+
b_factor_stack_list.append(b_factor_stack)
|
| 544 |
+
|
| 545 |
+
with open(output_fasta, "w") as f:
|
| 546 |
+
f.write(
|
| 547 |
+
">{}, T={}, seed={}, num_res={}, num_ligand_res={}, use_ligand_context={}, ligand_cutoff_distance={}, batch_size={}, number_of_batches={}, model_path={}\n{}\n".format(
|
| 548 |
+
name,
|
| 549 |
+
args.temperature,
|
| 550 |
+
seed,
|
| 551 |
+
torch.sum(rec_mask).cpu().numpy(),
|
| 552 |
+
torch.sum(combined_mask[:1]).cpu().numpy(),
|
| 553 |
+
bool(args.ligand_mpnn_use_atom_context),
|
| 554 |
+
float(args.ligand_mpnn_cutoff_for_score),
|
| 555 |
+
args.batch_size,
|
| 556 |
+
args.number_of_batches,
|
| 557 |
+
checkpoint_path,
|
| 558 |
+
seq_out_str,
|
| 559 |
+
)
|
| 560 |
+
)
|
| 561 |
+
for ix in range(S_stack.shape[0]):
|
| 562 |
+
ix_suffix = ix
|
| 563 |
+
if not args.zero_indexed:
|
| 564 |
+
ix_suffix += 1
|
| 565 |
+
seq_rec_print = np.format_float_positional(
|
| 566 |
+
rec_stack[ix].cpu().numpy(), unique=False, precision=4
|
| 567 |
+
)
|
| 568 |
+
loss_np = np.format_float_positional(
|
| 569 |
+
np.exp(-loss_stack[ix].cpu().numpy()), unique=False, precision=4
|
| 570 |
+
)
|
| 571 |
+
loss_XY_np = np.format_float_positional(
|
| 572 |
+
np.exp(-loss_XY_stack[ix].cpu().numpy()),
|
| 573 |
+
unique=False,
|
| 574 |
+
precision=4,
|
| 575 |
+
)
|
| 576 |
+
seq = "".join(
|
| 577 |
+
[restype_int_to_str[AA] for AA in S_stack[ix].cpu().numpy()]
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
# write new sequences into PDB with backbone coordinates
|
| 581 |
+
seq_prody = np.array([restype_1to3[AA] for AA in list(seq)])[
|
| 582 |
+
None,
|
| 583 |
+
].repeat(4, 1)
|
| 584 |
+
bfactor_prody = (
|
| 585 |
+
loss_per_residue_stack[ix].cpu().numpy()[None, :].repeat(4, 1)
|
| 586 |
+
)
|
| 587 |
+
backbone.setResnames(seq_prody)
|
| 588 |
+
backbone.setBetas(
|
| 589 |
+
np.exp(-bfactor_prody)
|
| 590 |
+
* (bfactor_prody > 0.01).astype(np.float32)
|
| 591 |
+
)
|
| 592 |
+
if other_atoms:
|
| 593 |
+
writePDB(
|
| 594 |
+
output_backbones
|
| 595 |
+
+ name
|
| 596 |
+
+ "_"
|
| 597 |
+
+ str(ix_suffix)
|
| 598 |
+
+ args.file_ending
|
| 599 |
+
+ ".pdb",
|
| 600 |
+
backbone + other_atoms,
|
| 601 |
+
)
|
| 602 |
+
else:
|
| 603 |
+
writePDB(
|
| 604 |
+
output_backbones
|
| 605 |
+
+ name
|
| 606 |
+
+ "_"
|
| 607 |
+
+ str(ix_suffix)
|
| 608 |
+
+ args.file_ending
|
| 609 |
+
+ ".pdb",
|
| 610 |
+
backbone,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# write full PDB files
|
| 614 |
+
if args.pack_side_chains:
|
| 615 |
+
for c_pack in range(args.number_of_packs_per_design):
|
| 616 |
+
X_stack = X_stack_list[c_pack]
|
| 617 |
+
X_m_stack = X_m_stack_list[c_pack]
|
| 618 |
+
b_factor_stack = b_factor_stack_list[c_pack]
|
| 619 |
+
write_full_PDB(
|
| 620 |
+
output_packed
|
| 621 |
+
+ name
|
| 622 |
+
+ args.packed_suffix
|
| 623 |
+
+ "_"
|
| 624 |
+
+ str(ix_suffix)
|
| 625 |
+
+ "_"
|
| 626 |
+
+ str(c_pack + 1)
|
| 627 |
+
+ args.file_ending
|
| 628 |
+
+ ".pdb",
|
| 629 |
+
X_stack[ix].cpu().numpy(),
|
| 630 |
+
X_m_stack[ix].cpu().numpy(),
|
| 631 |
+
b_factor_stack[ix].cpu().numpy(),
|
| 632 |
+
feature_dict["R_idx_original"][0].cpu().numpy(),
|
| 633 |
+
protein_dict["chain_letters"],
|
| 634 |
+
S_stack[ix].cpu().numpy(),
|
| 635 |
+
other_atoms=other_atoms,
|
| 636 |
+
icodes=icodes,
|
| 637 |
+
force_hetatm=args.force_hetatm,
|
| 638 |
+
)
|
| 639 |
+
# -----
|
| 640 |
+
|
| 641 |
+
# write fasta lines
|
| 642 |
+
seq_np = np.array(list(seq))
|
| 643 |
+
seq_out_str = []
|
| 644 |
+
for mask in protein_dict["mask_c"]:
|
| 645 |
+
seq_out_str += list(seq_np[mask.cpu().numpy()])
|
| 646 |
+
seq_out_str += [args.fasta_seq_separation]
|
| 647 |
+
seq_out_str = "".join(seq_out_str)[:-1]
|
| 648 |
+
if ix == S_stack.shape[0] - 1:
|
| 649 |
+
# final 2 lines
|
| 650 |
+
f.write(
|
| 651 |
+
">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}".format(
|
| 652 |
+
name,
|
| 653 |
+
ix_suffix,
|
| 654 |
+
args.temperature,
|
| 655 |
+
seed,
|
| 656 |
+
loss_np,
|
| 657 |
+
loss_XY_np,
|
| 658 |
+
seq_rec_print,
|
| 659 |
+
seq_out_str,
|
| 660 |
+
)
|
| 661 |
+
)
|
| 662 |
+
else:
|
| 663 |
+
f.write(
|
| 664 |
+
">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}\n".format(
|
| 665 |
+
name,
|
| 666 |
+
ix_suffix,
|
| 667 |
+
args.temperature,
|
| 668 |
+
seed,
|
| 669 |
+
loss_np,
|
| 670 |
+
loss_XY_np,
|
| 671 |
+
seq_rec_print,
|
| 672 |
+
seq_out_str,
|
| 673 |
+
)
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
if __name__ == "__main__":
|
| 678 |
+
argparser = argparse.ArgumentParser(
|
| 679 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
argparser.add_argument(
|
| 683 |
+
"--model_type",
|
| 684 |
+
type=str,
|
| 685 |
+
default="protein_mpnn",
|
| 686 |
+
help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn",
|
| 687 |
+
)
|
| 688 |
+
# protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms
|
| 689 |
+
# ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB
|
| 690 |
+
# per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed
|
| 691 |
+
# global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane
|
| 692 |
+
# soluble_mpnn - ProteinMPNN trained only on soluble PDB ids
|
| 693 |
+
argparser.add_argument(
|
| 694 |
+
"--checkpoint_protein_mpnn",
|
| 695 |
+
type=str,
|
| 696 |
+
default="./model_params/proteinmpnn_v_48_020.pt",
|
| 697 |
+
help="Path to model weights.",
|
| 698 |
+
)
|
| 699 |
+
argparser.add_argument(
|
| 700 |
+
"--checkpoint_ligand_mpnn",
|
| 701 |
+
type=str,
|
| 702 |
+
default="./model_params/ligandmpnn_v_32_010_25.pt",
|
| 703 |
+
help="Path to model weights.",
|
| 704 |
+
)
|
| 705 |
+
argparser.add_argument(
|
| 706 |
+
"--checkpoint_per_residue_label_membrane_mpnn",
|
| 707 |
+
type=str,
|
| 708 |
+
default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt",
|
| 709 |
+
help="Path to model weights.",
|
| 710 |
+
)
|
| 711 |
+
argparser.add_argument(
|
| 712 |
+
"--checkpoint_global_label_membrane_mpnn",
|
| 713 |
+
type=str,
|
| 714 |
+
default="./model_params/global_label_membrane_mpnn_v_48_020.pt",
|
| 715 |
+
help="Path to model weights.",
|
| 716 |
+
)
|
| 717 |
+
argparser.add_argument(
|
| 718 |
+
"--checkpoint_soluble_mpnn",
|
| 719 |
+
type=str,
|
| 720 |
+
default="./model_params/solublempnn_v_48_020.pt",
|
| 721 |
+
help="Path to model weights.",
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
argparser.add_argument(
|
| 725 |
+
"--fasta_seq_separation",
|
| 726 |
+
type=str,
|
| 727 |
+
default=":",
|
| 728 |
+
help="Symbol to use between sequences from different chains",
|
| 729 |
+
)
|
| 730 |
+
argparser.add_argument("--verbose", type=int, default=1, help="Print stuff")
|
| 731 |
+
|
| 732 |
+
argparser.add_argument(
|
| 733 |
+
"--pdb_path", type=str, default="", help="Path to the input PDB."
|
| 734 |
+
)
|
| 735 |
+
argparser.add_argument(
|
| 736 |
+
"--pdb_path_multi",
|
| 737 |
+
type=str,
|
| 738 |
+
default="",
|
| 739 |
+
help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.",
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
argparser.add_argument(
|
| 743 |
+
"--fixed_residues",
|
| 744 |
+
type=str,
|
| 745 |
+
default="",
|
| 746 |
+
help="Provide fixed residues, A12 A13 A14 B2 B25",
|
| 747 |
+
)
|
| 748 |
+
argparser.add_argument(
|
| 749 |
+
"--fixed_residues_multi",
|
| 750 |
+
type=str,
|
| 751 |
+
default="",
|
| 752 |
+
help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
argparser.add_argument(
|
| 756 |
+
"--redesigned_residues",
|
| 757 |
+
type=str,
|
| 758 |
+
default="",
|
| 759 |
+
help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25",
|
| 760 |
+
)
|
| 761 |
+
argparser.add_argument(
|
| 762 |
+
"--redesigned_residues_multi",
|
| 763 |
+
type=str,
|
| 764 |
+
default="",
|
| 765 |
+
help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
argparser.add_argument(
|
| 769 |
+
"--bias_AA",
|
| 770 |
+
type=str,
|
| 771 |
+
default="",
|
| 772 |
+
help="Bias generation of amino acids, e.g. 'A:-1.024,P:2.34,C:-12.34'",
|
| 773 |
+
)
|
| 774 |
+
argparser.add_argument(
|
| 775 |
+
"--bias_AA_per_residue",
|
| 776 |
+
type=str,
|
| 777 |
+
default="",
|
| 778 |
+
help="Path to json mapping of bias {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}",
|
| 779 |
+
)
|
| 780 |
+
argparser.add_argument(
|
| 781 |
+
"--bias_AA_per_residue_multi",
|
| 782 |
+
type=str,
|
| 783 |
+
default="",
|
| 784 |
+
help="Path to json mapping of bias {'pdb_path': {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}}",
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
argparser.add_argument(
|
| 788 |
+
"--omit_AA",
|
| 789 |
+
type=str,
|
| 790 |
+
default="",
|
| 791 |
+
help="Bias generation of amino acids, e.g. 'ACG'",
|
| 792 |
+
)
|
| 793 |
+
argparser.add_argument(
|
| 794 |
+
"--omit_AA_per_residue",
|
| 795 |
+
type=str,
|
| 796 |
+
default="",
|
| 797 |
+
help="Path to json mapping of bias {'A12': 'APQ', 'A13': 'QST'}",
|
| 798 |
+
)
|
| 799 |
+
argparser.add_argument(
|
| 800 |
+
"--omit_AA_per_residue_multi",
|
| 801 |
+
type=str,
|
| 802 |
+
default="",
|
| 803 |
+
help="Path to json mapping of bias {'pdb_path': {'A12': 'QSPC', 'A13': 'AGE'}}",
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
argparser.add_argument(
|
| 807 |
+
"--symmetry_residues",
|
| 808 |
+
type=str,
|
| 809 |
+
default="",
|
| 810 |
+
help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'",
|
| 811 |
+
)
|
| 812 |
+
argparser.add_argument(
|
| 813 |
+
"--symmetry_weights",
|
| 814 |
+
type=str,
|
| 815 |
+
default="",
|
| 816 |
+
help="Add weights that match symmetry_residues, e.g. '1.01,1.0,1.0|-1.0,2.0|2.0,2.3'",
|
| 817 |
+
)
|
| 818 |
+
argparser.add_argument(
|
| 819 |
+
"--homo_oligomer",
|
| 820 |
+
type=int,
|
| 821 |
+
default=0,
|
| 822 |
+
help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.",
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
argparser.add_argument(
|
| 826 |
+
"--out_folder",
|
| 827 |
+
type=str,
|
| 828 |
+
help="Path to a folder to output sequences, e.g. /home/out/",
|
| 829 |
+
)
|
| 830 |
+
argparser.add_argument(
|
| 831 |
+
"--file_ending", type=str, default="", help="adding_string_to_the_end"
|
| 832 |
+
)
|
| 833 |
+
argparser.add_argument(
|
| 834 |
+
"--zero_indexed",
|
| 835 |
+
type=str,
|
| 836 |
+
default=0,
|
| 837 |
+
help="1 - to start output PDB numbering with 0",
|
| 838 |
+
)
|
| 839 |
+
argparser.add_argument(
|
| 840 |
+
"--seed",
|
| 841 |
+
type=int,
|
| 842 |
+
default=0,
|
| 843 |
+
help="Set seed for torch, numpy, and python random.",
|
| 844 |
+
)
|
| 845 |
+
argparser.add_argument(
|
| 846 |
+
"--batch_size",
|
| 847 |
+
type=int,
|
| 848 |
+
default=1,
|
| 849 |
+
help="Number of sequence to generate per one pass.",
|
| 850 |
+
)
|
| 851 |
+
argparser.add_argument(
|
| 852 |
+
"--number_of_batches",
|
| 853 |
+
type=int,
|
| 854 |
+
default=1,
|
| 855 |
+
help="Number of times to design sequence using a chosen batch size.",
|
| 856 |
+
)
|
| 857 |
+
argparser.add_argument(
|
| 858 |
+
"--temperature",
|
| 859 |
+
type=float,
|
| 860 |
+
default=0.1,
|
| 861 |
+
help="Temperature to sample sequences.",
|
| 862 |
+
)
|
| 863 |
+
argparser.add_argument(
|
| 864 |
+
"--save_stats", type=int, default=0, help="Save output statistics"
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
argparser.add_argument(
|
| 868 |
+
"--ligand_mpnn_use_atom_context",
|
| 869 |
+
type=int,
|
| 870 |
+
default=1,
|
| 871 |
+
help="1 - use atom context, 0 - do not use atom context.",
|
| 872 |
+
)
|
| 873 |
+
argparser.add_argument(
|
| 874 |
+
"--ligand_mpnn_cutoff_for_score",
|
| 875 |
+
type=float,
|
| 876 |
+
default=8.0,
|
| 877 |
+
help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.",
|
| 878 |
+
)
|
| 879 |
+
argparser.add_argument(
|
| 880 |
+
"--ligand_mpnn_use_side_chain_context",
|
| 881 |
+
type=int,
|
| 882 |
+
default=0,
|
| 883 |
+
help="Flag to use side chain atoms as ligand context for the fixed residues",
|
| 884 |
+
)
|
| 885 |
+
argparser.add_argument(
|
| 886 |
+
"--chains_to_design",
|
| 887 |
+
type=str,
|
| 888 |
+
default="",
|
| 889 |
+
help="Specify which chains to redesign, all others will be kept fixed, 'A,B,C,F'",
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
argparser.add_argument(
|
| 893 |
+
"--parse_these_chains_only",
|
| 894 |
+
type=str,
|
| 895 |
+
default="",
|
| 896 |
+
help="Provide chains letters for parsing backbones, 'A,B,C,F'",
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
argparser.add_argument(
|
| 900 |
+
"--transmembrane_buried",
|
| 901 |
+
type=str,
|
| 902 |
+
default="",
|
| 903 |
+
help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
|
| 904 |
+
)
|
| 905 |
+
argparser.add_argument(
|
| 906 |
+
"--transmembrane_interface",
|
| 907 |
+
type=str,
|
| 908 |
+
default="",
|
| 909 |
+
help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
argparser.add_argument(
|
| 913 |
+
"--global_transmembrane_label",
|
| 914 |
+
type=int,
|
| 915 |
+
default=0,
|
| 916 |
+
help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble",
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
argparser.add_argument(
|
| 920 |
+
"--parse_atoms_with_zero_occupancy",
|
| 921 |
+
type=int,
|
| 922 |
+
default=0,
|
| 923 |
+
help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy",
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
argparser.add_argument(
|
| 927 |
+
"--pack_side_chains",
|
| 928 |
+
type=int,
|
| 929 |
+
default=0,
|
| 930 |
+
help="1 - to run side chain packer, 0 - do not run it",
|
| 931 |
+
)
|
| 932 |
+
|
| 933 |
+
argparser.add_argument(
|
| 934 |
+
"--checkpoint_path_sc",
|
| 935 |
+
type=str,
|
| 936 |
+
default="./model_params/ligandmpnn_sc_v_32_002_16.pt",
|
| 937 |
+
help="Path to model weights.",
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
argparser.add_argument(
|
| 941 |
+
"--number_of_packs_per_design",
|
| 942 |
+
type=int,
|
| 943 |
+
default=4,
|
| 944 |
+
help="Number of independent side chain packing samples to return per design",
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
argparser.add_argument(
|
| 948 |
+
"--sc_num_denoising_steps",
|
| 949 |
+
type=int,
|
| 950 |
+
default=3,
|
| 951 |
+
help="Number of denoising/recycling steps to make for side chain packing",
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
argparser.add_argument(
|
| 955 |
+
"--sc_num_samples",
|
| 956 |
+
type=int,
|
| 957 |
+
default=16,
|
| 958 |
+
help="Number of samples to draw from a mixture distribution and then take a sample with the highest likelihood.",
|
| 959 |
+
)
|
| 960 |
+
|
| 961 |
+
argparser.add_argument(
|
| 962 |
+
"--repack_everything",
|
| 963 |
+
type=int,
|
| 964 |
+
default=0,
|
| 965 |
+
help="1 - repacks side chains of all residues including the fixed ones; 0 - keeps the side chains fixed for fixed residues",
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
argparser.add_argument(
|
| 969 |
+
"--force_hetatm",
|
| 970 |
+
type=int,
|
| 971 |
+
default=0,
|
| 972 |
+
help="To force ligand atoms to be written as HETATM to PDB file after packing.",
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
argparser.add_argument(
|
| 976 |
+
"--packed_suffix",
|
| 977 |
+
type=str,
|
| 978 |
+
default="_packed",
|
| 979 |
+
help="Suffix for packed PDB paths",
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
argparser.add_argument(
|
| 983 |
+
"--pack_with_ligand_context",
|
| 984 |
+
type=int,
|
| 985 |
+
default=1,
|
| 986 |
+
help="1-pack side chains using ligand context, 0 - do not use it.",
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
args = argparser.parse_args()
|
| 990 |
+
main(args)
|
run_examples.sh
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
#1
|
| 4 |
+
python run.py \
|
| 5 |
+
--seed 111 \
|
| 6 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 7 |
+
--out_folder "./outputs/default"
|
| 8 |
+
#2
|
| 9 |
+
python run.py \
|
| 10 |
+
--seed 111 \
|
| 11 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 12 |
+
--temperature 0.05 \
|
| 13 |
+
--out_folder "./outputs/temperature"
|
| 14 |
+
|
| 15 |
+
#3
|
| 16 |
+
python run.py \
|
| 17 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 18 |
+
--out_folder "./outputs/random_seed"
|
| 19 |
+
|
| 20 |
+
#4
|
| 21 |
+
python run.py \
|
| 22 |
+
--seed 111 \
|
| 23 |
+
--verbose 0 \
|
| 24 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 25 |
+
--out_folder "./outputs/verbose"
|
| 26 |
+
|
| 27 |
+
#5
|
| 28 |
+
python run.py \
|
| 29 |
+
--seed 111 \
|
| 30 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 31 |
+
--out_folder "./outputs/save_stats" \
|
| 32 |
+
--save_stats 1
|
| 33 |
+
|
| 34 |
+
#6
|
| 35 |
+
python run.py \
|
| 36 |
+
--seed 111 \
|
| 37 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 38 |
+
--out_folder "./outputs/fix_residues" \
|
| 39 |
+
--fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
|
| 40 |
+
--bias_AA "A:10.0"
|
| 41 |
+
|
| 42 |
+
#7
|
| 43 |
+
python run.py \
|
| 44 |
+
--seed 111 \
|
| 45 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 46 |
+
--out_folder "./outputs/redesign_residues" \
|
| 47 |
+
--redesigned_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
|
| 48 |
+
--bias_AA "A:10.0"
|
| 49 |
+
|
| 50 |
+
#8
|
| 51 |
+
python run.py \
|
| 52 |
+
--seed 111 \
|
| 53 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 54 |
+
--out_folder "./outputs/batch_size" \
|
| 55 |
+
--batch_size 3 \
|
| 56 |
+
--number_of_batches 5
|
| 57 |
+
|
| 58 |
+
#9
|
| 59 |
+
python run.py \
|
| 60 |
+
--seed 111 \
|
| 61 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 62 |
+
--bias_AA "W:3.0,P:3.0,C:3.0,A:-3.0" \
|
| 63 |
+
--out_folder "./outputs/global_bias"
|
| 64 |
+
|
| 65 |
+
#10
|
| 66 |
+
python run.py \
|
| 67 |
+
--seed 111 \
|
| 68 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 69 |
+
--bias_AA_per_residue "./inputs/bias_AA_per_residue.json" \
|
| 70 |
+
--out_folder "./outputs/per_residue_bias"
|
| 71 |
+
|
| 72 |
+
#11
|
| 73 |
+
python run.py \
|
| 74 |
+
--seed 111 \
|
| 75 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 76 |
+
--omit_AA "CDFGHILMNPQRSTVWY" \
|
| 77 |
+
--out_folder "./outputs/global_omit"
|
| 78 |
+
|
| 79 |
+
#12
|
| 80 |
+
python run.py \
|
| 81 |
+
--seed 111 \
|
| 82 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 83 |
+
--omit_AA_per_residue "./inputs/omit_AA_per_residue.json" \
|
| 84 |
+
--out_folder "./outputs/per_residue_omit"
|
| 85 |
+
|
| 86 |
+
#13
|
| 87 |
+
python run.py \
|
| 88 |
+
--seed 111 \
|
| 89 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 90 |
+
--out_folder "./outputs/symmetry" \
|
| 91 |
+
--symmetry_residues "C1,C2,C3|C4,C5|C6,C7" \
|
| 92 |
+
--symmetry_weights "0.33,0.33,0.33|0.5,0.5|0.5,0.5"
|
| 93 |
+
|
| 94 |
+
#14
|
| 95 |
+
python run.py \
|
| 96 |
+
--model_type "ligand_mpnn" \
|
| 97 |
+
--seed 111 \
|
| 98 |
+
--pdb_path "./inputs/4GYT.pdb" \
|
| 99 |
+
--out_folder "./outputs/homooligomer" \
|
| 100 |
+
--homo_oligomer 1 \
|
| 101 |
+
--number_of_batches 2
|
| 102 |
+
|
| 103 |
+
#15
|
| 104 |
+
python run.py \
|
| 105 |
+
--seed 111 \
|
| 106 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 107 |
+
--out_folder "./outputs/file_ending" \
|
| 108 |
+
--file_ending "_xyz"
|
| 109 |
+
|
| 110 |
+
#16
|
| 111 |
+
python run.py \
|
| 112 |
+
--seed 111 \
|
| 113 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 114 |
+
--out_folder "./outputs/zero_indexed" \
|
| 115 |
+
--zero_indexed 1 \
|
| 116 |
+
--number_of_batches 2
|
| 117 |
+
|
| 118 |
+
#17
|
| 119 |
+
python run.py \
|
| 120 |
+
--model_type "ligand_mpnn" \
|
| 121 |
+
--seed 111 \
|
| 122 |
+
--pdb_path "./inputs/4GYT.pdb" \
|
| 123 |
+
--out_folder "./outputs/chains_to_design" \
|
| 124 |
+
--chains_to_design "A,B"
|
| 125 |
+
|
| 126 |
+
#18
|
| 127 |
+
python run.py \
|
| 128 |
+
--model_type "ligand_mpnn" \
|
| 129 |
+
--seed 111 \
|
| 130 |
+
--pdb_path "./inputs/4GYT.pdb" \
|
| 131 |
+
--out_folder "./outputs/parse_these_chains_only" \
|
| 132 |
+
--parse_these_chains_only "A,B"
|
| 133 |
+
|
| 134 |
+
#19
|
| 135 |
+
python run.py \
|
| 136 |
+
--model_type "ligand_mpnn" \
|
| 137 |
+
--seed 111 \
|
| 138 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 139 |
+
--out_folder "./outputs/ligandmpnn_default"
|
| 140 |
+
|
| 141 |
+
#20
|
| 142 |
+
python run.py \
|
| 143 |
+
--checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_005_25.pt" \
|
| 144 |
+
--model_type "ligand_mpnn" \
|
| 145 |
+
--seed 111 \
|
| 146 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 147 |
+
--out_folder "./outputs/ligandmpnn_v_32_005_25"
|
| 148 |
+
|
| 149 |
+
#21
|
| 150 |
+
python run.py \
|
| 151 |
+
--model_type "ligand_mpnn" \
|
| 152 |
+
--seed 111 \
|
| 153 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 154 |
+
--out_folder "./outputs/ligandmpnn_no_context" \
|
| 155 |
+
--ligand_mpnn_use_atom_context 0
|
| 156 |
+
|
| 157 |
+
#22
|
| 158 |
+
python run.py \
|
| 159 |
+
--model_type "ligand_mpnn" \
|
| 160 |
+
--seed 111 \
|
| 161 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 162 |
+
--out_folder "./outputs/ligandmpnn_use_side_chain_atoms" \
|
| 163 |
+
--ligand_mpnn_use_side_chain_context 1 \
|
| 164 |
+
--fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10"
|
| 165 |
+
|
| 166 |
+
#23
|
| 167 |
+
python run.py \
|
| 168 |
+
--model_type "soluble_mpnn" \
|
| 169 |
+
--seed 111 \
|
| 170 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 171 |
+
--out_folder "./outputs/soluble_mpnn_default"
|
| 172 |
+
|
| 173 |
+
#24
|
| 174 |
+
python run.py \
|
| 175 |
+
--model_type "global_label_membrane_mpnn" \
|
| 176 |
+
--seed 111 \
|
| 177 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 178 |
+
--out_folder "./outputs/global_label_membrane_mpnn_0" \
|
| 179 |
+
--global_transmembrane_label 0
|
| 180 |
+
|
| 181 |
+
#25
|
| 182 |
+
python run.py \
|
| 183 |
+
--model_type "per_residue_label_membrane_mpnn" \
|
| 184 |
+
--seed 111 \
|
| 185 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 186 |
+
--out_folder "./outputs/per_residue_label_membrane_mpnn_default" \
|
| 187 |
+
--transmembrane_buried "C1 C2 C3 C11" \
|
| 188 |
+
--transmembrane_interface "C4 C5 C6 C22"
|
| 189 |
+
|
| 190 |
+
#26
|
| 191 |
+
python run.py \
|
| 192 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 193 |
+
--out_folder "./outputs/fasta_seq_separation" \
|
| 194 |
+
--fasta_seq_separation ":"
|
| 195 |
+
|
| 196 |
+
#27
|
| 197 |
+
python run.py \
|
| 198 |
+
--pdb_path_multi "./inputs/pdb_ids.json" \
|
| 199 |
+
--out_folder "./outputs/pdb_path_multi" \
|
| 200 |
+
--seed 111
|
| 201 |
+
|
| 202 |
+
#28
|
| 203 |
+
python run.py \
|
| 204 |
+
--pdb_path_multi "./inputs/pdb_ids.json" \
|
| 205 |
+
--fixed_residues_multi "./inputs/fix_residues_multi.json" \
|
| 206 |
+
--out_folder "./outputs/fixed_residues_multi" \
|
| 207 |
+
--seed 111
|
| 208 |
+
|
| 209 |
+
#29
|
| 210 |
+
python run.py \
|
| 211 |
+
--pdb_path_multi "./inputs/pdb_ids.json" \
|
| 212 |
+
--redesigned_residues_multi "./inputs/redesigned_residues_multi.json" \
|
| 213 |
+
--out_folder "./outputs/redesigned_residues_multi" \
|
| 214 |
+
--seed 111
|
| 215 |
+
|
| 216 |
+
#30
|
| 217 |
+
python run.py \
|
| 218 |
+
--pdb_path_multi "./inputs/pdb_ids.json" \
|
| 219 |
+
--omit_AA_per_residue_multi "./inputs/omit_AA_per_residue_multi.json" \
|
| 220 |
+
--out_folder "./outputs/omit_AA_per_residue_multi" \
|
| 221 |
+
--seed 111
|
| 222 |
+
|
| 223 |
+
#31
|
| 224 |
+
python run.py \
|
| 225 |
+
--pdb_path_multi "./inputs/pdb_ids.json" \
|
| 226 |
+
--bias_AA_per_residue_multi "./inputs/bias_AA_per_residue_multi.json" \
|
| 227 |
+
--out_folder "./outputs/bias_AA_per_residue_multi" \
|
| 228 |
+
--seed 111
|
| 229 |
+
|
| 230 |
+
#32
|
| 231 |
+
python run.py \
|
| 232 |
+
--model_type "ligand_mpnn" \
|
| 233 |
+
--seed 111 \
|
| 234 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 235 |
+
--ligand_mpnn_cutoff_for_score "6.0" \
|
| 236 |
+
--out_folder "./outputs/ligand_mpnn_cutoff_for_score"
|
| 237 |
+
|
| 238 |
+
#33
|
| 239 |
+
python run.py \
|
| 240 |
+
--seed 111 \
|
| 241 |
+
--pdb_path "./inputs/2GFB.pdb" \
|
| 242 |
+
--out_folder "./outputs/insertion_code" \
|
| 243 |
+
--redesigned_residues "B82 B82A B82B B82C" \
|
| 244 |
+
--parse_these_chains_only "B"
|
sc_examples.sh
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#1 design a new sequence and pack side chains (return 1 side chain packing sample - fast)
|
| 2 |
+
python run.py \
|
| 3 |
+
--model_type "ligand_mpnn" \
|
| 4 |
+
--seed 111 \
|
| 5 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 6 |
+
--out_folder "./outputs/sc_default_fast" \
|
| 7 |
+
--pack_side_chains 1 \
|
| 8 |
+
--number_of_packs_per_design 0 \
|
| 9 |
+
--pack_with_ligand_context 1
|
| 10 |
+
|
| 11 |
+
#2 design a new sequence and pack side chains (return 4 side chain packing samples)
|
| 12 |
+
python run.py \
|
| 13 |
+
--model_type "ligand_mpnn" \
|
| 14 |
+
--seed 111 \
|
| 15 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 16 |
+
--out_folder "./outputs/sc_default" \
|
| 17 |
+
--pack_side_chains 1 \
|
| 18 |
+
--number_of_packs_per_design 4 \
|
| 19 |
+
--pack_with_ligand_context 1
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
#3 fix specific residues for design and packing
|
| 23 |
+
python run.py \
|
| 24 |
+
--model_type "ligand_mpnn" \
|
| 25 |
+
--seed 111 \
|
| 26 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 27 |
+
--out_folder "./outputs/sc_fixed_residues" \
|
| 28 |
+
--pack_side_chains 1 \
|
| 29 |
+
--number_of_packs_per_design 4 \
|
| 30 |
+
--pack_with_ligand_context 1 \
|
| 31 |
+
--fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
|
| 32 |
+
--repack_everything 0
|
| 33 |
+
|
| 34 |
+
#4 fix specific residues for sequence design but repack everything
|
| 35 |
+
python run.py \
|
| 36 |
+
--model_type "ligand_mpnn" \
|
| 37 |
+
--seed 111 \
|
| 38 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 39 |
+
--out_folder "./outputs/sc_fixed_residues_full_repack" \
|
| 40 |
+
--pack_side_chains 1 \
|
| 41 |
+
--number_of_packs_per_design 4 \
|
| 42 |
+
--pack_with_ligand_context 1 \
|
| 43 |
+
--fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
|
| 44 |
+
--repack_everything 1
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
#5 design a new sequence using LigandMPNN but pack side chains without considering ligand/DNA etc atoms
|
| 48 |
+
python run.py \
|
| 49 |
+
--model_type "ligand_mpnn" \
|
| 50 |
+
--seed 111 \
|
| 51 |
+
--pdb_path "./inputs/1BC8.pdb" \
|
| 52 |
+
--out_folder "./outputs/sc_no_context" \
|
| 53 |
+
--pack_side_chains 1 \
|
| 54 |
+
--number_of_packs_per_design 4 \
|
| 55 |
+
--pack_with_ligand_context 0
|
sc_utils.py
ADDED
|
@@ -0,0 +1,1158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributions as D
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from model_utils import (
|
| 8 |
+
DecLayer,
|
| 9 |
+
DecLayerJ,
|
| 10 |
+
EncLayer,
|
| 11 |
+
PositionalEncodings,
|
| 12 |
+
cat_neighbors_nodes,
|
| 13 |
+
gather_edges,
|
| 14 |
+
gather_nodes,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from openfold.data.data_transforms import atom37_to_torsion_angles, make_atom14_masks
|
| 18 |
+
from openfold.np.residue_constants import (
|
| 19 |
+
restype_atom14_mask,
|
| 20 |
+
restype_atom14_rigid_group_positions,
|
| 21 |
+
restype_atom14_to_rigid_group,
|
| 22 |
+
restype_rigid_group_default_frame,
|
| 23 |
+
)
|
| 24 |
+
from openfold.utils import feats
|
| 25 |
+
from openfold.utils.rigid_utils import Rigid
|
| 26 |
+
|
| 27 |
+
torch_pi = torch.tensor(np.pi, device="cpu")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
map_mpnn_to_af2_seq = torch.tensor(
|
| 31 |
+
[
|
| 32 |
+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 33 |
+
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 34 |
+
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 35 |
+
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 36 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
|
| 37 |
+
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 38 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 39 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 40 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 41 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 42 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 43 |
+
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 44 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
|
| 45 |
+
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 46 |
+
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
| 47 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
|
| 48 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
|
| 49 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
|
| 50 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
|
| 51 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
|
| 52 |
+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
|
| 53 |
+
],
|
| 54 |
+
device="cpu",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def pack_side_chains(
|
| 59 |
+
feature_dict,
|
| 60 |
+
model_sc,
|
| 61 |
+
num_denoising_steps,
|
| 62 |
+
num_samples=10,
|
| 63 |
+
repack_everything=True,
|
| 64 |
+
num_context_atoms=16,
|
| 65 |
+
):
|
| 66 |
+
device = feature_dict["X"].device
|
| 67 |
+
torsion_dict = make_torsion_features(feature_dict, repack_everything)
|
| 68 |
+
feature_dict["X"] = torsion_dict["xyz14_noised"]
|
| 69 |
+
feature_dict["X_m"] = torsion_dict["xyz14_m"]
|
| 70 |
+
if "Y" not in list(feature_dict):
|
| 71 |
+
feature_dict["Y"] = torch.zeros(
|
| 72 |
+
[
|
| 73 |
+
feature_dict["X"].shape[0],
|
| 74 |
+
feature_dict["X"].shape[1],
|
| 75 |
+
num_context_atoms,
|
| 76 |
+
3,
|
| 77 |
+
],
|
| 78 |
+
device=device,
|
| 79 |
+
)
|
| 80 |
+
feature_dict["Y_t"] = torch.zeros(
|
| 81 |
+
[feature_dict["X"].shape[0], feature_dict["X"].shape[1], num_context_atoms],
|
| 82 |
+
device=device,
|
| 83 |
+
)
|
| 84 |
+
feature_dict["Y_m"] = torch.zeros(
|
| 85 |
+
[feature_dict["X"].shape[0], feature_dict["X"].shape[1], num_context_atoms],
|
| 86 |
+
device=device,
|
| 87 |
+
)
|
| 88 |
+
h_V, h_E, E_idx = model_sc.encode(feature_dict)
|
| 89 |
+
feature_dict["h_V"] = h_V
|
| 90 |
+
feature_dict["h_E"] = h_E
|
| 91 |
+
feature_dict["E_idx"] = E_idx
|
| 92 |
+
for step in range(num_denoising_steps):
|
| 93 |
+
mean, concentration, mix_logits = model_sc.decode(feature_dict)
|
| 94 |
+
mix = D.Categorical(logits=mix_logits)
|
| 95 |
+
comp = D.VonMises(mean, concentration)
|
| 96 |
+
pred_dist = D.MixtureSameFamily(mix, comp)
|
| 97 |
+
predicted_samples = pred_dist.sample([num_samples])
|
| 98 |
+
log_probs_of_samples = pred_dist.log_prob(predicted_samples)
|
| 99 |
+
sample = torch.gather(
|
| 100 |
+
predicted_samples, dim=0, index=torch.argmax(log_probs_of_samples, 0)[None,]
|
| 101 |
+
)[0,]
|
| 102 |
+
torsions_pred_unit = torch.cat(
|
| 103 |
+
[torch.sin(sample[:, :, :, None]), torch.cos(sample[:, :, :, None])], -1
|
| 104 |
+
)
|
| 105 |
+
torsion_dict["torsions_noised"][:, :, 3:] = torsions_pred_unit * torsion_dict[
|
| 106 |
+
"mask_fix_sc"
|
| 107 |
+
] + torsion_dict["torsions_true"] * (1 - torsion_dict["mask_fix_sc"])
|
| 108 |
+
pred_frames = feats.torsion_angles_to_frames(
|
| 109 |
+
torsion_dict["rigids"],
|
| 110 |
+
torsion_dict["torsions_noised"],
|
| 111 |
+
torsion_dict["aatype"],
|
| 112 |
+
torch.tensor(restype_rigid_group_default_frame, device=device),
|
| 113 |
+
)
|
| 114 |
+
xyz14_noised = feats.frames_and_literature_positions_to_atom14_pos(
|
| 115 |
+
pred_frames,
|
| 116 |
+
torsion_dict["aatype"],
|
| 117 |
+
torch.tensor(restype_rigid_group_default_frame, device=device),
|
| 118 |
+
torch.tensor(restype_atom14_to_rigid_group, device=device),
|
| 119 |
+
torch.tensor(restype_atom14_mask, device=device),
|
| 120 |
+
torch.tensor(restype_atom14_rigid_group_positions, device=device),
|
| 121 |
+
)
|
| 122 |
+
xyz14_noised = xyz14_noised * feature_dict["X_m"][:, :, :, None]
|
| 123 |
+
feature_dict["X"] = xyz14_noised
|
| 124 |
+
S_af2 = torsion_dict["S_af2"]
|
| 125 |
+
|
| 126 |
+
feature_dict["X"] = xyz14_noised
|
| 127 |
+
|
| 128 |
+
log_prob = pred_dist.log_prob(sample) * torsion_dict["mask_fix_sc"][
|
| 129 |
+
..., 0
|
| 130 |
+
] + 2.0 * (1 - torsion_dict["mask_fix_sc"][..., 0])
|
| 131 |
+
|
| 132 |
+
tmp_types = torch.tensor(restype_atom14_to_rigid_group, device=device)[S_af2]
|
| 133 |
+
tmp_types[tmp_types < 4] = 4
|
| 134 |
+
tmp_types -= 4
|
| 135 |
+
atom_types_for_b_factor = torch.nn.functional.one_hot(tmp_types, 4) # [B, L, 14, 4]
|
| 136 |
+
|
| 137 |
+
uncertainty = log_prob[:, :, None, :] * atom_types_for_b_factor # [B,L,14,4]
|
| 138 |
+
b_factor_pred = uncertainty.sum(-1) # [B, L, 14]
|
| 139 |
+
feature_dict["b_factors"] = b_factor_pred
|
| 140 |
+
feature_dict["mean"] = mean
|
| 141 |
+
feature_dict["concentration"] = concentration
|
| 142 |
+
feature_dict["mix_logits"] = mix_logits
|
| 143 |
+
feature_dict["log_prob"] = log_prob
|
| 144 |
+
feature_dict["sample"] = sample
|
| 145 |
+
feature_dict["true_torsion_sin_cos"] = torsion_dict["torsions_true"]
|
| 146 |
+
return feature_dict
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def make_torsion_features(feature_dict, repack_everything=True):
|
| 150 |
+
device = feature_dict["mask"].device
|
| 151 |
+
|
| 152 |
+
mask = feature_dict["mask"]
|
| 153 |
+
B, L = mask.shape
|
| 154 |
+
|
| 155 |
+
xyz37 = torch.zeros([B, L, 37, 3], device=device, dtype=torch.float32)
|
| 156 |
+
xyz37[:, :, :3] = feature_dict["X"][:, :, :3]
|
| 157 |
+
xyz37[:, :, 4] = feature_dict["X"][:, :, 3]
|
| 158 |
+
|
| 159 |
+
S_af2 = torch.argmax(
|
| 160 |
+
torch.nn.functional.one_hot(feature_dict["S"], 21).float()
|
| 161 |
+
@ map_mpnn_to_af2_seq.to(device).float(),
|
| 162 |
+
-1,
|
| 163 |
+
)
|
| 164 |
+
masks14_37 = make_atom14_masks({"aatype": S_af2})
|
| 165 |
+
temp_dict = {
|
| 166 |
+
"aatype": S_af2,
|
| 167 |
+
"all_atom_positions": xyz37,
|
| 168 |
+
"all_atom_mask": masks14_37["atom37_atom_exists"],
|
| 169 |
+
}
|
| 170 |
+
torsion_dict = atom37_to_torsion_angles("")(temp_dict)
|
| 171 |
+
|
| 172 |
+
rigids = Rigid.make_transform_from_reference(
|
| 173 |
+
n_xyz=xyz37[:, :, 0, :],
|
| 174 |
+
ca_xyz=xyz37[:, :, 1, :],
|
| 175 |
+
c_xyz=xyz37[:, :, 2, :],
|
| 176 |
+
eps=1e-9,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if not repack_everything:
|
| 180 |
+
xyz37_true = feature_dict["xyz_37"]
|
| 181 |
+
temp_dict_true = {
|
| 182 |
+
"aatype": S_af2,
|
| 183 |
+
"all_atom_positions": xyz37_true,
|
| 184 |
+
"all_atom_mask": masks14_37["atom37_atom_exists"],
|
| 185 |
+
}
|
| 186 |
+
torsion_dict_true = atom37_to_torsion_angles("")(temp_dict_true)
|
| 187 |
+
torsions_true = torch.clone(torsion_dict_true["torsion_angles_sin_cos"])[
|
| 188 |
+
:, :, 3:
|
| 189 |
+
]
|
| 190 |
+
mask_fix_sc = feature_dict["chain_mask"][:, :, None, None]
|
| 191 |
+
else:
|
| 192 |
+
torsions_true = torch.zeros([B, L, 4, 2], device=device)
|
| 193 |
+
mask_fix_sc = torch.ones([B, L, 1, 1], device=device)
|
| 194 |
+
|
| 195 |
+
random_angle = (
|
| 196 |
+
2 * torch_pi * torch.rand([S_af2.shape[0], S_af2.shape[1], 4], device=device)
|
| 197 |
+
)
|
| 198 |
+
random_sin_cos = torch.cat(
|
| 199 |
+
[torch.sin(random_angle)[..., None], torch.cos(random_angle)[..., None]], -1
|
| 200 |
+
)
|
| 201 |
+
torsions_noised = torch.clone(torsion_dict["torsion_angles_sin_cos"])
|
| 202 |
+
torsions_noised[:, :, 3:] = random_sin_cos * mask_fix_sc + torsions_true * (
|
| 203 |
+
1 - mask_fix_sc
|
| 204 |
+
)
|
| 205 |
+
pred_frames = feats.torsion_angles_to_frames(
|
| 206 |
+
rigids,
|
| 207 |
+
torsions_noised,
|
| 208 |
+
S_af2,
|
| 209 |
+
torch.tensor(restype_rigid_group_default_frame, device=device),
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
xyz14_noised = feats.frames_and_literature_positions_to_atom14_pos(
|
| 213 |
+
pred_frames,
|
| 214 |
+
S_af2,
|
| 215 |
+
torch.tensor(restype_rigid_group_default_frame, device=device),
|
| 216 |
+
torch.tensor(restype_atom14_to_rigid_group, device=device).long(),
|
| 217 |
+
torch.tensor(restype_atom14_mask, device=device),
|
| 218 |
+
torch.tensor(restype_atom14_rigid_group_positions, device=device),
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
xyz14_m = masks14_37["atom14_atom_exists"] * mask[:, :, None]
|
| 222 |
+
xyz14_noised = xyz14_noised * xyz14_m[:, :, :, None]
|
| 223 |
+
torsion_dict["xyz14_m"] = xyz14_m
|
| 224 |
+
torsion_dict["xyz14_noised"] = xyz14_noised
|
| 225 |
+
torsion_dict["mask_for_loss"] = mask
|
| 226 |
+
torsion_dict["rigids"] = rigids
|
| 227 |
+
torsion_dict["torsions_noised"] = torsions_noised
|
| 228 |
+
torsion_dict["mask_fix_sc"] = mask_fix_sc
|
| 229 |
+
torsion_dict["torsions_true"] = torsions_true
|
| 230 |
+
torsion_dict["S_af2"] = S_af2
|
| 231 |
+
return torsion_dict
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class Packer(nn.Module):
|
| 235 |
+
def __init__(
|
| 236 |
+
self,
|
| 237 |
+
edge_features=128,
|
| 238 |
+
node_features=128,
|
| 239 |
+
num_positional_embeddings=16,
|
| 240 |
+
num_chain_embeddings=16,
|
| 241 |
+
num_rbf=16,
|
| 242 |
+
top_k=30,
|
| 243 |
+
augment_eps=0.0,
|
| 244 |
+
atom37_order=False,
|
| 245 |
+
device=None,
|
| 246 |
+
atom_context_num=16,
|
| 247 |
+
lower_bound=0.0,
|
| 248 |
+
upper_bound=20.0,
|
| 249 |
+
hidden_dim=128,
|
| 250 |
+
num_encoder_layers=3,
|
| 251 |
+
num_decoder_layers=3,
|
| 252 |
+
dropout=0.1,
|
| 253 |
+
num_mix=3,
|
| 254 |
+
):
|
| 255 |
+
super(Packer, self).__init__()
|
| 256 |
+
self.edge_features = edge_features
|
| 257 |
+
self.node_features = node_features
|
| 258 |
+
self.num_positional_embeddings = num_positional_embeddings
|
| 259 |
+
self.num_chain_embeddings = num_chain_embeddings
|
| 260 |
+
self.num_rbf = num_rbf
|
| 261 |
+
self.top_k = top_k
|
| 262 |
+
self.augment_eps = augment_eps
|
| 263 |
+
self.atom37_order = atom37_order
|
| 264 |
+
self.device = device
|
| 265 |
+
self.atom_context_num = atom_context_num
|
| 266 |
+
self.lower_bound = lower_bound
|
| 267 |
+
self.upper_bound = upper_bound
|
| 268 |
+
|
| 269 |
+
self.hidden_dim = hidden_dim
|
| 270 |
+
self.num_encoder_layers = num_encoder_layers
|
| 271 |
+
self.num_decoder_layers = num_decoder_layers
|
| 272 |
+
self.dropout = dropout
|
| 273 |
+
self.softplus = nn.Softplus(beta=1, threshold=20)
|
| 274 |
+
|
| 275 |
+
self.features = ProteinFeatures(
|
| 276 |
+
edge_features=edge_features,
|
| 277 |
+
node_features=node_features,
|
| 278 |
+
num_positional_embeddings=num_positional_embeddings,
|
| 279 |
+
num_chain_embeddings=num_chain_embeddings,
|
| 280 |
+
num_rbf=num_rbf,
|
| 281 |
+
top_k=top_k,
|
| 282 |
+
augment_eps=augment_eps,
|
| 283 |
+
atom37_order=atom37_order,
|
| 284 |
+
device=device,
|
| 285 |
+
atom_context_num=atom_context_num,
|
| 286 |
+
lower_bound=lower_bound,
|
| 287 |
+
upper_bound=upper_bound,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
self.W_e = nn.Linear(edge_features, hidden_dim, bias=True)
|
| 291 |
+
self.W_v = nn.Linear(node_features, hidden_dim, bias=True)
|
| 292 |
+
self.W_f = nn.Linear(edge_features, hidden_dim, bias=True)
|
| 293 |
+
self.W_v_sc = nn.Linear(node_features, hidden_dim, bias=True)
|
| 294 |
+
self.linear_down = nn.Linear(2 * hidden_dim, hidden_dim, bias=True)
|
| 295 |
+
self.W_torsions = nn.Linear(hidden_dim, 4 * 3 * num_mix, bias=True)
|
| 296 |
+
self.num_mix = num_mix
|
| 297 |
+
|
| 298 |
+
self.dropout = nn.Dropout(dropout)
|
| 299 |
+
|
| 300 |
+
# Encoder layers
|
| 301 |
+
self.encoder_layers = nn.ModuleList(
|
| 302 |
+
[
|
| 303 |
+
EncLayer(hidden_dim, hidden_dim * 2, dropout=dropout)
|
| 304 |
+
for _ in range(num_encoder_layers)
|
| 305 |
+
]
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
self.W_c = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 309 |
+
self.W_e_context = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 310 |
+
|
| 311 |
+
self.W_nodes_y = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 312 |
+
self.W_edges_y = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 313 |
+
|
| 314 |
+
self.context_encoder_layers = nn.ModuleList(
|
| 315 |
+
[DecLayer(hidden_dim, hidden_dim * 2, dropout=dropout) for _ in range(2)]
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
self.V_C = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
| 319 |
+
self.V_C_norm = nn.LayerNorm(hidden_dim)
|
| 320 |
+
self.y_context_encoder_layers = nn.ModuleList(
|
| 321 |
+
[DecLayerJ(hidden_dim, hidden_dim, dropout=dropout) for _ in range(2)]
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
self.h_V_C_dropout = nn.Dropout(dropout)
|
| 325 |
+
|
| 326 |
+
# Decoder layers
|
| 327 |
+
self.decoder_layers = nn.ModuleList(
|
| 328 |
+
[
|
| 329 |
+
DecLayer(hidden_dim, hidden_dim * 3, dropout=dropout)
|
| 330 |
+
for _ in range(num_decoder_layers)
|
| 331 |
+
]
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
for p in self.parameters():
|
| 335 |
+
if p.dim() > 1:
|
| 336 |
+
nn.init.xavier_uniform_(p)
|
| 337 |
+
|
| 338 |
+
def encode(self, feature_dict):
|
| 339 |
+
mask = feature_dict["mask"]
|
| 340 |
+
V, E, E_idx, Y_nodes, Y_edges, E_context, Y_m = self.features.features_encode(
|
| 341 |
+
feature_dict
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
h_E_context = self.W_e_context(E_context)
|
| 345 |
+
h_V = self.W_v(V)
|
| 346 |
+
h_E = self.W_e(E)
|
| 347 |
+
mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
|
| 348 |
+
mask_attend = mask.unsqueeze(-1) * mask_attend
|
| 349 |
+
for layer in self.encoder_layers:
|
| 350 |
+
h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
|
| 351 |
+
|
| 352 |
+
h_V_C = self.W_c(h_V)
|
| 353 |
+
Y_m_edges = Y_m[:, :, :, None] * Y_m[:, :, None, :]
|
| 354 |
+
Y_nodes = self.W_nodes_y(Y_nodes)
|
| 355 |
+
Y_edges = self.W_edges_y(Y_edges)
|
| 356 |
+
for i in range(len(self.context_encoder_layers)):
|
| 357 |
+
Y_nodes = self.y_context_encoder_layers[i](Y_nodes, Y_edges, Y_m, Y_m_edges)
|
| 358 |
+
h_E_context_cat = torch.cat([h_E_context, Y_nodes], -1)
|
| 359 |
+
h_V_C = self.context_encoder_layers[i](h_V_C, h_E_context_cat, mask, Y_m)
|
| 360 |
+
|
| 361 |
+
h_V_C = self.V_C(h_V_C)
|
| 362 |
+
h_V = h_V + self.V_C_norm(self.h_V_C_dropout(h_V_C))
|
| 363 |
+
|
| 364 |
+
return h_V, h_E, E_idx
|
| 365 |
+
|
| 366 |
+
def decode(self, feature_dict):
|
| 367 |
+
h_V = feature_dict["h_V"]
|
| 368 |
+
h_E = feature_dict["h_E"]
|
| 369 |
+
E_idx = feature_dict["E_idx"]
|
| 370 |
+
mask = feature_dict["mask"]
|
| 371 |
+
device = h_V.device
|
| 372 |
+
V, F = self.features.features_decode(feature_dict)
|
| 373 |
+
|
| 374 |
+
h_F = self.W_f(F)
|
| 375 |
+
h_EF = torch.cat([h_E, h_F], -1)
|
| 376 |
+
|
| 377 |
+
h_V_sc = self.W_v_sc(V)
|
| 378 |
+
h_V_combined = torch.cat([h_V, h_V_sc], -1)
|
| 379 |
+
h_V = self.linear_down(h_V_combined)
|
| 380 |
+
|
| 381 |
+
for layer in self.decoder_layers:
|
| 382 |
+
h_EV = cat_neighbors_nodes(h_V, h_EF, E_idx)
|
| 383 |
+
h_V = layer(h_V, h_EV, mask)
|
| 384 |
+
|
| 385 |
+
torsions = self.W_torsions(h_V)
|
| 386 |
+
torsions = torsions.reshape(h_V.shape[0], h_V.shape[1], 4, self.num_mix, 3)
|
| 387 |
+
mean = torsions[:, :, :, :, 0].float()
|
| 388 |
+
concentration = 0.1 + self.softplus(torsions[:, :, :, :, 1]).float()
|
| 389 |
+
mix_logits = torsions[:, :, :, :, 2].float()
|
| 390 |
+
return mean, concentration, mix_logits
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class ProteinFeatures(nn.Module):
|
| 394 |
+
def __init__(
|
| 395 |
+
self,
|
| 396 |
+
edge_features=128,
|
| 397 |
+
node_features=128,
|
| 398 |
+
num_positional_embeddings=16,
|
| 399 |
+
num_chain_embeddings=16,
|
| 400 |
+
num_rbf=16,
|
| 401 |
+
top_k=30,
|
| 402 |
+
augment_eps=0.0,
|
| 403 |
+
atom37_order=False,
|
| 404 |
+
device=None,
|
| 405 |
+
atom_context_num=16,
|
| 406 |
+
lower_bound=0.0,
|
| 407 |
+
upper_bound=20.0,
|
| 408 |
+
):
|
| 409 |
+
"""Extract protein features"""
|
| 410 |
+
super(ProteinFeatures, self).__init__()
|
| 411 |
+
self.edge_features = edge_features
|
| 412 |
+
self.node_features = node_features
|
| 413 |
+
self.num_positional_embeddings = num_positional_embeddings
|
| 414 |
+
self.num_chain_embeddings = num_chain_embeddings
|
| 415 |
+
self.num_rbf = num_rbf
|
| 416 |
+
self.top_k = top_k
|
| 417 |
+
self.augment_eps = augment_eps
|
| 418 |
+
self.atom37_order = atom37_order
|
| 419 |
+
self.device = device
|
| 420 |
+
self.atom_context_num = atom_context_num
|
| 421 |
+
self.lower_bound = lower_bound
|
| 422 |
+
self.upper_bound = upper_bound
|
| 423 |
+
|
| 424 |
+
# deal with oxygen index
|
| 425 |
+
# ------
|
| 426 |
+
self.N_idx = 0
|
| 427 |
+
self.CA_idx = 1
|
| 428 |
+
self.C_idx = 2
|
| 429 |
+
|
| 430 |
+
if atom37_order:
|
| 431 |
+
self.O_idx = 4
|
| 432 |
+
else:
|
| 433 |
+
self.O_idx = 3
|
| 434 |
+
# -------
|
| 435 |
+
self.positional_embeddings = PositionalEncodings(num_positional_embeddings)
|
| 436 |
+
|
| 437 |
+
# Features for the encoder
|
| 438 |
+
enc_node_in = 21 # alphabet for the sequence
|
| 439 |
+
enc_edge_in = (
|
| 440 |
+
num_positional_embeddings + num_rbf * 25
|
| 441 |
+
) # positional + distance features
|
| 442 |
+
|
| 443 |
+
self.enc_node_in = enc_node_in
|
| 444 |
+
self.enc_edge_in = enc_edge_in
|
| 445 |
+
|
| 446 |
+
self.enc_edge_embedding = nn.Linear(enc_edge_in, edge_features, bias=False)
|
| 447 |
+
self.enc_norm_edges = nn.LayerNorm(edge_features)
|
| 448 |
+
self.enc_node_embedding = nn.Linear(enc_node_in, node_features, bias=False)
|
| 449 |
+
self.enc_norm_nodes = nn.LayerNorm(node_features)
|
| 450 |
+
|
| 451 |
+
# Features for the decoder
|
| 452 |
+
dec_node_in = 14 * atom_context_num * num_rbf
|
| 453 |
+
dec_edge_in = num_rbf * 14 * 14 + 42
|
| 454 |
+
|
| 455 |
+
self.dec_node_in = dec_node_in
|
| 456 |
+
self.dec_edge_in = dec_edge_in
|
| 457 |
+
|
| 458 |
+
self.W_XY_project_down1 = nn.Linear(num_rbf + 120, num_rbf, bias=True)
|
| 459 |
+
self.dec_edge_embedding1 = nn.Linear(dec_edge_in, edge_features, bias=False)
|
| 460 |
+
self.dec_norm_edges1 = nn.LayerNorm(edge_features)
|
| 461 |
+
self.dec_node_embedding1 = nn.Linear(dec_node_in, node_features, bias=False)
|
| 462 |
+
self.dec_norm_nodes1 = nn.LayerNorm(node_features)
|
| 463 |
+
|
| 464 |
+
self.node_project_down = nn.Linear(
|
| 465 |
+
5 * num_rbf + 64 + 4, node_features, bias=True
|
| 466 |
+
)
|
| 467 |
+
self.norm_nodes = nn.LayerNorm(node_features)
|
| 468 |
+
|
| 469 |
+
self.type_linear = nn.Linear(147, 64)
|
| 470 |
+
|
| 471 |
+
self.y_nodes = nn.Linear(147, node_features, bias=False)
|
| 472 |
+
self.y_edges = nn.Linear(num_rbf, node_features, bias=False)
|
| 473 |
+
|
| 474 |
+
self.norm_y_edges = nn.LayerNorm(node_features)
|
| 475 |
+
self.norm_y_nodes = nn.LayerNorm(node_features)
|
| 476 |
+
|
| 477 |
+
self.periodic_table_features = torch.tensor(
|
| 478 |
+
[
|
| 479 |
+
[
|
| 480 |
+
0,
|
| 481 |
+
1,
|
| 482 |
+
2,
|
| 483 |
+
3,
|
| 484 |
+
4,
|
| 485 |
+
5,
|
| 486 |
+
6,
|
| 487 |
+
7,
|
| 488 |
+
8,
|
| 489 |
+
9,
|
| 490 |
+
10,
|
| 491 |
+
11,
|
| 492 |
+
12,
|
| 493 |
+
13,
|
| 494 |
+
14,
|
| 495 |
+
15,
|
| 496 |
+
16,
|
| 497 |
+
17,
|
| 498 |
+
18,
|
| 499 |
+
19,
|
| 500 |
+
20,
|
| 501 |
+
21,
|
| 502 |
+
22,
|
| 503 |
+
23,
|
| 504 |
+
24,
|
| 505 |
+
25,
|
| 506 |
+
26,
|
| 507 |
+
27,
|
| 508 |
+
28,
|
| 509 |
+
29,
|
| 510 |
+
30,
|
| 511 |
+
31,
|
| 512 |
+
32,
|
| 513 |
+
33,
|
| 514 |
+
34,
|
| 515 |
+
35,
|
| 516 |
+
36,
|
| 517 |
+
37,
|
| 518 |
+
38,
|
| 519 |
+
39,
|
| 520 |
+
40,
|
| 521 |
+
41,
|
| 522 |
+
42,
|
| 523 |
+
43,
|
| 524 |
+
44,
|
| 525 |
+
45,
|
| 526 |
+
46,
|
| 527 |
+
47,
|
| 528 |
+
48,
|
| 529 |
+
49,
|
| 530 |
+
50,
|
| 531 |
+
51,
|
| 532 |
+
52,
|
| 533 |
+
53,
|
| 534 |
+
54,
|
| 535 |
+
55,
|
| 536 |
+
56,
|
| 537 |
+
57,
|
| 538 |
+
58,
|
| 539 |
+
59,
|
| 540 |
+
60,
|
| 541 |
+
61,
|
| 542 |
+
62,
|
| 543 |
+
63,
|
| 544 |
+
64,
|
| 545 |
+
65,
|
| 546 |
+
66,
|
| 547 |
+
67,
|
| 548 |
+
68,
|
| 549 |
+
69,
|
| 550 |
+
70,
|
| 551 |
+
71,
|
| 552 |
+
72,
|
| 553 |
+
73,
|
| 554 |
+
74,
|
| 555 |
+
75,
|
| 556 |
+
76,
|
| 557 |
+
77,
|
| 558 |
+
78,
|
| 559 |
+
79,
|
| 560 |
+
80,
|
| 561 |
+
81,
|
| 562 |
+
82,
|
| 563 |
+
83,
|
| 564 |
+
84,
|
| 565 |
+
85,
|
| 566 |
+
86,
|
| 567 |
+
87,
|
| 568 |
+
88,
|
| 569 |
+
89,
|
| 570 |
+
90,
|
| 571 |
+
91,
|
| 572 |
+
92,
|
| 573 |
+
93,
|
| 574 |
+
94,
|
| 575 |
+
95,
|
| 576 |
+
96,
|
| 577 |
+
97,
|
| 578 |
+
98,
|
| 579 |
+
99,
|
| 580 |
+
100,
|
| 581 |
+
101,
|
| 582 |
+
102,
|
| 583 |
+
103,
|
| 584 |
+
104,
|
| 585 |
+
105,
|
| 586 |
+
106,
|
| 587 |
+
107,
|
| 588 |
+
108,
|
| 589 |
+
109,
|
| 590 |
+
110,
|
| 591 |
+
111,
|
| 592 |
+
112,
|
| 593 |
+
113,
|
| 594 |
+
114,
|
| 595 |
+
115,
|
| 596 |
+
116,
|
| 597 |
+
117,
|
| 598 |
+
118,
|
| 599 |
+
],
|
| 600 |
+
[
|
| 601 |
+
0,
|
| 602 |
+
1,
|
| 603 |
+
18,
|
| 604 |
+
1,
|
| 605 |
+
2,
|
| 606 |
+
13,
|
| 607 |
+
14,
|
| 608 |
+
15,
|
| 609 |
+
16,
|
| 610 |
+
17,
|
| 611 |
+
18,
|
| 612 |
+
1,
|
| 613 |
+
2,
|
| 614 |
+
13,
|
| 615 |
+
14,
|
| 616 |
+
15,
|
| 617 |
+
16,
|
| 618 |
+
17,
|
| 619 |
+
18,
|
| 620 |
+
1,
|
| 621 |
+
2,
|
| 622 |
+
3,
|
| 623 |
+
4,
|
| 624 |
+
5,
|
| 625 |
+
6,
|
| 626 |
+
7,
|
| 627 |
+
8,
|
| 628 |
+
9,
|
| 629 |
+
10,
|
| 630 |
+
11,
|
| 631 |
+
12,
|
| 632 |
+
13,
|
| 633 |
+
14,
|
| 634 |
+
15,
|
| 635 |
+
16,
|
| 636 |
+
17,
|
| 637 |
+
18,
|
| 638 |
+
1,
|
| 639 |
+
2,
|
| 640 |
+
3,
|
| 641 |
+
4,
|
| 642 |
+
5,
|
| 643 |
+
6,
|
| 644 |
+
7,
|
| 645 |
+
8,
|
| 646 |
+
9,
|
| 647 |
+
10,
|
| 648 |
+
11,
|
| 649 |
+
12,
|
| 650 |
+
13,
|
| 651 |
+
14,
|
| 652 |
+
15,
|
| 653 |
+
16,
|
| 654 |
+
17,
|
| 655 |
+
18,
|
| 656 |
+
1,
|
| 657 |
+
2,
|
| 658 |
+
3,
|
| 659 |
+
3,
|
| 660 |
+
3,
|
| 661 |
+
3,
|
| 662 |
+
3,
|
| 663 |
+
3,
|
| 664 |
+
3,
|
| 665 |
+
3,
|
| 666 |
+
3,
|
| 667 |
+
3,
|
| 668 |
+
3,
|
| 669 |
+
3,
|
| 670 |
+
3,
|
| 671 |
+
3,
|
| 672 |
+
3,
|
| 673 |
+
4,
|
| 674 |
+
5,
|
| 675 |
+
6,
|
| 676 |
+
7,
|
| 677 |
+
8,
|
| 678 |
+
9,
|
| 679 |
+
10,
|
| 680 |
+
11,
|
| 681 |
+
12,
|
| 682 |
+
13,
|
| 683 |
+
14,
|
| 684 |
+
15,
|
| 685 |
+
16,
|
| 686 |
+
17,
|
| 687 |
+
18,
|
| 688 |
+
1,
|
| 689 |
+
2,
|
| 690 |
+
3,
|
| 691 |
+
3,
|
| 692 |
+
3,
|
| 693 |
+
3,
|
| 694 |
+
3,
|
| 695 |
+
3,
|
| 696 |
+
3,
|
| 697 |
+
3,
|
| 698 |
+
3,
|
| 699 |
+
3,
|
| 700 |
+
3,
|
| 701 |
+
3,
|
| 702 |
+
3,
|
| 703 |
+
3,
|
| 704 |
+
3,
|
| 705 |
+
4,
|
| 706 |
+
5,
|
| 707 |
+
6,
|
| 708 |
+
7,
|
| 709 |
+
8,
|
| 710 |
+
9,
|
| 711 |
+
10,
|
| 712 |
+
11,
|
| 713 |
+
12,
|
| 714 |
+
13,
|
| 715 |
+
14,
|
| 716 |
+
15,
|
| 717 |
+
16,
|
| 718 |
+
17,
|
| 719 |
+
18,
|
| 720 |
+
],
|
| 721 |
+
[
|
| 722 |
+
0,
|
| 723 |
+
1,
|
| 724 |
+
1,
|
| 725 |
+
2,
|
| 726 |
+
2,
|
| 727 |
+
2,
|
| 728 |
+
2,
|
| 729 |
+
2,
|
| 730 |
+
2,
|
| 731 |
+
2,
|
| 732 |
+
2,
|
| 733 |
+
3,
|
| 734 |
+
3,
|
| 735 |
+
3,
|
| 736 |
+
3,
|
| 737 |
+
3,
|
| 738 |
+
3,
|
| 739 |
+
3,
|
| 740 |
+
3,
|
| 741 |
+
4,
|
| 742 |
+
4,
|
| 743 |
+
4,
|
| 744 |
+
4,
|
| 745 |
+
4,
|
| 746 |
+
4,
|
| 747 |
+
4,
|
| 748 |
+
4,
|
| 749 |
+
4,
|
| 750 |
+
4,
|
| 751 |
+
4,
|
| 752 |
+
4,
|
| 753 |
+
4,
|
| 754 |
+
4,
|
| 755 |
+
4,
|
| 756 |
+
4,
|
| 757 |
+
4,
|
| 758 |
+
4,
|
| 759 |
+
5,
|
| 760 |
+
5,
|
| 761 |
+
5,
|
| 762 |
+
5,
|
| 763 |
+
5,
|
| 764 |
+
5,
|
| 765 |
+
5,
|
| 766 |
+
5,
|
| 767 |
+
5,
|
| 768 |
+
5,
|
| 769 |
+
5,
|
| 770 |
+
5,
|
| 771 |
+
5,
|
| 772 |
+
5,
|
| 773 |
+
5,
|
| 774 |
+
5,
|
| 775 |
+
5,
|
| 776 |
+
5,
|
| 777 |
+
6,
|
| 778 |
+
6,
|
| 779 |
+
6,
|
| 780 |
+
6,
|
| 781 |
+
6,
|
| 782 |
+
6,
|
| 783 |
+
6,
|
| 784 |
+
6,
|
| 785 |
+
6,
|
| 786 |
+
6,
|
| 787 |
+
6,
|
| 788 |
+
6,
|
| 789 |
+
6,
|
| 790 |
+
6,
|
| 791 |
+
6,
|
| 792 |
+
6,
|
| 793 |
+
6,
|
| 794 |
+
6,
|
| 795 |
+
6,
|
| 796 |
+
6,
|
| 797 |
+
6,
|
| 798 |
+
6,
|
| 799 |
+
6,
|
| 800 |
+
6,
|
| 801 |
+
6,
|
| 802 |
+
6,
|
| 803 |
+
6,
|
| 804 |
+
6,
|
| 805 |
+
6,
|
| 806 |
+
6,
|
| 807 |
+
6,
|
| 808 |
+
6,
|
| 809 |
+
7,
|
| 810 |
+
7,
|
| 811 |
+
7,
|
| 812 |
+
7,
|
| 813 |
+
7,
|
| 814 |
+
7,
|
| 815 |
+
7,
|
| 816 |
+
7,
|
| 817 |
+
7,
|
| 818 |
+
7,
|
| 819 |
+
7,
|
| 820 |
+
7,
|
| 821 |
+
7,
|
| 822 |
+
7,
|
| 823 |
+
7,
|
| 824 |
+
7,
|
| 825 |
+
7,
|
| 826 |
+
7,
|
| 827 |
+
7,
|
| 828 |
+
7,
|
| 829 |
+
7,
|
| 830 |
+
7,
|
| 831 |
+
7,
|
| 832 |
+
7,
|
| 833 |
+
7,
|
| 834 |
+
7,
|
| 835 |
+
7,
|
| 836 |
+
7,
|
| 837 |
+
7,
|
| 838 |
+
7,
|
| 839 |
+
7,
|
| 840 |
+
7,
|
| 841 |
+
],
|
| 842 |
+
],
|
| 843 |
+
dtype=torch.long,
|
| 844 |
+
device=device,
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
def _dist(self, X, mask, eps=1e-6):
|
| 848 |
+
mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2)
|
| 849 |
+
dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2)
|
| 850 |
+
D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
|
| 851 |
+
D_max, _ = torch.max(D, -1, keepdim=True)
|
| 852 |
+
D_adjust = D + (1.0 - mask_2D) * D_max
|
| 853 |
+
sampled_top_k = self.top_k
|
| 854 |
+
D_neighbors, E_idx = torch.topk(
|
| 855 |
+
D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False
|
| 856 |
+
)
|
| 857 |
+
return D_neighbors, E_idx
|
| 858 |
+
|
| 859 |
+
def _make_angle_features(self, A, B, C, Y):
|
| 860 |
+
v1 = A - B
|
| 861 |
+
v2 = C - B
|
| 862 |
+
e1 = torch.nn.functional.normalize(v1, dim=-1)
|
| 863 |
+
e1_v2_dot = torch.einsum("bli, bli -> bl", e1, v2)[..., None]
|
| 864 |
+
u2 = v2 - e1 * e1_v2_dot
|
| 865 |
+
e2 = torch.nn.functional.normalize(u2, dim=-1)
|
| 866 |
+
e3 = torch.cross(e1, e2, dim=-1)
|
| 867 |
+
R_residue = torch.cat(
|
| 868 |
+
(e1[:, :, :, None], e2[:, :, :, None], e3[:, :, :, None]), dim=-1
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
local_vectors = torch.einsum(
|
| 872 |
+
"blqp, blyq -> blyp", R_residue, Y - B[:, :, None, :]
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
rxy = torch.sqrt(local_vectors[..., 0] ** 2 + local_vectors[..., 1] ** 2 + 1e-8)
|
| 876 |
+
f1 = local_vectors[..., 0] / rxy
|
| 877 |
+
f2 = local_vectors[..., 1] / rxy
|
| 878 |
+
rxyz = torch.norm(local_vectors, dim=-1) + 1e-8
|
| 879 |
+
f3 = rxy / rxyz
|
| 880 |
+
f4 = local_vectors[..., 2] / rxyz
|
| 881 |
+
|
| 882 |
+
f = torch.cat([f1[..., None], f2[..., None], f3[..., None], f4[..., None]], -1)
|
| 883 |
+
return f
|
| 884 |
+
|
| 885 |
+
def _rbf(
|
| 886 |
+
self,
|
| 887 |
+
D,
|
| 888 |
+
D_mu_shape=[1, 1, 1, -1],
|
| 889 |
+
lower_bound=0.0,
|
| 890 |
+
upper_bound=20.0,
|
| 891 |
+
num_bins=16,
|
| 892 |
+
):
|
| 893 |
+
device = D.device
|
| 894 |
+
D_min, D_max, D_count = lower_bound, upper_bound, num_bins
|
| 895 |
+
D_mu = torch.linspace(D_min, D_max, D_count, device=device)
|
| 896 |
+
D_mu = D_mu.view(D_mu_shape)
|
| 897 |
+
D_sigma = (D_max - D_min) / D_count
|
| 898 |
+
D_expand = torch.unsqueeze(D, -1)
|
| 899 |
+
RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2))
|
| 900 |
+
return RBF
|
| 901 |
+
|
| 902 |
+
def _get_rbf(
|
| 903 |
+
self,
|
| 904 |
+
A,
|
| 905 |
+
B,
|
| 906 |
+
E_idx,
|
| 907 |
+
D_mu_shape=[1, 1, 1, -1],
|
| 908 |
+
lower_bound=2.0,
|
| 909 |
+
upper_bound=22.0,
|
| 910 |
+
num_bins=16,
|
| 911 |
+
):
|
| 912 |
+
D_A_B = torch.sqrt(
|
| 913 |
+
torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6
|
| 914 |
+
) # [B, L, L]
|
| 915 |
+
D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[
|
| 916 |
+
:, :, :, 0
|
| 917 |
+
] # [B,L,K]
|
| 918 |
+
RBF_A_B = self._rbf(
|
| 919 |
+
D_A_B_neighbors,
|
| 920 |
+
D_mu_shape=D_mu_shape,
|
| 921 |
+
lower_bound=lower_bound,
|
| 922 |
+
upper_bound=upper_bound,
|
| 923 |
+
num_bins=num_bins,
|
| 924 |
+
)
|
| 925 |
+
return RBF_A_B
|
| 926 |
+
|
| 927 |
+
def features_encode(self, features):
|
| 928 |
+
"""
|
| 929 |
+
make protein graph and encode backbone
|
| 930 |
+
"""
|
| 931 |
+
S = features["S"]
|
| 932 |
+
X = features["X"]
|
| 933 |
+
Y = features["Y"]
|
| 934 |
+
Y_m = features["Y_m"]
|
| 935 |
+
Y_t = features["Y_t"]
|
| 936 |
+
mask = features["mask"]
|
| 937 |
+
R_idx = features["R_idx"]
|
| 938 |
+
chain_labels = features["chain_labels"]
|
| 939 |
+
|
| 940 |
+
if self.training and self.augment_eps > 0:
|
| 941 |
+
X = X + self.augment_eps * torch.randn_like(X)
|
| 942 |
+
|
| 943 |
+
Ca = X[:, :, self.CA_idx, :]
|
| 944 |
+
N = X[:, :, self.N_idx, :]
|
| 945 |
+
C = X[:, :, self.C_idx, :]
|
| 946 |
+
O = X[:, :, self.O_idx, :]
|
| 947 |
+
|
| 948 |
+
b = Ca - N
|
| 949 |
+
c = C - Ca
|
| 950 |
+
a = torch.cross(b, c, dim=-1)
|
| 951 |
+
Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca # shift from CA
|
| 952 |
+
|
| 953 |
+
_, E_idx = self._dist(Ca, mask)
|
| 954 |
+
|
| 955 |
+
backbone_coords_list = [N, Ca, C, O, Cb]
|
| 956 |
+
|
| 957 |
+
RBF_all = []
|
| 958 |
+
for atom_1 in backbone_coords_list:
|
| 959 |
+
for atom_2 in backbone_coords_list:
|
| 960 |
+
RBF_all.append(
|
| 961 |
+
self._get_rbf(
|
| 962 |
+
atom_1,
|
| 963 |
+
atom_2,
|
| 964 |
+
E_idx,
|
| 965 |
+
D_mu_shape=[1, 1, 1, -1],
|
| 966 |
+
lower_bound=self.lower_bound,
|
| 967 |
+
upper_bound=self.upper_bound,
|
| 968 |
+
num_bins=self.num_rbf,
|
| 969 |
+
)
|
| 970 |
+
)
|
| 971 |
+
RBF_all = torch.cat(tuple(RBF_all), dim=-1)
|
| 972 |
+
|
| 973 |
+
offset = R_idx[:, :, None] - R_idx[:, None, :]
|
| 974 |
+
offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K]
|
| 975 |
+
|
| 976 |
+
d_chains = (
|
| 977 |
+
(chain_labels[:, :, None] - chain_labels[:, None, :]) == 0
|
| 978 |
+
).long() # find self vs non-self interaction
|
| 979 |
+
E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0]
|
| 980 |
+
E_positional = self.positional_embeddings(offset.long(), E_chains)
|
| 981 |
+
E = torch.cat((E_positional, RBF_all), -1)
|
| 982 |
+
E = self.enc_edge_embedding(E)
|
| 983 |
+
E = self.enc_norm_edges(E)
|
| 984 |
+
|
| 985 |
+
V = torch.nn.functional.one_hot(S, self.enc_node_in).float()
|
| 986 |
+
V = self.enc_node_embedding(V)
|
| 987 |
+
V = self.enc_norm_nodes(V)
|
| 988 |
+
|
| 989 |
+
Y_t = Y_t.long()
|
| 990 |
+
Y_t_g = self.periodic_table_features[1][Y_t] # group; 19 categories including 0
|
| 991 |
+
Y_t_p = self.periodic_table_features[2][Y_t] # period; 8 categories including 0
|
| 992 |
+
|
| 993 |
+
Y_t_g_1hot_ = torch.nn.functional.one_hot(Y_t_g, 19) # [B, L, M, 19]
|
| 994 |
+
Y_t_p_1hot_ = torch.nn.functional.one_hot(Y_t_p, 8) # [B, L, M, 8]
|
| 995 |
+
Y_t_1hot_ = torch.nn.functional.one_hot(Y_t, 120) # [B, L, M, 120]
|
| 996 |
+
|
| 997 |
+
Y_t_1hot_ = torch.cat(
|
| 998 |
+
[Y_t_1hot_, Y_t_g_1hot_, Y_t_p_1hot_], -1
|
| 999 |
+
) # [B, L, M, 147]
|
| 1000 |
+
Y_t_1hot = self.type_linear(Y_t_1hot_.float())
|
| 1001 |
+
|
| 1002 |
+
D_N_Y = torch.sqrt(
|
| 1003 |
+
torch.sum((N[:, :, None, :] - Y) ** 2, -1) + 1e-6
|
| 1004 |
+
) # [B, L, M, num_bins]
|
| 1005 |
+
D_N_Y = self._rbf(
|
| 1006 |
+
D_N_Y,
|
| 1007 |
+
D_mu_shape=[1, 1, 1, -1],
|
| 1008 |
+
lower_bound=self.lower_bound,
|
| 1009 |
+
upper_bound=self.upper_bound,
|
| 1010 |
+
num_bins=self.num_rbf,
|
| 1011 |
+
)
|
| 1012 |
+
|
| 1013 |
+
D_Ca_Y = torch.sqrt(
|
| 1014 |
+
torch.sum((Ca[:, :, None, :] - Y) ** 2, -1) + 1e-6
|
| 1015 |
+
) # [B, L, M, num_bins]
|
| 1016 |
+
D_Ca_Y = self._rbf(
|
| 1017 |
+
D_Ca_Y,
|
| 1018 |
+
D_mu_shape=[1, 1, 1, -1],
|
| 1019 |
+
lower_bound=self.lower_bound,
|
| 1020 |
+
upper_bound=self.upper_bound,
|
| 1021 |
+
num_bins=self.num_rbf,
|
| 1022 |
+
)
|
| 1023 |
+
|
| 1024 |
+
D_C_Y = torch.sqrt(
|
| 1025 |
+
torch.sum((C[:, :, None, :] - Y) ** 2, -1) + 1e-6
|
| 1026 |
+
) # [B, L, M, num_bins]
|
| 1027 |
+
D_C_Y = self._rbf(
|
| 1028 |
+
D_C_Y,
|
| 1029 |
+
D_mu_shape=[1, 1, 1, -1],
|
| 1030 |
+
lower_bound=self.lower_bound,
|
| 1031 |
+
upper_bound=self.upper_bound,
|
| 1032 |
+
num_bins=self.num_rbf,
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
D_O_Y = torch.sqrt(
|
| 1036 |
+
torch.sum((O[:, :, None, :] - Y) ** 2, -1) + 1e-6
|
| 1037 |
+
) # [B, L, M, num_bins]
|
| 1038 |
+
D_O_Y = self._rbf(
|
| 1039 |
+
D_O_Y,
|
| 1040 |
+
D_mu_shape=[1, 1, 1, -1],
|
| 1041 |
+
lower_bound=self.lower_bound,
|
| 1042 |
+
upper_bound=self.upper_bound,
|
| 1043 |
+
num_bins=self.num_rbf,
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
D_Cb_Y = torch.sqrt(
|
| 1047 |
+
torch.sum((Cb[:, :, None, :] - Y) ** 2, -1) + 1e-6
|
| 1048 |
+
) # [B, L, M, num_bins]
|
| 1049 |
+
D_Cb_Y = self._rbf(
|
| 1050 |
+
D_Cb_Y,
|
| 1051 |
+
D_mu_shape=[1, 1, 1, -1],
|
| 1052 |
+
lower_bound=self.lower_bound,
|
| 1053 |
+
upper_bound=self.upper_bound,
|
| 1054 |
+
num_bins=self.num_rbf,
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
f_angles = self._make_angle_features(N, Ca, C, Y)
|
| 1058 |
+
|
| 1059 |
+
D_all = torch.cat(
|
| 1060 |
+
(D_N_Y, D_Ca_Y, D_C_Y, D_O_Y, D_Cb_Y, Y_t_1hot, f_angles), dim=-1
|
| 1061 |
+
) # [B,L,M,5*num_bins+5]
|
| 1062 |
+
E_context = self.node_project_down(D_all) # [B, L, M, node_features]
|
| 1063 |
+
E_context = self.norm_nodes(E_context)
|
| 1064 |
+
|
| 1065 |
+
Y_edges = self._rbf(
|
| 1066 |
+
torch.sqrt(
|
| 1067 |
+
torch.sum((Y[:, :, :, None, :] - Y[:, :, None, :, :]) ** 2, -1) + 1e-6
|
| 1068 |
+
)
|
| 1069 |
+
) # [B, L, M, M, num_bins]
|
| 1070 |
+
|
| 1071 |
+
Y_edges = self.y_edges(Y_edges)
|
| 1072 |
+
Y_nodes = self.y_nodes(Y_t_1hot_.float())
|
| 1073 |
+
|
| 1074 |
+
Y_edges = self.norm_y_edges(Y_edges)
|
| 1075 |
+
Y_nodes = self.norm_y_nodes(Y_nodes)
|
| 1076 |
+
|
| 1077 |
+
return V, E, E_idx, Y_nodes, Y_edges, E_context, Y_m
|
| 1078 |
+
|
| 1079 |
+
def features_decode(self, features):
|
| 1080 |
+
"""
|
| 1081 |
+
Make features for decoding. Explicit side chain atom and other atom distances.
|
| 1082 |
+
"""
|
| 1083 |
+
|
| 1084 |
+
S = features["S"]
|
| 1085 |
+
X = features["X"]
|
| 1086 |
+
X_m = features["X_m"]
|
| 1087 |
+
mask = features["mask"]
|
| 1088 |
+
E_idx = features["E_idx"]
|
| 1089 |
+
|
| 1090 |
+
Y = features["Y"][:, :, : self.atom_context_num]
|
| 1091 |
+
Y_m = features["Y_m"][:, :, : self.atom_context_num]
|
| 1092 |
+
Y_t = features["Y_t"][:, :, : self.atom_context_num]
|
| 1093 |
+
|
| 1094 |
+
X_m = X_m * mask[:, :, None]
|
| 1095 |
+
device = S.device
|
| 1096 |
+
|
| 1097 |
+
B, L, _, _ = X.shape
|
| 1098 |
+
|
| 1099 |
+
RBF_sidechain = []
|
| 1100 |
+
X_m_gathered = gather_nodes(X_m, E_idx) # [B, L, K, 14]
|
| 1101 |
+
|
| 1102 |
+
for i in range(14):
|
| 1103 |
+
for j in range(14):
|
| 1104 |
+
rbf_features = self._get_rbf(
|
| 1105 |
+
X[:, :, i, :],
|
| 1106 |
+
X[:, :, j, :],
|
| 1107 |
+
E_idx,
|
| 1108 |
+
D_mu_shape=[1, 1, 1, -1],
|
| 1109 |
+
lower_bound=self.lower_bound,
|
| 1110 |
+
upper_bound=self.upper_bound,
|
| 1111 |
+
num_bins=self.num_rbf,
|
| 1112 |
+
)
|
| 1113 |
+
rbf_features = (
|
| 1114 |
+
rbf_features
|
| 1115 |
+
* X_m[:, :, i, None, None]
|
| 1116 |
+
* X_m_gathered[:, :, :, j, None]
|
| 1117 |
+
)
|
| 1118 |
+
RBF_sidechain.append(rbf_features)
|
| 1119 |
+
|
| 1120 |
+
D_XY = torch.sqrt(
|
| 1121 |
+
torch.sum((X[:, :, :, None, :] - Y[:, :, None, :, :]) ** 2, -1) + 1e-6
|
| 1122 |
+
) # [B, L, 14, atom_context_num]
|
| 1123 |
+
XY_features = self._rbf(
|
| 1124 |
+
D_XY,
|
| 1125 |
+
D_mu_shape=[1, 1, 1, 1, -1],
|
| 1126 |
+
lower_bound=self.lower_bound,
|
| 1127 |
+
upper_bound=self.upper_bound,
|
| 1128 |
+
num_bins=self.num_rbf,
|
| 1129 |
+
) # [B, L, 14, atom_context_num, num_rbf]
|
| 1130 |
+
XY_features = XY_features * X_m[:, :, :, None, None] * Y_m[:, :, None, :, None]
|
| 1131 |
+
|
| 1132 |
+
Y_t_1hot = torch.nn.functional.one_hot(
|
| 1133 |
+
Y_t.long(), 120
|
| 1134 |
+
).float() # [B, L, atom_context_num, 120]
|
| 1135 |
+
XY_Y_t = torch.cat(
|
| 1136 |
+
[XY_features, Y_t_1hot[:, :, None, :, :].repeat(1, 1, 14, 1, 1)], -1
|
| 1137 |
+
) # [B, L, 14, atom_context_num, num_rbf+120]
|
| 1138 |
+
XY_Y_t = self.W_XY_project_down1(
|
| 1139 |
+
XY_Y_t
|
| 1140 |
+
) # [B, L, 14, atom_context_num, num_rbf]
|
| 1141 |
+
XY_features = XY_Y_t.view([B, L, -1])
|
| 1142 |
+
|
| 1143 |
+
V = self.dec_node_embedding1(XY_features)
|
| 1144 |
+
V = self.dec_norm_nodes1(V)
|
| 1145 |
+
|
| 1146 |
+
S_1h = torch.nn.functional.one_hot(S, self.enc_node_in).float()
|
| 1147 |
+
S_1h_gathered = gather_nodes(S_1h, E_idx) # [B, L, K, 21]
|
| 1148 |
+
S_features = torch.cat(
|
| 1149 |
+
[S_1h[:, :, None, :].repeat(1, 1, E_idx.shape[2], 1), S_1h_gathered], -1
|
| 1150 |
+
) # [B, L, K, 42]
|
| 1151 |
+
|
| 1152 |
+
F = torch.cat(
|
| 1153 |
+
tuple(RBF_sidechain), dim=-1
|
| 1154 |
+
) # [B,L,atom_context_num,14*14*num_rbf]
|
| 1155 |
+
F = torch.cat([F, S_features], -1)
|
| 1156 |
+
F = self.dec_edge_embedding1(F)
|
| 1157 |
+
F = self.dec_norm_edges1(F)
|
| 1158 |
+
return V, F
|
score.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os.path
|
| 4 |
+
import random
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from data_utils import (
|
| 11 |
+
element_dict_rev,
|
| 12 |
+
alphabet,
|
| 13 |
+
restype_int_to_str,
|
| 14 |
+
featurize,
|
| 15 |
+
parse_PDB,
|
| 16 |
+
)
|
| 17 |
+
from model_utils import ProteinMPNN
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def main(args) -> None:
|
| 21 |
+
"""
|
| 22 |
+
Inference function
|
| 23 |
+
"""
|
| 24 |
+
if args.seed:
|
| 25 |
+
seed = args.seed
|
| 26 |
+
else:
|
| 27 |
+
seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0])
|
| 28 |
+
torch.manual_seed(seed)
|
| 29 |
+
random.seed(seed)
|
| 30 |
+
np.random.seed(seed)
|
| 31 |
+
device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
|
| 32 |
+
folder_for_outputs = args.out_folder
|
| 33 |
+
base_folder = folder_for_outputs
|
| 34 |
+
if base_folder[-1] != "/":
|
| 35 |
+
base_folder = base_folder + "/"
|
| 36 |
+
if not os.path.exists(base_folder):
|
| 37 |
+
os.makedirs(base_folder, exist_ok=True)
|
| 38 |
+
if args.model_type == "protein_mpnn":
|
| 39 |
+
checkpoint_path = args.checkpoint_protein_mpnn
|
| 40 |
+
elif args.model_type == "ligand_mpnn":
|
| 41 |
+
checkpoint_path = args.checkpoint_ligand_mpnn
|
| 42 |
+
elif args.model_type == "per_residue_label_membrane_mpnn":
|
| 43 |
+
checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn
|
| 44 |
+
elif args.model_type == "global_label_membrane_mpnn":
|
| 45 |
+
checkpoint_path = args.checkpoint_global_label_membrane_mpnn
|
| 46 |
+
elif args.model_type == "soluble_mpnn":
|
| 47 |
+
checkpoint_path = args.checkpoint_soluble_mpnn
|
| 48 |
+
else:
|
| 49 |
+
print("Choose one of the available models")
|
| 50 |
+
sys.exit()
|
| 51 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 52 |
+
if args.model_type == "ligand_mpnn":
|
| 53 |
+
atom_context_num = checkpoint["atom_context_num"]
|
| 54 |
+
ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context
|
| 55 |
+
k_neighbors = checkpoint["num_edges"]
|
| 56 |
+
else:
|
| 57 |
+
atom_context_num = 1
|
| 58 |
+
ligand_mpnn_use_side_chain_context = 0
|
| 59 |
+
k_neighbors = checkpoint["num_edges"]
|
| 60 |
+
|
| 61 |
+
model = ProteinMPNN(
|
| 62 |
+
node_features=128,
|
| 63 |
+
edge_features=128,
|
| 64 |
+
hidden_dim=128,
|
| 65 |
+
num_encoder_layers=3,
|
| 66 |
+
num_decoder_layers=3,
|
| 67 |
+
k_neighbors=k_neighbors,
|
| 68 |
+
device=device,
|
| 69 |
+
atom_context_num=atom_context_num,
|
| 70 |
+
model_type=args.model_type,
|
| 71 |
+
ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 75 |
+
model.to(device)
|
| 76 |
+
model.eval()
|
| 77 |
+
|
| 78 |
+
if args.pdb_path_multi:
|
| 79 |
+
with open(args.pdb_path_multi, "r") as fh:
|
| 80 |
+
pdb_paths = list(json.load(fh))
|
| 81 |
+
else:
|
| 82 |
+
pdb_paths = [args.pdb_path]
|
| 83 |
+
|
| 84 |
+
if args.fixed_residues_multi:
|
| 85 |
+
with open(args.fixed_residues_multi, "r") as fh:
|
| 86 |
+
fixed_residues_multi = json.load(fh)
|
| 87 |
+
else:
|
| 88 |
+
fixed_residues = [item for item in args.fixed_residues.split()]
|
| 89 |
+
fixed_residues_multi = {}
|
| 90 |
+
for pdb in pdb_paths:
|
| 91 |
+
fixed_residues_multi[pdb] = fixed_residues
|
| 92 |
+
|
| 93 |
+
if args.redesigned_residues_multi:
|
| 94 |
+
with open(args.redesigned_residues_multi, "r") as fh:
|
| 95 |
+
redesigned_residues_multi = json.load(fh)
|
| 96 |
+
else:
|
| 97 |
+
redesigned_residues = [item for item in args.redesigned_residues.split()]
|
| 98 |
+
redesigned_residues_multi = {}
|
| 99 |
+
for pdb in pdb_paths:
|
| 100 |
+
redesigned_residues_multi[pdb] = redesigned_residues
|
| 101 |
+
|
| 102 |
+
# loop over PDB paths
|
| 103 |
+
for pdb in pdb_paths:
|
| 104 |
+
if args.verbose:
|
| 105 |
+
print("Designing protein from this path:", pdb)
|
| 106 |
+
fixed_residues = fixed_residues_multi[pdb]
|
| 107 |
+
redesigned_residues = redesigned_residues_multi[pdb]
|
| 108 |
+
protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(
|
| 109 |
+
pdb,
|
| 110 |
+
device=device,
|
| 111 |
+
chains=args.parse_these_chains_only,
|
| 112 |
+
parse_all_atoms=args.ligand_mpnn_use_side_chain_context,
|
| 113 |
+
parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy
|
| 114 |
+
)
|
| 115 |
+
# make chain_letter + residue_idx + insertion_code mapping to integers
|
| 116 |
+
R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices
|
| 117 |
+
chain_letters_list = list(protein_dict["chain_letters"]) # chain letters
|
| 118 |
+
encoded_residues = []
|
| 119 |
+
for i, R_idx_item in enumerate(R_idx_list):
|
| 120 |
+
tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i]
|
| 121 |
+
encoded_residues.append(tmp)
|
| 122 |
+
encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues))))
|
| 123 |
+
encoded_residue_dict_rev = dict(
|
| 124 |
+
zip(list(range(len(encoded_residues))), encoded_residues)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
fixed_positions = torch.tensor(
|
| 128 |
+
[int(item not in fixed_residues) for item in encoded_residues],
|
| 129 |
+
device=device,
|
| 130 |
+
)
|
| 131 |
+
redesigned_positions = torch.tensor(
|
| 132 |
+
[int(item not in redesigned_residues) for item in encoded_residues],
|
| 133 |
+
device=device,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model
|
| 137 |
+
if args.transmembrane_buried:
|
| 138 |
+
buried_residues = [item for item in args.transmembrane_buried.split()]
|
| 139 |
+
buried_positions = torch.tensor(
|
| 140 |
+
[int(item in buried_residues) for item in encoded_residues],
|
| 141 |
+
device=device,
|
| 142 |
+
)
|
| 143 |
+
else:
|
| 144 |
+
buried_positions = torch.zeros_like(fixed_positions)
|
| 145 |
+
|
| 146 |
+
if args.transmembrane_interface:
|
| 147 |
+
interface_residues = [item for item in args.transmembrane_interface.split()]
|
| 148 |
+
interface_positions = torch.tensor(
|
| 149 |
+
[int(item in interface_residues) for item in encoded_residues],
|
| 150 |
+
device=device,
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
interface_positions = torch.zeros_like(fixed_positions)
|
| 154 |
+
protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * (
|
| 155 |
+
1 - interface_positions
|
| 156 |
+
) + 1 * interface_positions * (1 - buried_positions)
|
| 157 |
+
|
| 158 |
+
if args.model_type == "global_label_membrane_mpnn":
|
| 159 |
+
protein_dict["membrane_per_residue_labels"] = (
|
| 160 |
+
args.global_transmembrane_label + 0 * fixed_positions
|
| 161 |
+
)
|
| 162 |
+
if type(args.chains_to_design) == str:
|
| 163 |
+
chains_to_design_list = args.chains_to_design.split(",")
|
| 164 |
+
else:
|
| 165 |
+
chains_to_design_list = protein_dict["chain_letters"]
|
| 166 |
+
chain_mask = torch.tensor(
|
| 167 |
+
np.array(
|
| 168 |
+
[
|
| 169 |
+
item in chains_to_design_list
|
| 170 |
+
for item in protein_dict["chain_letters"]
|
| 171 |
+
],
|
| 172 |
+
dtype=np.int32,
|
| 173 |
+
),
|
| 174 |
+
device=device,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# create chain_mask to notify which residues are fixed (0) and which need to be designed (1)
|
| 178 |
+
if redesigned_residues:
|
| 179 |
+
protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions)
|
| 180 |
+
elif fixed_residues:
|
| 181 |
+
protein_dict["chain_mask"] = chain_mask * fixed_positions
|
| 182 |
+
else:
|
| 183 |
+
protein_dict["chain_mask"] = chain_mask
|
| 184 |
+
|
| 185 |
+
if args.verbose:
|
| 186 |
+
PDB_residues_to_be_redesigned = [
|
| 187 |
+
encoded_residue_dict_rev[item]
|
| 188 |
+
for item in range(protein_dict["chain_mask"].shape[0])
|
| 189 |
+
if protein_dict["chain_mask"][item] == 1
|
| 190 |
+
]
|
| 191 |
+
PDB_residues_to_be_fixed = [
|
| 192 |
+
encoded_residue_dict_rev[item]
|
| 193 |
+
for item in range(protein_dict["chain_mask"].shape[0])
|
| 194 |
+
if protein_dict["chain_mask"][item] == 0
|
| 195 |
+
]
|
| 196 |
+
print("These residues will be redesigned: ", PDB_residues_to_be_redesigned)
|
| 197 |
+
print("These residues will be fixed: ", PDB_residues_to_be_fixed)
|
| 198 |
+
|
| 199 |
+
# specify which residues are linked
|
| 200 |
+
if args.symmetry_residues:
|
| 201 |
+
symmetry_residues_list_of_lists = [
|
| 202 |
+
x.split(",") for x in args.symmetry_residues.split("|")
|
| 203 |
+
]
|
| 204 |
+
remapped_symmetry_residues = []
|
| 205 |
+
for t_list in symmetry_residues_list_of_lists:
|
| 206 |
+
tmp_list = []
|
| 207 |
+
for t in t_list:
|
| 208 |
+
tmp_list.append(encoded_residue_dict[t])
|
| 209 |
+
remapped_symmetry_residues.append(tmp_list)
|
| 210 |
+
else:
|
| 211 |
+
remapped_symmetry_residues = [[]]
|
| 212 |
+
|
| 213 |
+
if args.homo_oligomer:
|
| 214 |
+
if args.verbose:
|
| 215 |
+
print("Designing HOMO-OLIGOMER")
|
| 216 |
+
chain_letters_set = list(set(chain_letters_list))
|
| 217 |
+
reference_chain = chain_letters_set[0]
|
| 218 |
+
lc = len(reference_chain)
|
| 219 |
+
residue_indices = [
|
| 220 |
+
item[lc:] for item in encoded_residues if item[:lc] == reference_chain
|
| 221 |
+
]
|
| 222 |
+
remapped_symmetry_residues = []
|
| 223 |
+
for res in residue_indices:
|
| 224 |
+
tmp_list = []
|
| 225 |
+
tmp_w_list = []
|
| 226 |
+
for chain in chain_letters_set:
|
| 227 |
+
name = chain + res
|
| 228 |
+
tmp_list.append(encoded_residue_dict[name])
|
| 229 |
+
tmp_w_list.append(1 / len(chain_letters_set))
|
| 230 |
+
remapped_symmetry_residues.append(tmp_list)
|
| 231 |
+
|
| 232 |
+
# set other atom bfactors to 0.0
|
| 233 |
+
if other_atoms:
|
| 234 |
+
other_bfactors = other_atoms.getBetas()
|
| 235 |
+
other_atoms.setBetas(other_bfactors * 0.0)
|
| 236 |
+
|
| 237 |
+
# adjust input PDB name by dropping .pdb if it does exist
|
| 238 |
+
name = pdb[pdb.rfind("/") + 1 :]
|
| 239 |
+
if name[-4:] == ".pdb":
|
| 240 |
+
name = name[:-4]
|
| 241 |
+
|
| 242 |
+
with torch.no_grad():
|
| 243 |
+
# run featurize to remap R_idx and add batch dimension
|
| 244 |
+
if args.verbose:
|
| 245 |
+
if "Y" in list(protein_dict):
|
| 246 |
+
atom_coords = protein_dict["Y"].cpu().numpy()
|
| 247 |
+
atom_types = list(protein_dict["Y_t"].cpu().numpy())
|
| 248 |
+
atom_mask = list(protein_dict["Y_m"].cpu().numpy())
|
| 249 |
+
number_of_atoms_parsed = np.sum(atom_mask)
|
| 250 |
+
else:
|
| 251 |
+
print("No ligand atoms parsed")
|
| 252 |
+
number_of_atoms_parsed = 0
|
| 253 |
+
atom_types = ""
|
| 254 |
+
atom_coords = []
|
| 255 |
+
if number_of_atoms_parsed == 0:
|
| 256 |
+
print("No ligand atoms parsed")
|
| 257 |
+
elif args.model_type == "ligand_mpnn":
|
| 258 |
+
print(
|
| 259 |
+
f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}"
|
| 260 |
+
)
|
| 261 |
+
for i, atom_type in enumerate(atom_types):
|
| 262 |
+
print(
|
| 263 |
+
f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}"
|
| 264 |
+
)
|
| 265 |
+
feature_dict = featurize(
|
| 266 |
+
protein_dict,
|
| 267 |
+
cutoff_for_score=args.ligand_mpnn_cutoff_for_score,
|
| 268 |
+
use_atom_context=args.ligand_mpnn_use_atom_context,
|
| 269 |
+
number_of_ligand_atoms=atom_context_num,
|
| 270 |
+
model_type=args.model_type,
|
| 271 |
+
)
|
| 272 |
+
feature_dict["batch_size"] = args.batch_size
|
| 273 |
+
B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now.
|
| 274 |
+
# add additional keys to the feature dictionary
|
| 275 |
+
feature_dict["symmetry_residues"] = remapped_symmetry_residues
|
| 276 |
+
|
| 277 |
+
logits_list = []
|
| 278 |
+
probs_list = []
|
| 279 |
+
log_probs_list = []
|
| 280 |
+
decoding_order_list = []
|
| 281 |
+
for _ in range(args.number_of_batches):
|
| 282 |
+
feature_dict["randn"] = torch.randn(
|
| 283 |
+
[feature_dict["batch_size"], feature_dict["mask"].shape[1]],
|
| 284 |
+
device=device,
|
| 285 |
+
)
|
| 286 |
+
if args.autoregressive_score:
|
| 287 |
+
score_dict = model.score(feature_dict, use_sequence=args.use_sequence)
|
| 288 |
+
elif args.single_aa_score:
|
| 289 |
+
score_dict = model.single_aa_score(feature_dict, use_sequence=args.use_sequence)
|
| 290 |
+
else:
|
| 291 |
+
print("Set either autoregressive_score or single_aa_score to True")
|
| 292 |
+
sys.exit()
|
| 293 |
+
logits_list.append(score_dict["logits"])
|
| 294 |
+
log_probs_list.append(score_dict["log_probs"])
|
| 295 |
+
probs_list.append(torch.exp(score_dict["log_probs"]))
|
| 296 |
+
decoding_order_list.append(score_dict["decoding_order"])
|
| 297 |
+
log_probs_stack = torch.cat(log_probs_list, 0)
|
| 298 |
+
logits_stack = torch.cat(logits_list, 0)
|
| 299 |
+
probs_stack = torch.cat(probs_list, 0)
|
| 300 |
+
decoding_order_stack = torch.cat(decoding_order_list, 0)
|
| 301 |
+
|
| 302 |
+
output_stats_path = base_folder + name + args.file_ending + ".pt"
|
| 303 |
+
out_dict = {}
|
| 304 |
+
out_dict["logits"] = logits_stack.cpu().numpy()
|
| 305 |
+
out_dict["probs"] = probs_stack.cpu().numpy()
|
| 306 |
+
out_dict["log_probs"] = log_probs_stack.cpu().numpy()
|
| 307 |
+
out_dict["decoding_order"] = decoding_order_stack.cpu().numpy()
|
| 308 |
+
out_dict["native_sequence"] = feature_dict["S"][0].cpu().numpy()
|
| 309 |
+
out_dict["mask"] = feature_dict["mask"][0].cpu().numpy()
|
| 310 |
+
out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu().numpy() #this affects decoding order
|
| 311 |
+
out_dict["seed"] = seed
|
| 312 |
+
out_dict["alphabet"] = alphabet
|
| 313 |
+
out_dict["residue_names"] = encoded_residue_dict_rev
|
| 314 |
+
|
| 315 |
+
mean_probs = np.mean(out_dict["probs"], 0)
|
| 316 |
+
std_probs = np.std(out_dict["probs"], 0)
|
| 317 |
+
sequence = [restype_int_to_str[AA] for AA in out_dict["native_sequence"]]
|
| 318 |
+
mean_dict = {}
|
| 319 |
+
std_dict = {}
|
| 320 |
+
for residue in range(L):
|
| 321 |
+
mean_dict_ = dict(zip(alphabet, mean_probs[residue]))
|
| 322 |
+
mean_dict[encoded_residue_dict_rev[residue]] = mean_dict_
|
| 323 |
+
std_dict_ = dict(zip(alphabet, std_probs[residue]))
|
| 324 |
+
std_dict[encoded_residue_dict_rev[residue]] = std_dict_
|
| 325 |
+
|
| 326 |
+
out_dict["sequence"] = sequence
|
| 327 |
+
out_dict["mean_of_probs"] = mean_dict
|
| 328 |
+
out_dict["std_of_probs"] = std_dict
|
| 329 |
+
torch.save(out_dict, output_stats_path)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
if __name__ == "__main__":
|
| 334 |
+
argparser = argparse.ArgumentParser(
|
| 335 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
argparser.add_argument(
|
| 339 |
+
"--model_type",
|
| 340 |
+
type=str,
|
| 341 |
+
default="protein_mpnn",
|
| 342 |
+
help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn",
|
| 343 |
+
)
|
| 344 |
+
# protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms
|
| 345 |
+
# ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB
|
| 346 |
+
# per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed
|
| 347 |
+
# global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane
|
| 348 |
+
# soluble_mpnn - ProteinMPNN trained only on soluble PDB ids
|
| 349 |
+
argparser.add_argument(
|
| 350 |
+
"--checkpoint_protein_mpnn",
|
| 351 |
+
type=str,
|
| 352 |
+
default="./model_params/proteinmpnn_v_48_020.pt",
|
| 353 |
+
help="Path to model weights.",
|
| 354 |
+
)
|
| 355 |
+
argparser.add_argument(
|
| 356 |
+
"--checkpoint_ligand_mpnn",
|
| 357 |
+
type=str,
|
| 358 |
+
default="./model_params/ligandmpnn_v_32_010_25.pt",
|
| 359 |
+
help="Path to model weights.",
|
| 360 |
+
)
|
| 361 |
+
argparser.add_argument(
|
| 362 |
+
"--checkpoint_per_residue_label_membrane_mpnn",
|
| 363 |
+
type=str,
|
| 364 |
+
default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt",
|
| 365 |
+
help="Path to model weights.",
|
| 366 |
+
)
|
| 367 |
+
argparser.add_argument(
|
| 368 |
+
"--checkpoint_global_label_membrane_mpnn",
|
| 369 |
+
type=str,
|
| 370 |
+
default="./model_params/global_label_membrane_mpnn_v_48_020.pt",
|
| 371 |
+
help="Path to model weights.",
|
| 372 |
+
)
|
| 373 |
+
argparser.add_argument(
|
| 374 |
+
"--checkpoint_soluble_mpnn",
|
| 375 |
+
type=str,
|
| 376 |
+
default="./model_params/solublempnn_v_48_020.pt",
|
| 377 |
+
help="Path to model weights.",
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
argparser.add_argument("--verbose", type=int, default=1, help="Print stuff")
|
| 381 |
+
|
| 382 |
+
argparser.add_argument(
|
| 383 |
+
"--pdb_path", type=str, default="", help="Path to the input PDB."
|
| 384 |
+
)
|
| 385 |
+
argparser.add_argument(
|
| 386 |
+
"--pdb_path_multi",
|
| 387 |
+
type=str,
|
| 388 |
+
default="",
|
| 389 |
+
help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.",
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
argparser.add_argument(
|
| 393 |
+
"--fixed_residues",
|
| 394 |
+
type=str,
|
| 395 |
+
default="",
|
| 396 |
+
help="Provide fixed residues, A12 A13 A14 B2 B25",
|
| 397 |
+
)
|
| 398 |
+
argparser.add_argument(
|
| 399 |
+
"--fixed_residues_multi",
|
| 400 |
+
type=str,
|
| 401 |
+
default="",
|
| 402 |
+
help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
argparser.add_argument(
|
| 406 |
+
"--redesigned_residues",
|
| 407 |
+
type=str,
|
| 408 |
+
default="",
|
| 409 |
+
help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25",
|
| 410 |
+
)
|
| 411 |
+
argparser.add_argument(
|
| 412 |
+
"--redesigned_residues_multi",
|
| 413 |
+
type=str,
|
| 414 |
+
default="",
|
| 415 |
+
help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
argparser.add_argument(
|
| 419 |
+
"--symmetry_residues",
|
| 420 |
+
type=str,
|
| 421 |
+
default="",
|
| 422 |
+
help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'",
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
argparser.add_argument(
|
| 426 |
+
"--homo_oligomer",
|
| 427 |
+
type=int,
|
| 428 |
+
default=0,
|
| 429 |
+
help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.",
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
argparser.add_argument(
|
| 433 |
+
"--out_folder",
|
| 434 |
+
type=str,
|
| 435 |
+
help="Path to a folder to output scores, e.g. /home/out/",
|
| 436 |
+
)
|
| 437 |
+
argparser.add_argument(
|
| 438 |
+
"--file_ending", type=str, default="", help="adding_string_to_the_end"
|
| 439 |
+
)
|
| 440 |
+
argparser.add_argument(
|
| 441 |
+
"--zero_indexed",
|
| 442 |
+
type=str,
|
| 443 |
+
default=0,
|
| 444 |
+
help="1 - to start output PDB numbering with 0",
|
| 445 |
+
)
|
| 446 |
+
argparser.add_argument(
|
| 447 |
+
"--seed",
|
| 448 |
+
type=int,
|
| 449 |
+
default=0,
|
| 450 |
+
help="Set seed for torch, numpy, and python random.",
|
| 451 |
+
)
|
| 452 |
+
argparser.add_argument(
|
| 453 |
+
"--batch_size",
|
| 454 |
+
type=int,
|
| 455 |
+
default=1,
|
| 456 |
+
help="Number of sequence to generate per one pass.",
|
| 457 |
+
)
|
| 458 |
+
argparser.add_argument(
|
| 459 |
+
"--number_of_batches",
|
| 460 |
+
type=int,
|
| 461 |
+
default=1,
|
| 462 |
+
help="Number of times to design sequence using a chosen batch size.",
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
argparser.add_argument(
|
| 466 |
+
"--ligand_mpnn_use_atom_context",
|
| 467 |
+
type=int,
|
| 468 |
+
default=1,
|
| 469 |
+
help="1 - use atom context, 0 - do not use atom context.",
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
argparser.add_argument(
|
| 473 |
+
"--ligand_mpnn_use_side_chain_context",
|
| 474 |
+
type=int,
|
| 475 |
+
default=0,
|
| 476 |
+
help="Flag to use side chain atoms as ligand context for the fixed residues",
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
argparser.add_argument(
|
| 480 |
+
"--ligand_mpnn_cutoff_for_score",
|
| 481 |
+
type=float,
|
| 482 |
+
default=8.0,
|
| 483 |
+
help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.",
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
argparser.add_argument(
|
| 487 |
+
"--chains_to_design",
|
| 488 |
+
type=str,
|
| 489 |
+
default=None,
|
| 490 |
+
help="Specify which chains to redesign, all others will be kept fixed.",
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
argparser.add_argument(
|
| 494 |
+
"--parse_these_chains_only",
|
| 495 |
+
type=str,
|
| 496 |
+
default="",
|
| 497 |
+
help="Provide chains letters for parsing backbones, 'ABCF'",
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
argparser.add_argument(
|
| 501 |
+
"--transmembrane_buried",
|
| 502 |
+
type=str,
|
| 503 |
+
default="",
|
| 504 |
+
help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
|
| 505 |
+
)
|
| 506 |
+
argparser.add_argument(
|
| 507 |
+
"--transmembrane_interface",
|
| 508 |
+
type=str,
|
| 509 |
+
default="",
|
| 510 |
+
help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
argparser.add_argument(
|
| 514 |
+
"--global_transmembrane_label",
|
| 515 |
+
type=int,
|
| 516 |
+
default=0,
|
| 517 |
+
help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble",
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
argparser.add_argument(
|
| 521 |
+
"--parse_atoms_with_zero_occupancy",
|
| 522 |
+
type=int,
|
| 523 |
+
default=0,
|
| 524 |
+
help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy",
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
argparser.add_argument(
|
| 528 |
+
"--use_sequence",
|
| 529 |
+
type=int,
|
| 530 |
+
default=1,
|
| 531 |
+
help="1 - get scores using amino acid sequence info; 0 - get scores using backbone info only",
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
argparser.add_argument(
|
| 535 |
+
"--autoregressive_score",
|
| 536 |
+
type=int,
|
| 537 |
+
default=0,
|
| 538 |
+
help="1 - run autoregressive scoring function; p(AA_1|backbone); p(AA_2|backbone, AA_1) etc, 0 - False",
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
argparser.add_argument(
|
| 542 |
+
"--single_aa_score",
|
| 543 |
+
type=int,
|
| 544 |
+
default=1,
|
| 545 |
+
help="1 - run single amino acid scoring function; p(AA_i|backbone, AA_{all except ith one}), 0 - False",
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
args = argparser.parse_args()
|
| 549 |
+
main(args)
|
space_utils/download_weights.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
|
| 3 |
+
def download_ligandmpnn_weights():
|
| 4 |
+
url = "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_030_25.pt"
|
| 5 |
+
command = f"wget {url} -O ./model_params/ligandmpnn_v_32_030_25.pt"
|
| 6 |
+
subprocess.run(command, shell=True, check=True)
|
| 7 |
+
return 0
|