gabboud commited on
Commit
d95502a
·
1 Parent(s): 361e13e

initial commit from source repo

Browse files
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import warnings
3
+ import os
4
+ import subprocess
5
+ from pathlib import Path
6
+ import shutil
7
+ import spaces
8
+ from space_utils.download_weights import download_ligandmpnn_weights
9
+
10
+ download_ligandmpnn_weights()
11
+
12
+ with gr.Blocks(title="RFD3 Test") as demo:
13
+ out_dir = "./output/test"
14
+ command= f"python run.py --pdb_path ./inputs/1BC8.pdb --out_folder {out_dir}"
15
+ subprocess.run(command, shell=True, check=True, text=True)
16
+
17
+ command = f"ls {out_dir}"
18
+ res = subprocess.run(command, shell=True, check=True, text=True, capture_output=True)
19
+
20
+ gr.Markdown("### Command Output")
21
+ gr.Textbox(value=res.stdout, lines=20)
22
+
23
+
24
+ if __name__ == "__main__":
25
+ demo.launch()
data_utils.py ADDED
@@ -0,0 +1,988 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.utils
6
+ from prody import *
7
+
8
+ confProDy(verbosity="none")
9
+
10
+ restype_1to3 = {
11
+ "A": "ALA",
12
+ "R": "ARG",
13
+ "N": "ASN",
14
+ "D": "ASP",
15
+ "C": "CYS",
16
+ "Q": "GLN",
17
+ "E": "GLU",
18
+ "G": "GLY",
19
+ "H": "HIS",
20
+ "I": "ILE",
21
+ "L": "LEU",
22
+ "K": "LYS",
23
+ "M": "MET",
24
+ "F": "PHE",
25
+ "P": "PRO",
26
+ "S": "SER",
27
+ "T": "THR",
28
+ "W": "TRP",
29
+ "Y": "TYR",
30
+ "V": "VAL",
31
+ "X": "UNK",
32
+ }
33
+ restype_str_to_int = {
34
+ "A": 0,
35
+ "C": 1,
36
+ "D": 2,
37
+ "E": 3,
38
+ "F": 4,
39
+ "G": 5,
40
+ "H": 6,
41
+ "I": 7,
42
+ "K": 8,
43
+ "L": 9,
44
+ "M": 10,
45
+ "N": 11,
46
+ "P": 12,
47
+ "Q": 13,
48
+ "R": 14,
49
+ "S": 15,
50
+ "T": 16,
51
+ "V": 17,
52
+ "W": 18,
53
+ "Y": 19,
54
+ "X": 20,
55
+ }
56
+ restype_int_to_str = {
57
+ 0: "A",
58
+ 1: "C",
59
+ 2: "D",
60
+ 3: "E",
61
+ 4: "F",
62
+ 5: "G",
63
+ 6: "H",
64
+ 7: "I",
65
+ 8: "K",
66
+ 9: "L",
67
+ 10: "M",
68
+ 11: "N",
69
+ 12: "P",
70
+ 13: "Q",
71
+ 14: "R",
72
+ 15: "S",
73
+ 16: "T",
74
+ 17: "V",
75
+ 18: "W",
76
+ 19: "Y",
77
+ 20: "X",
78
+ }
79
+ alphabet = list(restype_str_to_int)
80
+
81
+ element_list = [
82
+ "H",
83
+ "He",
84
+ "Li",
85
+ "Be",
86
+ "B",
87
+ "C",
88
+ "N",
89
+ "O",
90
+ "F",
91
+ "Ne",
92
+ "Na",
93
+ "Mg",
94
+ "Al",
95
+ "Si",
96
+ "P",
97
+ "S",
98
+ "Cl",
99
+ "Ar",
100
+ "K",
101
+ "Ca",
102
+ "Sc",
103
+ "Ti",
104
+ "V",
105
+ "Cr",
106
+ "Mn",
107
+ "Fe",
108
+ "Co",
109
+ "Ni",
110
+ "Cu",
111
+ "Zn",
112
+ "Ga",
113
+ "Ge",
114
+ "As",
115
+ "Se",
116
+ "Br",
117
+ "Kr",
118
+ "Rb",
119
+ "Sr",
120
+ "Y",
121
+ "Zr",
122
+ "Nb",
123
+ "Mb",
124
+ "Tc",
125
+ "Ru",
126
+ "Rh",
127
+ "Pd",
128
+ "Ag",
129
+ "Cd",
130
+ "In",
131
+ "Sn",
132
+ "Sb",
133
+ "Te",
134
+ "I",
135
+ "Xe",
136
+ "Cs",
137
+ "Ba",
138
+ "La",
139
+ "Ce",
140
+ "Pr",
141
+ "Nd",
142
+ "Pm",
143
+ "Sm",
144
+ "Eu",
145
+ "Gd",
146
+ "Tb",
147
+ "Dy",
148
+ "Ho",
149
+ "Er",
150
+ "Tm",
151
+ "Yb",
152
+ "Lu",
153
+ "Hf",
154
+ "Ta",
155
+ "W",
156
+ "Re",
157
+ "Os",
158
+ "Ir",
159
+ "Pt",
160
+ "Au",
161
+ "Hg",
162
+ "Tl",
163
+ "Pb",
164
+ "Bi",
165
+ "Po",
166
+ "At",
167
+ "Rn",
168
+ "Fr",
169
+ "Ra",
170
+ "Ac",
171
+ "Th",
172
+ "Pa",
173
+ "U",
174
+ "Np",
175
+ "Pu",
176
+ "Am",
177
+ "Cm",
178
+ "Bk",
179
+ "Cf",
180
+ "Es",
181
+ "Fm",
182
+ "Md",
183
+ "No",
184
+ "Lr",
185
+ "Rf",
186
+ "Db",
187
+ "Sg",
188
+ "Bh",
189
+ "Hs",
190
+ "Mt",
191
+ "Ds",
192
+ "Rg",
193
+ "Cn",
194
+ "Uut",
195
+ "Fl",
196
+ "Uup",
197
+ "Lv",
198
+ "Uus",
199
+ "Uuo",
200
+ ]
201
+ element_list = [item.upper() for item in element_list]
202
+ # element_dict = dict(zip(element_list, range(1,len(element_list))))
203
+ element_dict_rev = dict(zip(range(1, len(element_list)), element_list))
204
+
205
+
206
+ def get_seq_rec(S: torch.Tensor, S_pred: torch.Tensor, mask: torch.Tensor):
207
+ """
208
+ S : true sequence shape=[batch, length]
209
+ S_pred : predicted sequence shape=[batch, length]
210
+ mask : mask to compute average over the region shape=[batch, length]
211
+
212
+ average : averaged sequence recovery shape=[batch]
213
+ """
214
+ match = S == S_pred
215
+ average = torch.sum(match * mask, dim=-1) / torch.sum(mask, dim=-1)
216
+ return average
217
+
218
+
219
+ def get_score(S: torch.Tensor, log_probs: torch.Tensor, mask: torch.Tensor):
220
+ """
221
+ S : true sequence shape=[batch, length]
222
+ log_probs : predicted sequence shape=[batch, length]
223
+ mask : mask to compute average over the region shape=[batch, length]
224
+
225
+ average_loss : averaged categorical cross entropy (CCE) [batch]
226
+ loss_per_resdue : per position CCE [batch, length]
227
+ """
228
+ S_one_hot = torch.nn.functional.one_hot(S, 21)
229
+ loss_per_residue = -(S_one_hot * log_probs).sum(-1) # [B, L]
230
+ average_loss = torch.sum(loss_per_residue * mask, dim=-1) / (
231
+ torch.sum(mask, dim=-1) + 1e-8
232
+ )
233
+ return average_loss, loss_per_residue
234
+
235
+
236
+ def write_full_PDB(
237
+ save_path: str,
238
+ X: np.ndarray,
239
+ X_m: np.ndarray,
240
+ b_factors: np.ndarray,
241
+ R_idx: np.ndarray,
242
+ chain_letters: np.ndarray,
243
+ S: np.ndarray,
244
+ other_atoms=None,
245
+ icodes=None,
246
+ force_hetatm=False,
247
+ ):
248
+ """
249
+ save_path : path where the PDB will be written to
250
+ X : protein atom xyz coordinates shape=[length, 14, 3]
251
+ X_m : protein atom mask shape=[length, 14]
252
+ b_factors: shape=[length, 14]
253
+ R_idx: protein residue indices shape=[length]
254
+ chain_letters: protein chain letters shape=[length]
255
+ S : protein amino acid sequence shape=[length]
256
+ other_atoms: other atoms parsed by prody
257
+ icodes: a list of insertion codes for the PDB; e.g. antibody loops
258
+ """
259
+
260
+ restype_1to3 = {
261
+ "A": "ALA",
262
+ "R": "ARG",
263
+ "N": "ASN",
264
+ "D": "ASP",
265
+ "C": "CYS",
266
+ "Q": "GLN",
267
+ "E": "GLU",
268
+ "G": "GLY",
269
+ "H": "HIS",
270
+ "I": "ILE",
271
+ "L": "LEU",
272
+ "K": "LYS",
273
+ "M": "MET",
274
+ "F": "PHE",
275
+ "P": "PRO",
276
+ "S": "SER",
277
+ "T": "THR",
278
+ "W": "TRP",
279
+ "Y": "TYR",
280
+ "V": "VAL",
281
+ "X": "UNK",
282
+ }
283
+ restype_INTtoSTR = {
284
+ 0: "A",
285
+ 1: "C",
286
+ 2: "D",
287
+ 3: "E",
288
+ 4: "F",
289
+ 5: "G",
290
+ 6: "H",
291
+ 7: "I",
292
+ 8: "K",
293
+ 9: "L",
294
+ 10: "M",
295
+ 11: "N",
296
+ 12: "P",
297
+ 13: "Q",
298
+ 14: "R",
299
+ 15: "S",
300
+ 16: "T",
301
+ 17: "V",
302
+ 18: "W",
303
+ 19: "Y",
304
+ 20: "X",
305
+ }
306
+ restype_name_to_atom14_names = {
307
+ "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
308
+ "ARG": [
309
+ "N",
310
+ "CA",
311
+ "C",
312
+ "O",
313
+ "CB",
314
+ "CG",
315
+ "CD",
316
+ "NE",
317
+ "CZ",
318
+ "NH1",
319
+ "NH2",
320
+ "",
321
+ "",
322
+ "",
323
+ ],
324
+ "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
325
+ "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
326
+ "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
327
+ "GLN": [
328
+ "N",
329
+ "CA",
330
+ "C",
331
+ "O",
332
+ "CB",
333
+ "CG",
334
+ "CD",
335
+ "OE1",
336
+ "NE2",
337
+ "",
338
+ "",
339
+ "",
340
+ "",
341
+ "",
342
+ ],
343
+ "GLU": [
344
+ "N",
345
+ "CA",
346
+ "C",
347
+ "O",
348
+ "CB",
349
+ "CG",
350
+ "CD",
351
+ "OE1",
352
+ "OE2",
353
+ "",
354
+ "",
355
+ "",
356
+ "",
357
+ "",
358
+ ],
359
+ "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
360
+ "HIS": [
361
+ "N",
362
+ "CA",
363
+ "C",
364
+ "O",
365
+ "CB",
366
+ "CG",
367
+ "ND1",
368
+ "CD2",
369
+ "CE1",
370
+ "NE2",
371
+ "",
372
+ "",
373
+ "",
374
+ "",
375
+ ],
376
+ "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
377
+ "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
378
+ "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
379
+ "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
380
+ "PHE": [
381
+ "N",
382
+ "CA",
383
+ "C",
384
+ "O",
385
+ "CB",
386
+ "CG",
387
+ "CD1",
388
+ "CD2",
389
+ "CE1",
390
+ "CE2",
391
+ "CZ",
392
+ "",
393
+ "",
394
+ "",
395
+ ],
396
+ "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
397
+ "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
398
+ "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
399
+ "TRP": [
400
+ "N",
401
+ "CA",
402
+ "C",
403
+ "O",
404
+ "CB",
405
+ "CG",
406
+ "CD1",
407
+ "CD2",
408
+ "CE2",
409
+ "CE3",
410
+ "NE1",
411
+ "CZ2",
412
+ "CZ3",
413
+ "CH2",
414
+ ],
415
+ "TYR": [
416
+ "N",
417
+ "CA",
418
+ "C",
419
+ "O",
420
+ "CB",
421
+ "CG",
422
+ "CD1",
423
+ "CD2",
424
+ "CE1",
425
+ "CE2",
426
+ "CZ",
427
+ "OH",
428
+ "",
429
+ "",
430
+ ],
431
+ "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
432
+ "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
433
+ }
434
+
435
+ S_str = [restype_1to3[AA] for AA in [restype_INTtoSTR[AA] for AA in S]]
436
+
437
+ X_list = []
438
+ b_factor_list = []
439
+ atom_name_list = []
440
+ element_name_list = []
441
+ residue_name_list = []
442
+ residue_number_list = []
443
+ chain_id_list = []
444
+ icodes_list = []
445
+ for i, AA in enumerate(S_str):
446
+ sel = X_m[i].astype(np.int32) == 1
447
+ total = np.sum(sel)
448
+ tmp = np.array(restype_name_to_atom14_names[AA])[sel]
449
+ X_list.append(X[i][sel])
450
+ b_factor_list.append(b_factors[i][sel])
451
+ atom_name_list.append(tmp)
452
+ element_name_list += [AA[:1] for AA in list(tmp)]
453
+ residue_name_list += total * [AA]
454
+ residue_number_list += total * [R_idx[i]]
455
+ chain_id_list += total * [chain_letters[i]]
456
+ icodes_list += total * [icodes[i]]
457
+
458
+ X_stack = np.concatenate(X_list, 0)
459
+ b_factor_stack = np.concatenate(b_factor_list, 0)
460
+ atom_name_stack = np.concatenate(atom_name_list, 0)
461
+
462
+ protein = prody.AtomGroup()
463
+ protein.setCoords(X_stack)
464
+ protein.setBetas(b_factor_stack)
465
+ protein.setNames(atom_name_stack)
466
+ protein.setResnames(residue_name_list)
467
+ protein.setElements(element_name_list)
468
+ protein.setOccupancies(np.ones([X_stack.shape[0]]))
469
+ protein.setResnums(residue_number_list)
470
+ protein.setChids(chain_id_list)
471
+ protein.setIcodes(icodes_list)
472
+
473
+ if other_atoms:
474
+ other_atoms_g = prody.AtomGroup()
475
+ other_atoms_g.setCoords(other_atoms.getCoords())
476
+ other_atoms_g.setNames(other_atoms.getNames())
477
+ other_atoms_g.setResnames(other_atoms.getResnames())
478
+ other_atoms_g.setElements(other_atoms.getElements())
479
+ other_atoms_g.setOccupancies(other_atoms.getOccupancies())
480
+ other_atoms_g.setResnums(other_atoms.getResnums())
481
+ other_atoms_g.setChids(other_atoms.getChids())
482
+ if force_hetatm:
483
+ other_atoms_g.setFlags("hetatm", other_atoms.getFlags("hetatm"))
484
+ writePDB(save_path, protein + other_atoms_g)
485
+ else:
486
+ writePDB(save_path, protein)
487
+
488
+
489
+ def get_aligned_coordinates(protein_atoms, CA_dict: dict, atom_name: str):
490
+ """
491
+ protein_atoms: prody atom group
492
+ CA_dict: mapping between chain_residue_idx_icodes and integers
493
+ atom_name: atom to be parsed; e.g. CA
494
+ """
495
+ atom_atoms = protein_atoms.select(f"name {atom_name}")
496
+
497
+ if atom_atoms != None:
498
+ atom_coords = atom_atoms.getCoords()
499
+ atom_resnums = atom_atoms.getResnums()
500
+ atom_chain_ids = atom_atoms.getChids()
501
+ atom_icodes = atom_atoms.getIcodes()
502
+
503
+ atom_coords_ = np.zeros([len(CA_dict), 3], np.float32)
504
+ atom_coords_m = np.zeros([len(CA_dict)], np.int32)
505
+ if atom_atoms != None:
506
+ for i in range(len(atom_resnums)):
507
+ code = atom_chain_ids[i] + "_" + str(atom_resnums[i]) + "_" + atom_icodes[i]
508
+ if code in list(CA_dict):
509
+ atom_coords_[CA_dict[code], :] = atom_coords[i]
510
+ atom_coords_m[CA_dict[code]] = 1
511
+ return atom_coords_, atom_coords_m
512
+
513
+
514
+ def parse_PDB(
515
+ input_path: str,
516
+ device: str = "cpu",
517
+ chains: list = [],
518
+ parse_all_atoms: bool = False,
519
+ parse_atoms_with_zero_occupancy: bool = False
520
+ ):
521
+ """
522
+ input_path : path for the input PDB
523
+ device: device for the torch.Tensor
524
+ chains: a list specifying which chains need to be parsed; e.g. ["A", "B"]
525
+ parse_all_atoms: if False parse only N,CA,C,O otherwise all 37 atoms
526
+ parse_atoms_with_zero_occupancy: if True atoms with zero occupancy will be parsed
527
+ """
528
+ element_list = [
529
+ "H",
530
+ "He",
531
+ "Li",
532
+ "Be",
533
+ "B",
534
+ "C",
535
+ "N",
536
+ "O",
537
+ "F",
538
+ "Ne",
539
+ "Na",
540
+ "Mg",
541
+ "Al",
542
+ "Si",
543
+ "P",
544
+ "S",
545
+ "Cl",
546
+ "Ar",
547
+ "K",
548
+ "Ca",
549
+ "Sc",
550
+ "Ti",
551
+ "V",
552
+ "Cr",
553
+ "Mn",
554
+ "Fe",
555
+ "Co",
556
+ "Ni",
557
+ "Cu",
558
+ "Zn",
559
+ "Ga",
560
+ "Ge",
561
+ "As",
562
+ "Se",
563
+ "Br",
564
+ "Kr",
565
+ "Rb",
566
+ "Sr",
567
+ "Y",
568
+ "Zr",
569
+ "Nb",
570
+ "Mb",
571
+ "Tc",
572
+ "Ru",
573
+ "Rh",
574
+ "Pd",
575
+ "Ag",
576
+ "Cd",
577
+ "In",
578
+ "Sn",
579
+ "Sb",
580
+ "Te",
581
+ "I",
582
+ "Xe",
583
+ "Cs",
584
+ "Ba",
585
+ "La",
586
+ "Ce",
587
+ "Pr",
588
+ "Nd",
589
+ "Pm",
590
+ "Sm",
591
+ "Eu",
592
+ "Gd",
593
+ "Tb",
594
+ "Dy",
595
+ "Ho",
596
+ "Er",
597
+ "Tm",
598
+ "Yb",
599
+ "Lu",
600
+ "Hf",
601
+ "Ta",
602
+ "W",
603
+ "Re",
604
+ "Os",
605
+ "Ir",
606
+ "Pt",
607
+ "Au",
608
+ "Hg",
609
+ "Tl",
610
+ "Pb",
611
+ "Bi",
612
+ "Po",
613
+ "At",
614
+ "Rn",
615
+ "Fr",
616
+ "Ra",
617
+ "Ac",
618
+ "Th",
619
+ "Pa",
620
+ "U",
621
+ "Np",
622
+ "Pu",
623
+ "Am",
624
+ "Cm",
625
+ "Bk",
626
+ "Cf",
627
+ "Es",
628
+ "Fm",
629
+ "Md",
630
+ "No",
631
+ "Lr",
632
+ "Rf",
633
+ "Db",
634
+ "Sg",
635
+ "Bh",
636
+ "Hs",
637
+ "Mt",
638
+ "Ds",
639
+ "Rg",
640
+ "Cn",
641
+ "Uut",
642
+ "Fl",
643
+ "Uup",
644
+ "Lv",
645
+ "Uus",
646
+ "Uuo",
647
+ ]
648
+ element_list = [item.upper() for item in element_list]
649
+ element_dict = dict(zip(element_list, range(1, len(element_list))))
650
+ restype_3to1 = {
651
+ "ALA": "A",
652
+ "ARG": "R",
653
+ "ASN": "N",
654
+ "ASP": "D",
655
+ "CYS": "C",
656
+ "GLN": "Q",
657
+ "GLU": "E",
658
+ "GLY": "G",
659
+ "HIS": "H",
660
+ "ILE": "I",
661
+ "LEU": "L",
662
+ "LYS": "K",
663
+ "MET": "M",
664
+ "PHE": "F",
665
+ "PRO": "P",
666
+ "SER": "S",
667
+ "THR": "T",
668
+ "TRP": "W",
669
+ "TYR": "Y",
670
+ "VAL": "V",
671
+ }
672
+ restype_STRtoINT = {
673
+ "A": 0,
674
+ "C": 1,
675
+ "D": 2,
676
+ "E": 3,
677
+ "F": 4,
678
+ "G": 5,
679
+ "H": 6,
680
+ "I": 7,
681
+ "K": 8,
682
+ "L": 9,
683
+ "M": 10,
684
+ "N": 11,
685
+ "P": 12,
686
+ "Q": 13,
687
+ "R": 14,
688
+ "S": 15,
689
+ "T": 16,
690
+ "V": 17,
691
+ "W": 18,
692
+ "Y": 19,
693
+ "X": 20,
694
+ }
695
+
696
+ atom_order = {
697
+ "N": 0,
698
+ "CA": 1,
699
+ "C": 2,
700
+ "CB": 3,
701
+ "O": 4,
702
+ "CG": 5,
703
+ "CG1": 6,
704
+ "CG2": 7,
705
+ "OG": 8,
706
+ "OG1": 9,
707
+ "SG": 10,
708
+ "CD": 11,
709
+ "CD1": 12,
710
+ "CD2": 13,
711
+ "ND1": 14,
712
+ "ND2": 15,
713
+ "OD1": 16,
714
+ "OD2": 17,
715
+ "SD": 18,
716
+ "CE": 19,
717
+ "CE1": 20,
718
+ "CE2": 21,
719
+ "CE3": 22,
720
+ "NE": 23,
721
+ "NE1": 24,
722
+ "NE2": 25,
723
+ "OE1": 26,
724
+ "OE2": 27,
725
+ "CH2": 28,
726
+ "NH1": 29,
727
+ "NH2": 30,
728
+ "OH": 31,
729
+ "CZ": 32,
730
+ "CZ2": 33,
731
+ "CZ3": 34,
732
+ "NZ": 35,
733
+ "OXT": 36,
734
+ }
735
+
736
+ if not parse_all_atoms:
737
+ atom_types = ["N", "CA", "C", "O"]
738
+ else:
739
+ atom_types = [
740
+ "N",
741
+ "CA",
742
+ "C",
743
+ "CB",
744
+ "O",
745
+ "CG",
746
+ "CG1",
747
+ "CG2",
748
+ "OG",
749
+ "OG1",
750
+ "SG",
751
+ "CD",
752
+ "CD1",
753
+ "CD2",
754
+ "ND1",
755
+ "ND2",
756
+ "OD1",
757
+ "OD2",
758
+ "SD",
759
+ "CE",
760
+ "CE1",
761
+ "CE2",
762
+ "CE3",
763
+ "NE",
764
+ "NE1",
765
+ "NE2",
766
+ "OE1",
767
+ "OE2",
768
+ "CH2",
769
+ "NH1",
770
+ "NH2",
771
+ "OH",
772
+ "CZ",
773
+ "CZ2",
774
+ "CZ3",
775
+ "NZ",
776
+ ]
777
+
778
+ atoms = parsePDB(input_path)
779
+ if not parse_atoms_with_zero_occupancy:
780
+ atoms = atoms.select("occupancy > 0")
781
+ if chains:
782
+ str_out = ""
783
+ for item in chains:
784
+ str_out += " chain " + item + " or"
785
+ atoms = atoms.select(str_out[1:-3])
786
+
787
+ protein_atoms = atoms.select("protein")
788
+ backbone = protein_atoms.select("backbone")
789
+ other_atoms = atoms.select("not protein and not water")
790
+ water_atoms = atoms.select("water")
791
+
792
+ CA_atoms = protein_atoms.select("name CA")
793
+ CA_resnums = CA_atoms.getResnums()
794
+ CA_chain_ids = CA_atoms.getChids()
795
+ CA_icodes = CA_atoms.getIcodes()
796
+
797
+ CA_dict = {}
798
+ for i in range(len(CA_resnums)):
799
+ code = CA_chain_ids[i] + "_" + str(CA_resnums[i]) + "_" + CA_icodes[i]
800
+ CA_dict[code] = i
801
+
802
+ xyz_37 = np.zeros([len(CA_dict), 37, 3], np.float32)
803
+ xyz_37_m = np.zeros([len(CA_dict), 37], np.int32)
804
+ for atom_name in atom_types:
805
+ xyz, xyz_m = get_aligned_coordinates(protein_atoms, CA_dict, atom_name)
806
+ xyz_37[:, atom_order[atom_name], :] = xyz
807
+ xyz_37_m[:, atom_order[atom_name]] = xyz_m
808
+
809
+ N = xyz_37[:, atom_order["N"], :]
810
+ CA = xyz_37[:, atom_order["CA"], :]
811
+ C = xyz_37[:, atom_order["C"], :]
812
+ O = xyz_37[:, atom_order["O"], :]
813
+
814
+ N_m = xyz_37_m[:, atom_order["N"]]
815
+ CA_m = xyz_37_m[:, atom_order["CA"]]
816
+ C_m = xyz_37_m[:, atom_order["C"]]
817
+ O_m = xyz_37_m[:, atom_order["O"]]
818
+
819
+ mask = N_m * CA_m * C_m * O_m # must all 4 atoms exist
820
+
821
+ b = CA - N
822
+ c = C - CA
823
+ a = np.cross(b, c, axis=-1)
824
+ CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
825
+
826
+ chain_labels = np.array(CA_atoms.getChindices(), dtype=np.int32)
827
+ R_idx = np.array(CA_resnums, dtype=np.int32)
828
+ S = CA_atoms.getResnames()
829
+ S = [restype_3to1[AA] if AA in list(restype_3to1) else "X" for AA in list(S)]
830
+ S = np.array([restype_STRtoINT[AA] for AA in list(S)], np.int32)
831
+ X = np.concatenate([N[:, None], CA[:, None], C[:, None], O[:, None]], 1)
832
+
833
+ try:
834
+ Y = np.array(other_atoms.getCoords(), dtype=np.float32)
835
+ Y_t = list(other_atoms.getElements())
836
+ Y_t = np.array(
837
+ [
838
+ element_dict[y_t.upper()] if y_t.upper() in element_list else 0
839
+ for y_t in Y_t
840
+ ],
841
+ dtype=np.int32,
842
+ )
843
+ Y_m = (Y_t != 1) * (Y_t != 0)
844
+
845
+ Y = Y[Y_m, :]
846
+ Y_t = Y_t[Y_m]
847
+ Y_m = Y_m[Y_m]
848
+ except:
849
+ Y = np.zeros([1, 3], np.float32)
850
+ Y_t = np.zeros([1], np.int32)
851
+ Y_m = np.zeros([1], np.int32)
852
+
853
+ output_dict = {}
854
+ output_dict["X"] = torch.tensor(X, device=device, dtype=torch.float32)
855
+ output_dict["mask"] = torch.tensor(mask, device=device, dtype=torch.int32)
856
+ output_dict["Y"] = torch.tensor(Y, device=device, dtype=torch.float32)
857
+ output_dict["Y_t"] = torch.tensor(Y_t, device=device, dtype=torch.int32)
858
+ output_dict["Y_m"] = torch.tensor(Y_m, device=device, dtype=torch.int32)
859
+
860
+ output_dict["R_idx"] = torch.tensor(R_idx, device=device, dtype=torch.int32)
861
+ output_dict["chain_labels"] = torch.tensor(
862
+ chain_labels, device=device, dtype=torch.int32
863
+ )
864
+
865
+ output_dict["chain_letters"] = CA_chain_ids
866
+
867
+ mask_c = []
868
+ chain_list = list(set(output_dict["chain_letters"]))
869
+ chain_list.sort()
870
+ for chain in chain_list:
871
+ mask_c.append(
872
+ torch.tensor(
873
+ [chain == item for item in output_dict["chain_letters"]],
874
+ device=device,
875
+ dtype=bool,
876
+ )
877
+ )
878
+
879
+ output_dict["mask_c"] = mask_c
880
+ output_dict["chain_list"] = chain_list
881
+
882
+ output_dict["S"] = torch.tensor(S, device=device, dtype=torch.int32)
883
+
884
+ output_dict["xyz_37"] = torch.tensor(xyz_37, device=device, dtype=torch.float32)
885
+ output_dict["xyz_37_m"] = torch.tensor(xyz_37_m, device=device, dtype=torch.int32)
886
+
887
+ return output_dict, backbone, other_atoms, CA_icodes, water_atoms
888
+
889
+
890
+ def get_nearest_neighbours(CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms):
891
+ device = CB.device
892
+ mask_CBY = mask[:, None] * Y_m[None, :] # [A,B]
893
+ L2_AB = torch.sum((CB[:, None, :] - Y[None, :, :]) ** 2, -1)
894
+ L2_AB = L2_AB * mask_CBY + (1 - mask_CBY) * 1000.0
895
+
896
+ nn_idx = torch.argsort(L2_AB, -1)[:, :number_of_ligand_atoms]
897
+ L2_AB_nn = torch.gather(L2_AB, 1, nn_idx)
898
+ D_AB_closest = torch.sqrt(L2_AB_nn[:, 0])
899
+
900
+ Y_r = Y[None, :, :].repeat(CB.shape[0], 1, 1)
901
+ Y_t_r = Y_t[None, :].repeat(CB.shape[0], 1)
902
+ Y_m_r = Y_m[None, :].repeat(CB.shape[0], 1)
903
+
904
+ Y_tmp = torch.gather(Y_r, 1, nn_idx[:, :, None].repeat(1, 1, 3))
905
+ Y_t_tmp = torch.gather(Y_t_r, 1, nn_idx)
906
+ Y_m_tmp = torch.gather(Y_m_r, 1, nn_idx)
907
+
908
+ Y = torch.zeros(
909
+ [CB.shape[0], number_of_ligand_atoms, 3], dtype=torch.float32, device=device
910
+ )
911
+ Y_t = torch.zeros(
912
+ [CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device
913
+ )
914
+ Y_m = torch.zeros(
915
+ [CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device
916
+ )
917
+
918
+ num_nn_update = Y_tmp.shape[1]
919
+ Y[:, :num_nn_update] = Y_tmp
920
+ Y_t[:, :num_nn_update] = Y_t_tmp
921
+ Y_m[:, :num_nn_update] = Y_m_tmp
922
+
923
+ return Y, Y_t, Y_m, D_AB_closest
924
+
925
+
926
+ def featurize(
927
+ input_dict,
928
+ cutoff_for_score=8.0,
929
+ use_atom_context=True,
930
+ number_of_ligand_atoms=16,
931
+ model_type="protein_mpnn",
932
+ ):
933
+ output_dict = {}
934
+ if model_type == "ligand_mpnn":
935
+ mask = input_dict["mask"]
936
+ Y = input_dict["Y"]
937
+ Y_t = input_dict["Y_t"]
938
+ Y_m = input_dict["Y_m"]
939
+ N = input_dict["X"][:, 0, :]
940
+ CA = input_dict["X"][:, 1, :]
941
+ C = input_dict["X"][:, 2, :]
942
+ b = CA - N
943
+ c = C - CA
944
+ a = torch.cross(b, c, axis=-1)
945
+ CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
946
+ Y, Y_t, Y_m, D_XY = get_nearest_neighbours(
947
+ CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms
948
+ )
949
+ mask_XY = (D_XY < cutoff_for_score) * mask * Y_m[:, 0]
950
+ output_dict["mask_XY"] = mask_XY[None,]
951
+ if "side_chain_mask" in list(input_dict):
952
+ output_dict["side_chain_mask"] = input_dict["side_chain_mask"][None,]
953
+ output_dict["Y"] = Y[None,]
954
+ output_dict["Y_t"] = Y_t[None,]
955
+ output_dict["Y_m"] = Y_m[None,]
956
+ if not use_atom_context:
957
+ output_dict["Y_m"] = 0.0 * output_dict["Y_m"]
958
+ elif (
959
+ model_type == "per_residue_label_membrane_mpnn"
960
+ or model_type == "global_label_membrane_mpnn"
961
+ ):
962
+ output_dict["membrane_per_residue_labels"] = input_dict[
963
+ "membrane_per_residue_labels"
964
+ ][None,]
965
+
966
+ R_idx_list = []
967
+ count = 0
968
+ R_idx_prev = -100000
969
+ for R_idx in list(input_dict["R_idx"]):
970
+ if R_idx_prev == R_idx:
971
+ count += 1
972
+ R_idx_list.append(R_idx + count)
973
+ R_idx_prev = R_idx
974
+ R_idx_renumbered = torch.tensor(R_idx_list, device=R_idx.device)
975
+ output_dict["R_idx"] = R_idx_renumbered[None,]
976
+ output_dict["R_idx_original"] = input_dict["R_idx"][None,]
977
+ output_dict["chain_labels"] = input_dict["chain_labels"][None,]
978
+ output_dict["S"] = input_dict["S"][None,]
979
+ output_dict["chain_mask"] = input_dict["chain_mask"][None,]
980
+ output_dict["mask"] = input_dict["mask"][None,]
981
+
982
+ output_dict["X"] = input_dict["X"][None,]
983
+
984
+ if "xyz_37" in list(input_dict):
985
+ output_dict["xyz_37"] = input_dict["xyz_37"][None,]
986
+ output_dict["xyz_37_m"] = input_dict["xyz_37_m"][None,]
987
+
988
+ return output_dict
inputs/1BC8.pdb ADDED
The diff for this file is too large to render. See raw diff
 
model_utils.py ADDED
@@ -0,0 +1,1772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+
3
+ import itertools
4
+ import sys
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+ class ProteinMPNN(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ num_letters=21,
14
+ node_features=128,
15
+ edge_features=128,
16
+ hidden_dim=128,
17
+ num_encoder_layers=3,
18
+ num_decoder_layers=3,
19
+ vocab=21,
20
+ k_neighbors=48,
21
+ augment_eps=0.0,
22
+ dropout=0.0,
23
+ device=None,
24
+ atom_context_num=0,
25
+ model_type="protein_mpnn",
26
+ ligand_mpnn_use_side_chain_context=False,
27
+ ):
28
+ super(ProteinMPNN, self).__init__()
29
+
30
+ self.model_type = model_type
31
+ self.node_features = node_features
32
+ self.edge_features = edge_features
33
+ self.hidden_dim = hidden_dim
34
+
35
+ if self.model_type == "ligand_mpnn":
36
+ self.features = ProteinFeaturesLigand(
37
+ node_features,
38
+ edge_features,
39
+ top_k=k_neighbors,
40
+ augment_eps=augment_eps,
41
+ device=device,
42
+ atom_context_num=atom_context_num,
43
+ use_side_chains=ligand_mpnn_use_side_chain_context,
44
+ )
45
+ self.W_v = torch.nn.Linear(node_features, hidden_dim, bias=True)
46
+ self.W_c = torch.nn.Linear(hidden_dim, hidden_dim, bias=True)
47
+
48
+ self.W_nodes_y = torch.nn.Linear(hidden_dim, hidden_dim, bias=True)
49
+ self.W_edges_y = torch.nn.Linear(hidden_dim, hidden_dim, bias=True)
50
+
51
+ self.V_C = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
52
+ self.V_C_norm = torch.nn.LayerNorm(hidden_dim)
53
+
54
+ self.context_encoder_layers = torch.nn.ModuleList(
55
+ [
56
+ DecLayer(hidden_dim, hidden_dim * 2, dropout=dropout)
57
+ for _ in range(2)
58
+ ]
59
+ )
60
+
61
+ self.y_context_encoder_layers = torch.nn.ModuleList(
62
+ [DecLayerJ(hidden_dim, hidden_dim, dropout=dropout) for _ in range(2)]
63
+ )
64
+ elif self.model_type == "protein_mpnn" or self.model_type == "soluble_mpnn":
65
+ self.features = ProteinFeatures(
66
+ node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps
67
+ )
68
+ elif (
69
+ self.model_type == "per_residue_label_membrane_mpnn"
70
+ or self.model_type == "global_label_membrane_mpnn"
71
+ ):
72
+ self.W_v = torch.nn.Linear(node_features, hidden_dim, bias=True)
73
+ self.features = ProteinFeaturesMembrane(
74
+ node_features,
75
+ edge_features,
76
+ top_k=k_neighbors,
77
+ augment_eps=augment_eps,
78
+ num_classes=3,
79
+ )
80
+ else:
81
+ print("Choose --model_type flag from currently available models")
82
+ sys.exit()
83
+
84
+ self.W_e = torch.nn.Linear(edge_features, hidden_dim, bias=True)
85
+ self.W_s = torch.nn.Embedding(vocab, hidden_dim)
86
+
87
+ self.dropout = torch.nn.Dropout(dropout)
88
+
89
+ # Encoder layers
90
+ self.encoder_layers = torch.nn.ModuleList(
91
+ [
92
+ EncLayer(hidden_dim, hidden_dim * 2, dropout=dropout)
93
+ for _ in range(num_encoder_layers)
94
+ ]
95
+ )
96
+
97
+ # Decoder layers
98
+ self.decoder_layers = torch.nn.ModuleList(
99
+ [
100
+ DecLayer(hidden_dim, hidden_dim * 3, dropout=dropout)
101
+ for _ in range(num_decoder_layers)
102
+ ]
103
+ )
104
+
105
+ self.W_out = torch.nn.Linear(hidden_dim, num_letters, bias=True)
106
+
107
+ for p in self.parameters():
108
+ if p.dim() > 1:
109
+ torch.nn.init.xavier_uniform_(p)
110
+
111
+ def encode(self, feature_dict):
112
+ # xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed
113
+ # xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords
114
+ # Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords
115
+ # Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type
116
+ # Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask
117
+ # X = feature_dict["X"] #[B,L,4,3] - backbone xyz coordinates for N,CA,C,O
118
+ S_true = feature_dict[
119
+ "S"
120
+ ] # [B,L] - integer protein sequence encoded using "restype_STRtoINT"
121
+ # R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index
122
+ mask = feature_dict[
123
+ "mask"
124
+ ] # [B,L] - mask for missing regions - should be removed! all ones most of the time
125
+ # chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters
126
+
127
+ B, L = S_true.shape
128
+ device = S_true.device
129
+
130
+ if self.model_type == "ligand_mpnn":
131
+ V, E, E_idx, Y_nodes, Y_edges, Y_m = self.features(feature_dict)
132
+ h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
133
+ h_E = self.W_e(E)
134
+ h_E_context = self.W_v(V)
135
+
136
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
137
+ mask_attend = mask.unsqueeze(-1) * mask_attend
138
+ for layer in self.encoder_layers:
139
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
140
+
141
+ h_V_C = self.W_c(h_V)
142
+ Y_m_edges = Y_m[:, :, :, None] * Y_m[:, :, None, :]
143
+ Y_nodes = self.W_nodes_y(Y_nodes)
144
+ Y_edges = self.W_edges_y(Y_edges)
145
+ for i in range(len(self.context_encoder_layers)):
146
+ Y_nodes = self.y_context_encoder_layers[i](
147
+ Y_nodes, Y_edges, Y_m, Y_m_edges
148
+ )
149
+ h_E_context_cat = torch.cat([h_E_context, Y_nodes], -1)
150
+ h_V_C = self.context_encoder_layers[i](
151
+ h_V_C, h_E_context_cat, mask, Y_m
152
+ )
153
+
154
+ h_V_C = self.V_C(h_V_C)
155
+ h_V = h_V + self.V_C_norm(self.dropout(h_V_C))
156
+ elif self.model_type == "protein_mpnn" or self.model_type == "soluble_mpnn":
157
+ E, E_idx = self.features(feature_dict)
158
+ h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=device)
159
+ h_E = self.W_e(E)
160
+
161
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
162
+ mask_attend = mask.unsqueeze(-1) * mask_attend
163
+ for layer in self.encoder_layers:
164
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
165
+ elif (
166
+ self.model_type == "per_residue_label_membrane_mpnn"
167
+ or self.model_type == "global_label_membrane_mpnn"
168
+ ):
169
+ V, E, E_idx = self.features(feature_dict)
170
+ h_V = self.W_v(V)
171
+ h_E = self.W_e(E)
172
+
173
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
174
+ mask_attend = mask.unsqueeze(-1) * mask_attend
175
+ for layer in self.encoder_layers:
176
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
177
+
178
+ return h_V, h_E, E_idx
179
+
180
+ def sample(self, feature_dict):
181
+ # xyz_37 = feature_dict["xyz_37"] #[B,L,37,3] - xyz coordinates for all atoms if needed
182
+ # xyz_37_m = feature_dict["xyz_37_m"] #[B,L,37] - mask for all coords
183
+ # Y = feature_dict["Y"] #[B,L,num_context_atoms,3] - for ligandMPNN coords
184
+ # Y_t = feature_dict["Y_t"] #[B,L,num_context_atoms] - element type
185
+ # Y_m = feature_dict["Y_m"] #[B,L,num_context_atoms] - mask
186
+ # X = feature_dict["X"] #[B,L,4,3] - backbone xyz coordinates for N,CA,C,O
187
+ B_decoder = feature_dict["batch_size"]
188
+ S_true = feature_dict[
189
+ "S"
190
+ ] # [B,L] - integer proitein sequence encoded using "restype_STRtoINT"
191
+ # R_idx = feature_dict["R_idx"] #[B,L] - primary sequence residue index
192
+ mask = feature_dict[
193
+ "mask"
194
+ ] # [B,L] - mask for missing regions - should be removed! all ones most of the time
195
+ chain_mask = feature_dict[
196
+ "chain_mask"
197
+ ] # [B,L] - mask for which residues need to be fixed; 0.0 - fixed; 1.0 - will be designed
198
+ bias = feature_dict["bias"] # [B,L,21] - amino acid bias per position
199
+ # chain_labels = feature_dict["chain_labels"] #[B,L] - integer labels for chain letters
200
+ randn = feature_dict[
201
+ "randn"
202
+ ] # [B,L] - random numbers for decoding order; only the first entry is used since decoding within a batch needs to match for symmetry
203
+ temperature = feature_dict[
204
+ "temperature"
205
+ ] # float - sampling temperature; prob = softmax(logits/temperature)
206
+ symmetry_list_of_lists = feature_dict[
207
+ "symmetry_residues"
208
+ ] # [[0, 1, 14], [10,11,14,15], [20, 21]] #indices to select X over length - L
209
+ symmetry_weights_list_of_lists = feature_dict[
210
+ "symmetry_weights"
211
+ ] # [[1.0, 1.0, 1.0], [-2.0,1.1,0.2,1.1], [2.3, 1.1]]
212
+
213
+ B, L = S_true.shape
214
+ device = S_true.device
215
+
216
+ h_V, h_E, E_idx = self.encode(feature_dict)
217
+
218
+ chain_mask = mask * chain_mask # update chain_M to include missing regions
219
+ decoding_order = torch.argsort(
220
+ (chain_mask + 0.0001) * (torch.abs(randn))
221
+ ) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
222
+ if len(symmetry_list_of_lists[0]) == 0 and len(symmetry_list_of_lists) == 1:
223
+ E_idx = E_idx.repeat(B_decoder, 1, 1)
224
+ permutation_matrix_reverse = torch.nn.functional.one_hot(
225
+ decoding_order, num_classes=L
226
+ ).float()
227
+ order_mask_backward = torch.einsum(
228
+ "ij, biq, bjp->bqp",
229
+ (1 - torch.triu(torch.ones(L, L, device=device))),
230
+ permutation_matrix_reverse,
231
+ permutation_matrix_reverse,
232
+ )
233
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
234
+ mask_1D = mask.view([B, L, 1, 1])
235
+ mask_bw = mask_1D * mask_attend
236
+ mask_fw = mask_1D * (1.0 - mask_attend)
237
+
238
+ # repeat for decoding
239
+ S_true = S_true.repeat(B_decoder, 1)
240
+ h_V = h_V.repeat(B_decoder, 1, 1)
241
+ h_E = h_E.repeat(B_decoder, 1, 1, 1)
242
+ chain_mask = chain_mask.repeat(B_decoder, 1)
243
+ mask = mask.repeat(B_decoder, 1)
244
+ bias = bias.repeat(B_decoder, 1, 1)
245
+
246
+ all_probs = torch.zeros(
247
+ (B_decoder, L, 20), device=device, dtype=torch.float32
248
+ )
249
+ all_log_probs = torch.zeros(
250
+ (B_decoder, L, 21), device=device, dtype=torch.float32
251
+ )
252
+ h_S = torch.zeros_like(h_V, device=device)
253
+ S = 20 * torch.ones((B_decoder, L), dtype=torch.int64, device=device)
254
+ h_V_stack = [h_V] + [
255
+ torch.zeros_like(h_V, device=device)
256
+ for _ in range(len(self.decoder_layers))
257
+ ]
258
+
259
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
260
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
261
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
262
+
263
+ for t_ in range(L):
264
+ t = decoding_order[:, t_] # [B]
265
+ chain_mask_t = torch.gather(chain_mask, 1, t[:, None])[:, 0] # [B]
266
+ mask_t = torch.gather(mask, 1, t[:, None])[:, 0] # [B]
267
+ bias_t = torch.gather(bias, 1, t[:, None, None].repeat(1, 1, 21))[
268
+ :, 0, :
269
+ ] # [B,21]
270
+
271
+ E_idx_t = torch.gather(
272
+ E_idx, 1, t[:, None, None].repeat(1, 1, E_idx.shape[-1])
273
+ )
274
+ h_E_t = torch.gather(
275
+ h_E,
276
+ 1,
277
+ t[:, None, None, None].repeat(1, 1, h_E.shape[-2], h_E.shape[-1]),
278
+ )
279
+ h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
280
+ h_EXV_encoder_t = torch.gather(
281
+ h_EXV_encoder_fw,
282
+ 1,
283
+ t[:, None, None, None].repeat(
284
+ 1, 1, h_EXV_encoder_fw.shape[-2], h_EXV_encoder_fw.shape[-1]
285
+ ),
286
+ )
287
+
288
+ mask_bw_t = torch.gather(
289
+ mask_bw,
290
+ 1,
291
+ t[:, None, None, None].repeat(
292
+ 1, 1, mask_bw.shape[-2], mask_bw.shape[-1]
293
+ ),
294
+ )
295
+
296
+ for l, layer in enumerate(self.decoder_layers):
297
+ h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
298
+ h_V_t = torch.gather(
299
+ h_V_stack[l],
300
+ 1,
301
+ t[:, None, None].repeat(1, 1, h_V_stack[l].shape[-1]),
302
+ )
303
+ h_ESV_t = mask_bw_t * h_ESV_decoder_t + h_EXV_encoder_t
304
+ h_V_stack[l + 1].scatter_(
305
+ 1,
306
+ t[:, None, None].repeat(1, 1, h_V.shape[-1]),
307
+ layer(h_V_t, h_ESV_t, mask_V=mask_t),
308
+ )
309
+
310
+ h_V_t = torch.gather(
311
+ h_V_stack[-1],
312
+ 1,
313
+ t[:, None, None].repeat(1, 1, h_V_stack[-1].shape[-1]),
314
+ )[:, 0]
315
+ logits = self.W_out(h_V_t) # [B,21]
316
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # [B,21]
317
+
318
+ probs = torch.nn.functional.softmax(
319
+ (logits + bias_t) / temperature, dim=-1
320
+ ) # [B,21]
321
+ probs_sample = probs[:, :20] / torch.sum(
322
+ probs[:, :20], dim=-1, keepdim=True
323
+ ) # hard omit X #[B,20]
324
+ S_t = torch.multinomial(probs_sample, 1)[:, 0] # [B]
325
+
326
+ all_probs.scatter_(
327
+ 1,
328
+ t[:, None, None].repeat(1, 1, 20),
329
+ (chain_mask_t[:, None, None] * probs_sample[:, None, :]).float(),
330
+ )
331
+ all_log_probs.scatter_(
332
+ 1,
333
+ t[:, None, None].repeat(1, 1, 21),
334
+ (chain_mask_t[:, None, None] * log_probs[:, None, :]).float(),
335
+ )
336
+ S_true_t = torch.gather(S_true, 1, t[:, None])[:, 0]
337
+ S_t = (S_t * chain_mask_t + S_true_t * (1.0 - chain_mask_t)).long()
338
+ h_S.scatter_(
339
+ 1,
340
+ t[:, None, None].repeat(1, 1, h_S.shape[-1]),
341
+ self.W_s(S_t)[:, None, :],
342
+ )
343
+ S.scatter_(1, t[:, None], S_t[:, None])
344
+
345
+ output_dict = {
346
+ "S": S,
347
+ "sampling_probs": all_probs,
348
+ "log_probs": all_log_probs,
349
+ "decoding_order": decoding_order,
350
+ }
351
+ else:
352
+ # weights for symmetric design
353
+ symmetry_weights = torch.ones([L], device=device, dtype=torch.float32)
354
+ for i1, item_list in enumerate(symmetry_list_of_lists):
355
+ for i2, item in enumerate(item_list):
356
+ symmetry_weights[item] = symmetry_weights_list_of_lists[i1][i2]
357
+
358
+ new_decoding_order = []
359
+ for t_dec in list(decoding_order[0,].cpu().data.numpy()):
360
+ if t_dec not in list(itertools.chain(*new_decoding_order)):
361
+ list_a = [item for item in symmetry_list_of_lists if t_dec in item]
362
+ if list_a:
363
+ new_decoding_order.append(list_a[0])
364
+ else:
365
+ new_decoding_order.append([t_dec])
366
+
367
+ decoding_order = torch.tensor(
368
+ list(itertools.chain(*new_decoding_order)), device=device
369
+ )[None,].repeat(B, 1)
370
+
371
+ permutation_matrix_reverse = torch.nn.functional.one_hot(
372
+ decoding_order, num_classes=L
373
+ ).float()
374
+ order_mask_backward = torch.einsum(
375
+ "ij, biq, bjp->bqp",
376
+ (1 - torch.triu(torch.ones(L, L, device=device))),
377
+ permutation_matrix_reverse,
378
+ permutation_matrix_reverse,
379
+ )
380
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
381
+ mask_1D = mask.view([B, L, 1, 1])
382
+ mask_bw = mask_1D * mask_attend
383
+ mask_fw = mask_1D * (1.0 - mask_attend)
384
+
385
+ # repeat for decoding
386
+ S_true = S_true.repeat(B_decoder, 1)
387
+ h_V = h_V.repeat(B_decoder, 1, 1)
388
+ h_E = h_E.repeat(B_decoder, 1, 1, 1)
389
+ E_idx = E_idx.repeat(B_decoder, 1, 1)
390
+ mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1)
391
+ mask_bw = mask_bw.repeat(B_decoder, 1, 1, 1)
392
+ chain_mask = chain_mask.repeat(B_decoder, 1)
393
+ mask = mask.repeat(B_decoder, 1)
394
+ bias = bias.repeat(B_decoder, 1, 1)
395
+
396
+ all_probs = torch.zeros(
397
+ (B_decoder, L, 20), device=device, dtype=torch.float32
398
+ )
399
+ all_log_probs = torch.zeros(
400
+ (B_decoder, L, 21), device=device, dtype=torch.float32
401
+ )
402
+ h_S = torch.zeros_like(h_V, device=device)
403
+ S = 20 * torch.ones((B_decoder, L), dtype=torch.int64, device=device)
404
+ h_V_stack = [h_V] + [
405
+ torch.zeros_like(h_V, device=device)
406
+ for _ in range(len(self.decoder_layers))
407
+ ]
408
+
409
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
410
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
411
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
412
+
413
+ for t_list in new_decoding_order:
414
+ total_logits = 0.0
415
+ for t in t_list:
416
+ chain_mask_t = chain_mask[:, t] # [B]
417
+ mask_t = mask[:, t] # [B]
418
+ bias_t = bias[:, t] # [B, 21]
419
+
420
+ E_idx_t = E_idx[:, t : t + 1]
421
+ h_E_t = h_E[:, t : t + 1]
422
+ h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
423
+ h_EXV_encoder_t = h_EXV_encoder_fw[:, t : t + 1]
424
+ for l, layer in enumerate(self.decoder_layers):
425
+ h_ESV_decoder_t = cat_neighbors_nodes(
426
+ h_V_stack[l], h_ES_t, E_idx_t
427
+ )
428
+ h_V_t = h_V_stack[l][:, t : t + 1]
429
+ h_ESV_t = (
430
+ mask_bw[:, t : t + 1] * h_ESV_decoder_t + h_EXV_encoder_t
431
+ )
432
+ h_V_stack[l + 1][:, t : t + 1, :] = layer(
433
+ h_V_t, h_ESV_t, mask_V=mask_t[:, None]
434
+ )
435
+
436
+ h_V_t = h_V_stack[-1][:, t]
437
+ logits = self.W_out(h_V_t) # [B,21]
438
+ log_probs = torch.nn.functional.log_softmax(
439
+ logits, dim=-1
440
+ ) # [B,21]
441
+ all_log_probs[:, t] = (
442
+ chain_mask_t[:, None] * log_probs
443
+ ).float() # [B,21]
444
+ total_logits += symmetry_weights[t] * logits
445
+
446
+ probs = torch.nn.functional.softmax(
447
+ (total_logits + bias_t) / temperature, dim=-1
448
+ ) # [B,21]
449
+ probs_sample = probs[:, :20] / torch.sum(
450
+ probs[:, :20], dim=-1, keepdim=True
451
+ ) # hard omit X #[B,20]
452
+ S_t = torch.multinomial(probs_sample, 1)[:, 0] # [B]
453
+ for t in t_list:
454
+ chain_mask_t = chain_mask[:, t] # [B]
455
+ all_probs[:, t] = (
456
+ chain_mask_t[:, None] * probs_sample
457
+ ).float() # [B,20]
458
+ S_true_t = S_true[:, t] # [B]
459
+ S_t = (S_t * chain_mask_t + S_true_t * (1.0 - chain_mask_t)).long()
460
+ h_S[:, t] = self.W_s(S_t)
461
+ S[:, t] = S_t
462
+
463
+ output_dict = {
464
+ "S": S,
465
+ "sampling_probs": all_probs,
466
+ "log_probs": all_log_probs,
467
+ "decoding_order": decoding_order.repeat(B_decoder, 1),
468
+ }
469
+ return output_dict
470
+
471
+ def single_aa_score(self, feature_dict, use_sequence: bool):
472
+ """
473
+ feature_dict - input features
474
+ use_sequence - False using backbone info only
475
+ """
476
+ B_decoder = feature_dict["batch_size"]
477
+ S_true_enc = feature_dict[
478
+ "S"
479
+ ]
480
+ mask_enc = feature_dict[
481
+ "mask"
482
+ ]
483
+ chain_mask_enc = feature_dict[
484
+ "chain_mask"
485
+ ]
486
+ randn = feature_dict[
487
+ "randn"
488
+ ]
489
+ B, L = S_true_enc.shape
490
+ device = S_true_enc.device
491
+
492
+ h_V_enc, h_E_enc, E_idx_enc = self.encode(feature_dict)
493
+ log_probs_out = torch.zeros([B_decoder, L, 21], device=device).float()
494
+ logits_out = torch.zeros([B_decoder, L, 21], device=device).float()
495
+ decoding_order_out = torch.zeros([B_decoder, L, L], device=device).float()
496
+
497
+ for idx in range(L):
498
+ h_V = torch.clone(h_V_enc)
499
+ E_idx = torch.clone(E_idx_enc)
500
+ mask = torch.clone(mask_enc)
501
+ S_true = torch.clone(S_true_enc)
502
+ if not use_sequence:
503
+ order_mask = torch.zeros(chain_mask_enc.shape[1], device=device).float()
504
+ order_mask[idx] = 1.
505
+ else:
506
+ order_mask = torch.ones(chain_mask_enc.shape[1], device=device).float()
507
+ order_mask[idx] = 0.
508
+ decoding_order = torch.argsort(
509
+ (order_mask + 0.0001) * (torch.abs(randn))
510
+ ) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
511
+ E_idx = E_idx.repeat(B_decoder, 1, 1)
512
+ permutation_matrix_reverse = torch.nn.functional.one_hot(
513
+ decoding_order, num_classes=L
514
+ ).float()
515
+ order_mask_backward = torch.einsum(
516
+ "ij, biq, bjp->bqp",
517
+ (1 - torch.triu(torch.ones(L, L, device=device))),
518
+ permutation_matrix_reverse,
519
+ permutation_matrix_reverse,
520
+ )
521
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
522
+ mask_1D = mask.view([B, L, 1, 1])
523
+ mask_bw = mask_1D * mask_attend
524
+ mask_fw = mask_1D * (1.0 - mask_attend)
525
+ S_true = S_true.repeat(B_decoder, 1)
526
+ h_V = h_V.repeat(B_decoder, 1, 1)
527
+ h_E = h_E_enc.repeat(B_decoder, 1, 1, 1)
528
+ mask = mask.repeat(B_decoder, 1)
529
+
530
+ h_S = self.W_s(S_true)
531
+ h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
532
+
533
+ # Build encoder embeddings
534
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
535
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
536
+
537
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
538
+ for layer in self.decoder_layers:
539
+ # Masked positions attend to encoder information, unmasked see.
540
+ h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
541
+ h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
542
+ h_V = layer(h_V, h_ESV, mask)
543
+
544
+ logits = self.W_out(h_V)
545
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
546
+
547
+ log_probs_out[:,idx,:] = log_probs[:,idx,:]
548
+ logits_out[:,idx,:] = logits[:,idx,:]
549
+ decoding_order_out[:,idx,:] = decoding_order
550
+
551
+ output_dict = {
552
+ "S": S_true,
553
+ "log_probs": log_probs_out,
554
+ "logits": logits_out,
555
+ "decoding_order": decoding_order_out,
556
+ }
557
+ return output_dict
558
+
559
+
560
+ def score(self, feature_dict, use_sequence: bool):
561
+ B_decoder = feature_dict["batch_size"]
562
+ S_true = feature_dict[
563
+ "S"
564
+ ]
565
+ mask = feature_dict[
566
+ "mask"
567
+ ]
568
+ chain_mask = feature_dict[
569
+ "chain_mask"
570
+ ]
571
+ randn = feature_dict[
572
+ "randn"
573
+ ]
574
+ symmetry_list_of_lists = feature_dict[
575
+ "symmetry_residues"
576
+ ]
577
+ B, L = S_true.shape
578
+ device = S_true.device
579
+
580
+ h_V, h_E, E_idx = self.encode(feature_dict)
581
+
582
+ chain_mask = mask * chain_mask # update chain_M to include missing regions
583
+ decoding_order = torch.argsort(
584
+ (chain_mask + 0.0001) * (torch.abs(randn))
585
+ ) # [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
586
+ if len(symmetry_list_of_lists[0]) == 0 and len(symmetry_list_of_lists) == 1:
587
+ E_idx = E_idx.repeat(B_decoder, 1, 1)
588
+ permutation_matrix_reverse = torch.nn.functional.one_hot(
589
+ decoding_order, num_classes=L
590
+ ).float()
591
+ order_mask_backward = torch.einsum(
592
+ "ij, biq, bjp->bqp",
593
+ (1 - torch.triu(torch.ones(L, L, device=device))),
594
+ permutation_matrix_reverse,
595
+ permutation_matrix_reverse,
596
+ )
597
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
598
+ mask_1D = mask.view([B, L, 1, 1])
599
+ mask_bw = mask_1D * mask_attend
600
+ mask_fw = mask_1D * (1.0 - mask_attend)
601
+ else:
602
+ new_decoding_order = []
603
+ for t_dec in list(decoding_order[0,].cpu().data.numpy()):
604
+ if t_dec not in list(itertools.chain(*new_decoding_order)):
605
+ list_a = [item for item in symmetry_list_of_lists if t_dec in item]
606
+ if list_a:
607
+ new_decoding_order.append(list_a[0])
608
+ else:
609
+ new_decoding_order.append([t_dec])
610
+
611
+ decoding_order = torch.tensor(
612
+ list(itertools.chain(*new_decoding_order)), device=device
613
+ )[None,].repeat(B, 1)
614
+
615
+ permutation_matrix_reverse = torch.nn.functional.one_hot(
616
+ decoding_order, num_classes=L
617
+ ).float()
618
+ order_mask_backward = torch.einsum(
619
+ "ij, biq, bjp->bqp",
620
+ (1 - torch.triu(torch.ones(L, L, device=device))),
621
+ permutation_matrix_reverse,
622
+ permutation_matrix_reverse,
623
+ )
624
+ mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
625
+ mask_1D = mask.view([B, L, 1, 1])
626
+ mask_bw = mask_1D * mask_attend
627
+ mask_fw = mask_1D * (1.0 - mask_attend)
628
+
629
+ E_idx = E_idx.repeat(B_decoder, 1, 1)
630
+ mask_fw = mask_fw.repeat(B_decoder, 1, 1, 1)
631
+ mask_bw = mask_bw.repeat(B_decoder, 1, 1, 1)
632
+ decoding_order = decoding_order.repeat(B_decoder, 1)
633
+
634
+ S_true = S_true.repeat(B_decoder, 1)
635
+ h_V = h_V.repeat(B_decoder, 1, 1)
636
+ h_E = h_E.repeat(B_decoder, 1, 1, 1)
637
+ mask = mask.repeat(B_decoder, 1)
638
+
639
+ h_S = self.W_s(S_true)
640
+ h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
641
+
642
+ # Build encoder embeddings
643
+ h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
644
+ h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
645
+
646
+ h_EXV_encoder_fw = mask_fw * h_EXV_encoder
647
+ if not use_sequence:
648
+ for layer in self.decoder_layers:
649
+ h_V = layer(h_V, h_EXV_encoder_fw, mask)
650
+ else:
651
+ for layer in self.decoder_layers:
652
+ # Masked positions attend to encoder information, unmasked see.
653
+ h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
654
+ h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
655
+ h_V = layer(h_V, h_ESV, mask)
656
+
657
+ logits = self.W_out(h_V)
658
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
659
+
660
+ output_dict = {
661
+ "S": S_true,
662
+ "log_probs": log_probs,
663
+ "logits": logits,
664
+ "decoding_order": decoding_order,
665
+ }
666
+ return output_dict
667
+
668
+
669
+ class ProteinFeaturesLigand(torch.nn.Module):
670
+ def __init__(
671
+ self,
672
+ edge_features,
673
+ node_features,
674
+ num_positional_embeddings=16,
675
+ num_rbf=16,
676
+ top_k=30,
677
+ augment_eps=0.0,
678
+ device=None,
679
+ atom_context_num=16,
680
+ use_side_chains=False,
681
+ ):
682
+ """Extract protein features"""
683
+ super(ProteinFeaturesLigand, self).__init__()
684
+
685
+ self.use_side_chains = use_side_chains
686
+
687
+ self.edge_features = edge_features
688
+ self.node_features = node_features
689
+ self.top_k = top_k
690
+ self.augment_eps = augment_eps
691
+ self.num_rbf = num_rbf
692
+ self.num_positional_embeddings = num_positional_embeddings
693
+
694
+ self.embeddings = PositionalEncodings(num_positional_embeddings)
695
+ edge_in = num_positional_embeddings + num_rbf * 25
696
+ self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False)
697
+ self.norm_edges = torch.nn.LayerNorm(edge_features)
698
+
699
+ self.node_project_down = torch.nn.Linear(
700
+ 5 * num_rbf + 64 + 4, node_features, bias=True
701
+ )
702
+ self.norm_nodes = torch.nn.LayerNorm(node_features)
703
+
704
+ self.type_linear = torch.nn.Linear(147, 64)
705
+
706
+ self.y_nodes = torch.nn.Linear(147, node_features, bias=False)
707
+ self.y_edges = torch.nn.Linear(num_rbf, node_features, bias=False)
708
+
709
+ self.norm_y_edges = torch.nn.LayerNorm(node_features)
710
+ self.norm_y_nodes = torch.nn.LayerNorm(node_features)
711
+
712
+ self.atom_context_num = atom_context_num
713
+
714
+ # the last 32 atoms in the 37 atom representation
715
+ self.side_chain_atom_types = torch.tensor(
716
+ [
717
+ 6,
718
+ 6,
719
+ 6,
720
+ 8,
721
+ 8,
722
+ 16,
723
+ 6,
724
+ 6,
725
+ 6,
726
+ 7,
727
+ 7,
728
+ 8,
729
+ 8,
730
+ 16,
731
+ 6,
732
+ 6,
733
+ 6,
734
+ 6,
735
+ 7,
736
+ 7,
737
+ 7,
738
+ 8,
739
+ 8,
740
+ 6,
741
+ 7,
742
+ 7,
743
+ 8,
744
+ 6,
745
+ 6,
746
+ 6,
747
+ 7,
748
+ 8,
749
+ ],
750
+ device=device,
751
+ )
752
+
753
+ self.periodic_table_features = torch.tensor(
754
+ [
755
+ [
756
+ 0,
757
+ 1,
758
+ 2,
759
+ 3,
760
+ 4,
761
+ 5,
762
+ 6,
763
+ 7,
764
+ 8,
765
+ 9,
766
+ 10,
767
+ 11,
768
+ 12,
769
+ 13,
770
+ 14,
771
+ 15,
772
+ 16,
773
+ 17,
774
+ 18,
775
+ 19,
776
+ 20,
777
+ 21,
778
+ 22,
779
+ 23,
780
+ 24,
781
+ 25,
782
+ 26,
783
+ 27,
784
+ 28,
785
+ 29,
786
+ 30,
787
+ 31,
788
+ 32,
789
+ 33,
790
+ 34,
791
+ 35,
792
+ 36,
793
+ 37,
794
+ 38,
795
+ 39,
796
+ 40,
797
+ 41,
798
+ 42,
799
+ 43,
800
+ 44,
801
+ 45,
802
+ 46,
803
+ 47,
804
+ 48,
805
+ 49,
806
+ 50,
807
+ 51,
808
+ 52,
809
+ 53,
810
+ 54,
811
+ 55,
812
+ 56,
813
+ 57,
814
+ 58,
815
+ 59,
816
+ 60,
817
+ 61,
818
+ 62,
819
+ 63,
820
+ 64,
821
+ 65,
822
+ 66,
823
+ 67,
824
+ 68,
825
+ 69,
826
+ 70,
827
+ 71,
828
+ 72,
829
+ 73,
830
+ 74,
831
+ 75,
832
+ 76,
833
+ 77,
834
+ 78,
835
+ 79,
836
+ 80,
837
+ 81,
838
+ 82,
839
+ 83,
840
+ 84,
841
+ 85,
842
+ 86,
843
+ 87,
844
+ 88,
845
+ 89,
846
+ 90,
847
+ 91,
848
+ 92,
849
+ 93,
850
+ 94,
851
+ 95,
852
+ 96,
853
+ 97,
854
+ 98,
855
+ 99,
856
+ 100,
857
+ 101,
858
+ 102,
859
+ 103,
860
+ 104,
861
+ 105,
862
+ 106,
863
+ 107,
864
+ 108,
865
+ 109,
866
+ 110,
867
+ 111,
868
+ 112,
869
+ 113,
870
+ 114,
871
+ 115,
872
+ 116,
873
+ 117,
874
+ 118,
875
+ ],
876
+ [
877
+ 0,
878
+ 1,
879
+ 18,
880
+ 1,
881
+ 2,
882
+ 13,
883
+ 14,
884
+ 15,
885
+ 16,
886
+ 17,
887
+ 18,
888
+ 1,
889
+ 2,
890
+ 13,
891
+ 14,
892
+ 15,
893
+ 16,
894
+ 17,
895
+ 18,
896
+ 1,
897
+ 2,
898
+ 3,
899
+ 4,
900
+ 5,
901
+ 6,
902
+ 7,
903
+ 8,
904
+ 9,
905
+ 10,
906
+ 11,
907
+ 12,
908
+ 13,
909
+ 14,
910
+ 15,
911
+ 16,
912
+ 17,
913
+ 18,
914
+ 1,
915
+ 2,
916
+ 3,
917
+ 4,
918
+ 5,
919
+ 6,
920
+ 7,
921
+ 8,
922
+ 9,
923
+ 10,
924
+ 11,
925
+ 12,
926
+ 13,
927
+ 14,
928
+ 15,
929
+ 16,
930
+ 17,
931
+ 18,
932
+ 1,
933
+ 2,
934
+ 3,
935
+ 3,
936
+ 3,
937
+ 3,
938
+ 3,
939
+ 3,
940
+ 3,
941
+ 3,
942
+ 3,
943
+ 3,
944
+ 3,
945
+ 3,
946
+ 3,
947
+ 3,
948
+ 3,
949
+ 4,
950
+ 5,
951
+ 6,
952
+ 7,
953
+ 8,
954
+ 9,
955
+ 10,
956
+ 11,
957
+ 12,
958
+ 13,
959
+ 14,
960
+ 15,
961
+ 16,
962
+ 17,
963
+ 18,
964
+ 1,
965
+ 2,
966
+ 3,
967
+ 3,
968
+ 3,
969
+ 3,
970
+ 3,
971
+ 3,
972
+ 3,
973
+ 3,
974
+ 3,
975
+ 3,
976
+ 3,
977
+ 3,
978
+ 3,
979
+ 3,
980
+ 3,
981
+ 4,
982
+ 5,
983
+ 6,
984
+ 7,
985
+ 8,
986
+ 9,
987
+ 10,
988
+ 11,
989
+ 12,
990
+ 13,
991
+ 14,
992
+ 15,
993
+ 16,
994
+ 17,
995
+ 18,
996
+ ],
997
+ [
998
+ 0,
999
+ 1,
1000
+ 1,
1001
+ 2,
1002
+ 2,
1003
+ 2,
1004
+ 2,
1005
+ 2,
1006
+ 2,
1007
+ 2,
1008
+ 2,
1009
+ 3,
1010
+ 3,
1011
+ 3,
1012
+ 3,
1013
+ 3,
1014
+ 3,
1015
+ 3,
1016
+ 3,
1017
+ 4,
1018
+ 4,
1019
+ 4,
1020
+ 4,
1021
+ 4,
1022
+ 4,
1023
+ 4,
1024
+ 4,
1025
+ 4,
1026
+ 4,
1027
+ 4,
1028
+ 4,
1029
+ 4,
1030
+ 4,
1031
+ 4,
1032
+ 4,
1033
+ 4,
1034
+ 4,
1035
+ 5,
1036
+ 5,
1037
+ 5,
1038
+ 5,
1039
+ 5,
1040
+ 5,
1041
+ 5,
1042
+ 5,
1043
+ 5,
1044
+ 5,
1045
+ 5,
1046
+ 5,
1047
+ 5,
1048
+ 5,
1049
+ 5,
1050
+ 5,
1051
+ 5,
1052
+ 5,
1053
+ 6,
1054
+ 6,
1055
+ 6,
1056
+ 6,
1057
+ 6,
1058
+ 6,
1059
+ 6,
1060
+ 6,
1061
+ 6,
1062
+ 6,
1063
+ 6,
1064
+ 6,
1065
+ 6,
1066
+ 6,
1067
+ 6,
1068
+ 6,
1069
+ 6,
1070
+ 6,
1071
+ 6,
1072
+ 6,
1073
+ 6,
1074
+ 6,
1075
+ 6,
1076
+ 6,
1077
+ 6,
1078
+ 6,
1079
+ 6,
1080
+ 6,
1081
+ 6,
1082
+ 6,
1083
+ 6,
1084
+ 6,
1085
+ 7,
1086
+ 7,
1087
+ 7,
1088
+ 7,
1089
+ 7,
1090
+ 7,
1091
+ 7,
1092
+ 7,
1093
+ 7,
1094
+ 7,
1095
+ 7,
1096
+ 7,
1097
+ 7,
1098
+ 7,
1099
+ 7,
1100
+ 7,
1101
+ 7,
1102
+ 7,
1103
+ 7,
1104
+ 7,
1105
+ 7,
1106
+ 7,
1107
+ 7,
1108
+ 7,
1109
+ 7,
1110
+ 7,
1111
+ 7,
1112
+ 7,
1113
+ 7,
1114
+ 7,
1115
+ 7,
1116
+ 7,
1117
+ ],
1118
+ ],
1119
+ dtype=torch.long,
1120
+ device=device,
1121
+ )
1122
+
1123
+ def _make_angle_features(self, A, B, C, Y):
1124
+ v1 = A - B
1125
+ v2 = C - B
1126
+ e1 = torch.nn.functional.normalize(v1, dim=-1)
1127
+ e1_v2_dot = torch.einsum("bli, bli -> bl", e1, v2)[..., None]
1128
+ u2 = v2 - e1 * e1_v2_dot
1129
+ e2 = torch.nn.functional.normalize(u2, dim=-1)
1130
+ e3 = torch.cross(e1, e2, dim=-1)
1131
+ R_residue = torch.cat(
1132
+ (e1[:, :, :, None], e2[:, :, :, None], e3[:, :, :, None]), dim=-1
1133
+ )
1134
+
1135
+ local_vectors = torch.einsum(
1136
+ "blqp, blyq -> blyp", R_residue, Y - B[:, :, None, :]
1137
+ )
1138
+
1139
+ rxy = torch.sqrt(local_vectors[..., 0] ** 2 + local_vectors[..., 1] ** 2 + 1e-8)
1140
+ f1 = local_vectors[..., 0] / rxy
1141
+ f2 = local_vectors[..., 1] / rxy
1142
+ rxyz = torch.norm(local_vectors, dim=-1) + 1e-8
1143
+ f3 = rxy / rxyz
1144
+ f4 = local_vectors[..., 2] / rxyz
1145
+
1146
+ f = torch.cat([f1[..., None], f2[..., None], f3[..., None], f4[..., None]], -1)
1147
+ return f
1148
+
1149
+ def _dist(self, X, mask, eps=1e-6):
1150
+ mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2)
1151
+ dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2)
1152
+ D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
1153
+ D_max, _ = torch.max(D, -1, keepdim=True)
1154
+ D_adjust = D + (1.0 - mask_2D) * D_max
1155
+ D_neighbors, E_idx = torch.topk(
1156
+ D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False
1157
+ )
1158
+ return D_neighbors, E_idx
1159
+
1160
+ def _rbf(self, D):
1161
+ device = D.device
1162
+ D_min, D_max, D_count = 2.0, 22.0, self.num_rbf
1163
+ D_mu = torch.linspace(D_min, D_max, D_count, device=device)
1164
+ D_mu = D_mu.view([1, 1, 1, -1])
1165
+ D_sigma = (D_max - D_min) / D_count
1166
+ D_expand = torch.unsqueeze(D, -1)
1167
+ RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2))
1168
+ return RBF
1169
+
1170
+ def _get_rbf(self, A, B, E_idx):
1171
+ D_A_B = torch.sqrt(
1172
+ torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6
1173
+ ) # [B, L, L]
1174
+ D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[
1175
+ :, :, :, 0
1176
+ ] # [B,L,K]
1177
+ RBF_A_B = self._rbf(D_A_B_neighbors)
1178
+ return RBF_A_B
1179
+
1180
+ def forward(self, input_features):
1181
+ Y = input_features["Y"]
1182
+ Y_m = input_features["Y_m"]
1183
+ Y_t = input_features["Y_t"]
1184
+ X = input_features["X"]
1185
+ mask = input_features["mask"]
1186
+ R_idx = input_features["R_idx"]
1187
+ chain_labels = input_features["chain_labels"]
1188
+
1189
+ if self.augment_eps > 0:
1190
+ X = X + self.augment_eps * torch.randn_like(X)
1191
+ Y = Y + self.augment_eps * torch.randn_like(Y)
1192
+
1193
+ B, L, _, _ = X.shape
1194
+
1195
+ Ca = X[:, :, 1, :]
1196
+ N = X[:, :, 0, :]
1197
+ C = X[:, :, 2, :]
1198
+ O = X[:, :, 3, :]
1199
+
1200
+ b = Ca - N
1201
+ c = C - Ca
1202
+ a = torch.cross(b, c, dim=-1)
1203
+ Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca # shift from CA
1204
+
1205
+ D_neighbors, E_idx = self._dist(Ca, mask)
1206
+
1207
+ RBF_all = []
1208
+ RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca
1209
+ RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N
1210
+ RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C
1211
+ RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O
1212
+ RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb
1213
+ RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N
1214
+ RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C
1215
+ RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O
1216
+ RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb
1217
+ RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C
1218
+ RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O
1219
+ RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb
1220
+ RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C
1221
+ RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O
1222
+ RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C
1223
+ RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca
1224
+ RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca
1225
+ RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca
1226
+ RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca
1227
+ RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N
1228
+ RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N
1229
+ RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N
1230
+ RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb
1231
+ RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb
1232
+ RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O
1233
+ RBF_all = torch.cat(tuple(RBF_all), dim=-1)
1234
+
1235
+ offset = R_idx[:, :, None] - R_idx[:, None, :]
1236
+ offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K]
1237
+
1238
+ d_chains = (
1239
+ (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0
1240
+ ).long() # find self vs non-self interaction
1241
+ E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0]
1242
+ E_positional = self.embeddings(offset.long(), E_chains)
1243
+ E = torch.cat((E_positional, RBF_all), -1)
1244
+ E = self.edge_embedding(E)
1245
+ E = self.norm_edges(E)
1246
+
1247
+ if self.use_side_chains:
1248
+ xyz_37 = input_features["xyz_37"]
1249
+ xyz_37_m = input_features["xyz_37_m"]
1250
+ E_idx_sub = E_idx[:, :, :16] # [B, L, 15]
1251
+ mask_residues = input_features["chain_mask"]
1252
+ xyz_37_m = xyz_37_m * (1 - mask_residues[:, :, None])
1253
+ R_m = gather_nodes(xyz_37_m[:, :, 5:], E_idx_sub)
1254
+
1255
+ X_sidechain = xyz_37[:, :, 5:, :].view(B, L, -1)
1256
+ R = gather_nodes(X_sidechain, E_idx_sub).view(
1257
+ B, L, E_idx_sub.shape[2], -1, 3
1258
+ )
1259
+ R_t = self.side_chain_atom_types[None, None, None, :].repeat(
1260
+ B, L, E_idx_sub.shape[2], 1
1261
+ )
1262
+
1263
+ # Side chain atom context
1264
+ R = R.view(B, L, -1, 3) # coordinates
1265
+ R_m = R_m.view(B, L, -1) # mask
1266
+ R_t = R_t.view(B, L, -1) # atom types
1267
+
1268
+ # Ligand atom context
1269
+ Y = torch.cat((R, Y), 2) # [B, L, atoms, 3]
1270
+ Y_m = torch.cat((R_m, Y_m), 2) # [B, L, atoms]
1271
+ Y_t = torch.cat((R_t, Y_t), 2) # [B, L, atoms]
1272
+
1273
+ Cb_Y_distances = torch.sum((Cb[:, :, None, :] - Y) ** 2, -1)
1274
+ mask_Y = mask[:, :, None] * Y_m
1275
+ Cb_Y_distances_adjusted = Cb_Y_distances * mask_Y + (1.0 - mask_Y) * 10000.0
1276
+ _, E_idx_Y = torch.topk(
1277
+ Cb_Y_distances_adjusted, self.atom_context_num, dim=-1, largest=False
1278
+ )
1279
+
1280
+ Y = torch.gather(Y, 2, E_idx_Y[:, :, :, None].repeat(1, 1, 1, 3))
1281
+ Y_t = torch.gather(Y_t, 2, E_idx_Y)
1282
+ Y_m = torch.gather(Y_m, 2, E_idx_Y)
1283
+
1284
+ Y_t = Y_t.long()
1285
+ Y_t_g = self.periodic_table_features[1][Y_t] # group; 19 categories including 0
1286
+ Y_t_p = self.periodic_table_features[2][Y_t] # period; 8 categories including 0
1287
+
1288
+ Y_t_g_1hot_ = torch.nn.functional.one_hot(Y_t_g, 19) # [B, L, M, 19]
1289
+ Y_t_p_1hot_ = torch.nn.functional.one_hot(Y_t_p, 8) # [B, L, M, 8]
1290
+ Y_t_1hot_ = torch.nn.functional.one_hot(Y_t, 120) # [B, L, M, 120]
1291
+
1292
+ Y_t_1hot_ = torch.cat(
1293
+ [Y_t_1hot_, Y_t_g_1hot_, Y_t_p_1hot_], -1
1294
+ ) # [B, L, M, 147]
1295
+ Y_t_1hot = self.type_linear(Y_t_1hot_.float())
1296
+
1297
+ D_N_Y = self._rbf(
1298
+ torch.sqrt(torch.sum((N[:, :, None, :] - Y) ** 2, -1) + 1e-6)
1299
+ ) # [B, L, M, num_bins]
1300
+ D_Ca_Y = self._rbf(
1301
+ torch.sqrt(torch.sum((Ca[:, :, None, :] - Y) ** 2, -1) + 1e-6)
1302
+ )
1303
+ D_C_Y = self._rbf(torch.sqrt(torch.sum((C[:, :, None, :] - Y) ** 2, -1) + 1e-6))
1304
+ D_O_Y = self._rbf(torch.sqrt(torch.sum((O[:, :, None, :] - Y) ** 2, -1) + 1e-6))
1305
+ D_Cb_Y = self._rbf(
1306
+ torch.sqrt(torch.sum((Cb[:, :, None, :] - Y) ** 2, -1) + 1e-6)
1307
+ )
1308
+
1309
+ f_angles = self._make_angle_features(N, Ca, C, Y) # [B, L, M, 4]
1310
+
1311
+ D_all = torch.cat(
1312
+ (D_N_Y, D_Ca_Y, D_C_Y, D_O_Y, D_Cb_Y, Y_t_1hot, f_angles), dim=-1
1313
+ ) # [B,L,M,5*num_bins+5]
1314
+ V = self.node_project_down(D_all) # [B, L, M, node_features]
1315
+ V = self.norm_nodes(V)
1316
+
1317
+ Y_edges = self._rbf(
1318
+ torch.sqrt(
1319
+ torch.sum((Y[:, :, :, None, :] - Y[:, :, None, :, :]) ** 2, -1) + 1e-6
1320
+ )
1321
+ ) # [B, L, M, M, num_bins]
1322
+
1323
+ Y_edges = self.y_edges(Y_edges)
1324
+ Y_nodes = self.y_nodes(Y_t_1hot_.float())
1325
+
1326
+ Y_edges = self.norm_y_edges(Y_edges)
1327
+ Y_nodes = self.norm_y_nodes(Y_nodes)
1328
+
1329
+ return V, E, E_idx, Y_nodes, Y_edges, Y_m
1330
+
1331
+
1332
+ class ProteinFeatures(torch.nn.Module):
1333
+ def __init__(
1334
+ self,
1335
+ edge_features,
1336
+ node_features,
1337
+ num_positional_embeddings=16,
1338
+ num_rbf=16,
1339
+ top_k=48,
1340
+ augment_eps=0.0,
1341
+ ):
1342
+ """Extract protein features"""
1343
+ super(ProteinFeatures, self).__init__()
1344
+ self.edge_features = edge_features
1345
+ self.node_features = node_features
1346
+ self.top_k = top_k
1347
+ self.augment_eps = augment_eps
1348
+ self.num_rbf = num_rbf
1349
+ self.num_positional_embeddings = num_positional_embeddings
1350
+
1351
+ self.embeddings = PositionalEncodings(num_positional_embeddings)
1352
+ edge_in = num_positional_embeddings + num_rbf * 25
1353
+ self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False)
1354
+ self.norm_edges = torch.nn.LayerNorm(edge_features)
1355
+
1356
+ def _dist(self, X, mask, eps=1e-6):
1357
+ mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2)
1358
+ dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2)
1359
+ D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
1360
+ D_max, _ = torch.max(D, -1, keepdim=True)
1361
+ D_adjust = D + (1.0 - mask_2D) * D_max
1362
+ D_neighbors, E_idx = torch.topk(
1363
+ D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False
1364
+ )
1365
+ return D_neighbors, E_idx
1366
+
1367
+ def _rbf(self, D):
1368
+ device = D.device
1369
+ D_min, D_max, D_count = 2.0, 22.0, self.num_rbf
1370
+ D_mu = torch.linspace(D_min, D_max, D_count, device=device)
1371
+ D_mu = D_mu.view([1, 1, 1, -1])
1372
+ D_sigma = (D_max - D_min) / D_count
1373
+ D_expand = torch.unsqueeze(D, -1)
1374
+ RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2))
1375
+ return RBF
1376
+
1377
+ def _get_rbf(self, A, B, E_idx):
1378
+ D_A_B = torch.sqrt(
1379
+ torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6
1380
+ ) # [B, L, L]
1381
+ D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[
1382
+ :, :, :, 0
1383
+ ] # [B,L,K]
1384
+ RBF_A_B = self._rbf(D_A_B_neighbors)
1385
+ return RBF_A_B
1386
+
1387
+ def forward(self, input_features):
1388
+ X = input_features["X"]
1389
+ mask = input_features["mask"]
1390
+ R_idx = input_features["R_idx"]
1391
+ chain_labels = input_features["chain_labels"]
1392
+
1393
+ if self.augment_eps > 0:
1394
+ X = X + self.augment_eps * torch.randn_like(X)
1395
+
1396
+ b = X[:, :, 1, :] - X[:, :, 0, :]
1397
+ c = X[:, :, 2, :] - X[:, :, 1, :]
1398
+ a = torch.cross(b, c, dim=-1)
1399
+ Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + X[:, :, 1, :]
1400
+ Ca = X[:, :, 1, :]
1401
+ N = X[:, :, 0, :]
1402
+ C = X[:, :, 2, :]
1403
+ O = X[:, :, 3, :]
1404
+
1405
+ D_neighbors, E_idx = self._dist(Ca, mask)
1406
+
1407
+ RBF_all = []
1408
+ RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca
1409
+ RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N
1410
+ RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C
1411
+ RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O
1412
+ RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb
1413
+ RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N
1414
+ RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C
1415
+ RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O
1416
+ RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb
1417
+ RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C
1418
+ RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O
1419
+ RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb
1420
+ RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C
1421
+ RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O
1422
+ RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C
1423
+ RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca
1424
+ RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca
1425
+ RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca
1426
+ RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca
1427
+ RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N
1428
+ RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N
1429
+ RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N
1430
+ RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb
1431
+ RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb
1432
+ RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O
1433
+ RBF_all = torch.cat(tuple(RBF_all), dim=-1)
1434
+
1435
+ offset = R_idx[:, :, None] - R_idx[:, None, :]
1436
+ offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K]
1437
+
1438
+ d_chains = (
1439
+ (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0
1440
+ ).long() # find self vs non-self interaction
1441
+ E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0]
1442
+ E_positional = self.embeddings(offset.long(), E_chains)
1443
+ E = torch.cat((E_positional, RBF_all), -1)
1444
+ E = self.edge_embedding(E)
1445
+ E = self.norm_edges(E)
1446
+
1447
+ return E, E_idx
1448
+
1449
+
1450
+ class ProteinFeaturesMembrane(torch.nn.Module):
1451
+ def __init__(
1452
+ self,
1453
+ edge_features,
1454
+ node_features,
1455
+ num_positional_embeddings=16,
1456
+ num_rbf=16,
1457
+ top_k=48,
1458
+ augment_eps=0.0,
1459
+ num_classes=3,
1460
+ ):
1461
+ """Extract protein features"""
1462
+ super(ProteinFeaturesMembrane, self).__init__()
1463
+ self.edge_features = edge_features
1464
+ self.node_features = node_features
1465
+ self.top_k = top_k
1466
+ self.augment_eps = augment_eps
1467
+ self.num_rbf = num_rbf
1468
+ self.num_positional_embeddings = num_positional_embeddings
1469
+ self.num_classes = num_classes
1470
+
1471
+ self.embeddings = PositionalEncodings(num_positional_embeddings)
1472
+ edge_in = num_positional_embeddings + num_rbf * 25
1473
+ self.edge_embedding = torch.nn.Linear(edge_in, edge_features, bias=False)
1474
+ self.norm_edges = torch.nn.LayerNorm(edge_features)
1475
+
1476
+ self.node_embedding = torch.nn.Linear(
1477
+ self.num_classes, node_features, bias=False
1478
+ )
1479
+ self.norm_nodes = torch.nn.LayerNorm(node_features)
1480
+
1481
+ def _dist(self, X, mask, eps=1e-6):
1482
+ mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2)
1483
+ dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2)
1484
+ D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
1485
+ D_max, _ = torch.max(D, -1, keepdim=True)
1486
+ D_adjust = D + (1.0 - mask_2D) * D_max
1487
+ D_neighbors, E_idx = torch.topk(
1488
+ D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False
1489
+ )
1490
+ return D_neighbors, E_idx
1491
+
1492
+ def _rbf(self, D):
1493
+ device = D.device
1494
+ D_min, D_max, D_count = 2.0, 22.0, self.num_rbf
1495
+ D_mu = torch.linspace(D_min, D_max, D_count, device=device)
1496
+ D_mu = D_mu.view([1, 1, 1, -1])
1497
+ D_sigma = (D_max - D_min) / D_count
1498
+ D_expand = torch.unsqueeze(D, -1)
1499
+ RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2))
1500
+ return RBF
1501
+
1502
+ def _get_rbf(self, A, B, E_idx):
1503
+ D_A_B = torch.sqrt(
1504
+ torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6
1505
+ ) # [B, L, L]
1506
+ D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[
1507
+ :, :, :, 0
1508
+ ] # [B,L,K]
1509
+ RBF_A_B = self._rbf(D_A_B_neighbors)
1510
+ return RBF_A_B
1511
+
1512
+ def forward(self, input_features):
1513
+ X = input_features["X"]
1514
+ mask = input_features["mask"]
1515
+ R_idx = input_features["R_idx"]
1516
+ chain_labels = input_features["chain_labels"]
1517
+ membrane_per_residue_labels = input_features["membrane_per_residue_labels"]
1518
+
1519
+ if self.augment_eps > 0:
1520
+ X = X + self.augment_eps * torch.randn_like(X)
1521
+
1522
+ b = X[:, :, 1, :] - X[:, :, 0, :]
1523
+ c = X[:, :, 2, :] - X[:, :, 1, :]
1524
+ a = torch.cross(b, c, dim=-1)
1525
+ Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + X[:, :, 1, :]
1526
+ Ca = X[:, :, 1, :]
1527
+ N = X[:, :, 0, :]
1528
+ C = X[:, :, 2, :]
1529
+ O = X[:, :, 3, :]
1530
+
1531
+ D_neighbors, E_idx = self._dist(Ca, mask)
1532
+
1533
+ RBF_all = []
1534
+ RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca
1535
+ RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N
1536
+ RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C
1537
+ RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O
1538
+ RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb
1539
+ RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N
1540
+ RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C
1541
+ RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O
1542
+ RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb
1543
+ RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C
1544
+ RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O
1545
+ RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb
1546
+ RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C
1547
+ RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O
1548
+ RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C
1549
+ RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca
1550
+ RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca
1551
+ RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca
1552
+ RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca
1553
+ RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N
1554
+ RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N
1555
+ RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N
1556
+ RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb
1557
+ RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb
1558
+ RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O
1559
+ RBF_all = torch.cat(tuple(RBF_all), dim=-1)
1560
+
1561
+ offset = R_idx[:, :, None] - R_idx[:, None, :]
1562
+ offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K]
1563
+
1564
+ d_chains = (
1565
+ (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0
1566
+ ).long() # find self vs non-self interaction
1567
+ E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0]
1568
+ E_positional = self.embeddings(offset.long(), E_chains)
1569
+ E = torch.cat((E_positional, RBF_all), -1)
1570
+ E = self.edge_embedding(E)
1571
+ E = self.norm_edges(E)
1572
+
1573
+ C_1hot = torch.nn.functional.one_hot(
1574
+ membrane_per_residue_labels, self.num_classes
1575
+ ).float()
1576
+ V = self.node_embedding(C_1hot)
1577
+ V = self.norm_nodes(V)
1578
+
1579
+ return V, E, E_idx
1580
+
1581
+
1582
+ class DecLayerJ(torch.nn.Module):
1583
+ def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
1584
+ super(DecLayerJ, self).__init__()
1585
+ self.num_hidden = num_hidden
1586
+ self.num_in = num_in
1587
+ self.scale = scale
1588
+ self.dropout1 = torch.nn.Dropout(dropout)
1589
+ self.dropout2 = torch.nn.Dropout(dropout)
1590
+ self.norm1 = torch.nn.LayerNorm(num_hidden)
1591
+ self.norm2 = torch.nn.LayerNorm(num_hidden)
1592
+
1593
+ self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True)
1594
+ self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
1595
+ self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
1596
+ self.act = torch.nn.GELU()
1597
+ self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
1598
+
1599
+ def forward(self, h_V, h_E, mask_V=None, mask_attend=None):
1600
+ """Parallel computation of full transformer layer"""
1601
+
1602
+ # Concatenate h_V_i to h_E_ij
1603
+ h_V_expand = h_V.unsqueeze(-2).expand(
1604
+ -1, -1, -1, h_E.size(-2), -1
1605
+ ) # the only difference
1606
+ h_EV = torch.cat([h_V_expand, h_E], -1)
1607
+
1608
+ h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
1609
+ if mask_attend is not None:
1610
+ h_message = mask_attend.unsqueeze(-1) * h_message
1611
+ dh = torch.sum(h_message, -2) / self.scale
1612
+
1613
+ h_V = self.norm1(h_V + self.dropout1(dh))
1614
+
1615
+ # Position-wise feedforward
1616
+ dh = self.dense(h_V)
1617
+ h_V = self.norm2(h_V + self.dropout2(dh))
1618
+
1619
+ if mask_V is not None:
1620
+ mask_V = mask_V.unsqueeze(-1)
1621
+ h_V = mask_V * h_V
1622
+ return h_V
1623
+
1624
+
1625
+ class PositionWiseFeedForward(torch.nn.Module):
1626
+ def __init__(self, num_hidden, num_ff):
1627
+ super(PositionWiseFeedForward, self).__init__()
1628
+ self.W_in = torch.nn.Linear(num_hidden, num_ff, bias=True)
1629
+ self.W_out = torch.nn.Linear(num_ff, num_hidden, bias=True)
1630
+ self.act = torch.nn.GELU()
1631
+
1632
+ def forward(self, h_V):
1633
+ h = self.act(self.W_in(h_V))
1634
+ h = self.W_out(h)
1635
+ return h
1636
+
1637
+
1638
+ class PositionalEncodings(torch.nn.Module):
1639
+ def __init__(self, num_embeddings, max_relative_feature=32):
1640
+ super(PositionalEncodings, self).__init__()
1641
+ self.num_embeddings = num_embeddings
1642
+ self.max_relative_feature = max_relative_feature
1643
+ self.linear = torch.nn.Linear(2 * max_relative_feature + 1 + 1, num_embeddings)
1644
+
1645
+ def forward(self, offset, mask):
1646
+ d = torch.clip(
1647
+ offset + self.max_relative_feature, 0, 2 * self.max_relative_feature
1648
+ ) * mask + (1 - mask) * (2 * self.max_relative_feature + 1)
1649
+ d_onehot = torch.nn.functional.one_hot(d, 2 * self.max_relative_feature + 1 + 1)
1650
+ E = self.linear(d_onehot.float())
1651
+ return E
1652
+
1653
+
1654
+ class DecLayer(torch.nn.Module):
1655
+ def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
1656
+ super(DecLayer, self).__init__()
1657
+ self.num_hidden = num_hidden
1658
+ self.num_in = num_in
1659
+ self.scale = scale
1660
+ self.dropout1 = torch.nn.Dropout(dropout)
1661
+ self.dropout2 = torch.nn.Dropout(dropout)
1662
+ self.norm1 = torch.nn.LayerNorm(num_hidden)
1663
+ self.norm2 = torch.nn.LayerNorm(num_hidden)
1664
+
1665
+ self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True)
1666
+ self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
1667
+ self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
1668
+ self.act = torch.nn.GELU()
1669
+ self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
1670
+
1671
+ def forward(self, h_V, h_E, mask_V=None, mask_attend=None):
1672
+ """Parallel computation of full transformer layer"""
1673
+
1674
+ # Concatenate h_V_i to h_E_ij
1675
+ h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_E.size(-2), -1)
1676
+ h_EV = torch.cat([h_V_expand, h_E], -1)
1677
+
1678
+ h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
1679
+ if mask_attend is not None:
1680
+ h_message = mask_attend.unsqueeze(-1) * h_message
1681
+ dh = torch.sum(h_message, -2) / self.scale
1682
+
1683
+ h_V = self.norm1(h_V + self.dropout1(dh))
1684
+
1685
+ # Position-wise feedforward
1686
+ dh = self.dense(h_V)
1687
+ h_V = self.norm2(h_V + self.dropout2(dh))
1688
+
1689
+ if mask_V is not None:
1690
+ mask_V = mask_V.unsqueeze(-1)
1691
+ h_V = mask_V * h_V
1692
+ return h_V
1693
+
1694
+
1695
+ class EncLayer(torch.nn.Module):
1696
+ def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
1697
+ super(EncLayer, self).__init__()
1698
+ self.num_hidden = num_hidden
1699
+ self.num_in = num_in
1700
+ self.scale = scale
1701
+ self.dropout1 = torch.nn.Dropout(dropout)
1702
+ self.dropout2 = torch.nn.Dropout(dropout)
1703
+ self.dropout3 = torch.nn.Dropout(dropout)
1704
+ self.norm1 = torch.nn.LayerNorm(num_hidden)
1705
+ self.norm2 = torch.nn.LayerNorm(num_hidden)
1706
+ self.norm3 = torch.nn.LayerNorm(num_hidden)
1707
+
1708
+ self.W1 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True)
1709
+ self.W2 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
1710
+ self.W3 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
1711
+ self.W11 = torch.nn.Linear(num_hidden + num_in, num_hidden, bias=True)
1712
+ self.W12 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
1713
+ self.W13 = torch.nn.Linear(num_hidden, num_hidden, bias=True)
1714
+ self.act = torch.nn.GELU()
1715
+ self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
1716
+
1717
+ def forward(self, h_V, h_E, E_idx, mask_V=None, mask_attend=None):
1718
+ """Parallel computation of full transformer layer"""
1719
+
1720
+ h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
1721
+ h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1)
1722
+ h_EV = torch.cat([h_V_expand, h_EV], -1)
1723
+ h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
1724
+ if mask_attend is not None:
1725
+ h_message = mask_attend.unsqueeze(-1) * h_message
1726
+ dh = torch.sum(h_message, -2) / self.scale
1727
+ h_V = self.norm1(h_V + self.dropout1(dh))
1728
+
1729
+ dh = self.dense(h_V)
1730
+ h_V = self.norm2(h_V + self.dropout2(dh))
1731
+ if mask_V is not None:
1732
+ mask_V = mask_V.unsqueeze(-1)
1733
+ h_V = mask_V * h_V
1734
+
1735
+ h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
1736
+ h_V_expand = h_V.unsqueeze(-2).expand(-1, -1, h_EV.size(-2), -1)
1737
+ h_EV = torch.cat([h_V_expand, h_EV], -1)
1738
+ h_message = self.W13(self.act(self.W12(self.act(self.W11(h_EV)))))
1739
+ h_E = self.norm3(h_E + self.dropout3(h_message))
1740
+ return h_V, h_E
1741
+
1742
+
1743
+ # The following gather functions
1744
+ def gather_edges(edges, neighbor_idx):
1745
+ # Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C]
1746
+ neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1))
1747
+ edge_features = torch.gather(edges, 2, neighbors)
1748
+ return edge_features
1749
+
1750
+
1751
+ def gather_nodes(nodes, neighbor_idx):
1752
+ # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
1753
+ # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
1754
+ neighbors_flat = neighbor_idx.reshape((neighbor_idx.shape[0], -1))
1755
+ neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))
1756
+ # Gather and re-pack
1757
+ neighbor_features = torch.gather(nodes, 1, neighbors_flat)
1758
+ neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1])
1759
+ return neighbor_features
1760
+
1761
+
1762
+ def gather_nodes_t(nodes, neighbor_idx):
1763
+ # Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C]
1764
+ idx_flat = neighbor_idx.unsqueeze(-1).expand(-1, -1, nodes.size(2))
1765
+ neighbor_features = torch.gather(nodes, 1, idx_flat)
1766
+ return neighbor_features
1767
+
1768
+
1769
+ def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx):
1770
+ h_nodes = gather_nodes(h_nodes, E_idx)
1771
+ h_nn = torch.cat([h_neighbors, h_nodes], -1)
1772
+ return h_nn
openfold/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #from . import model
2
+ #from . import utils
3
+ #from . import np
4
+ #from . import resources
5
+
6
+ #__all__ = ["model", "utils", "np", "data", "resources"]
openfold/config.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import ml_collections as mlc
3
+
4
+
5
+ def set_inf(c, inf):
6
+ for k, v in c.items():
7
+ if isinstance(v, mlc.ConfigDict):
8
+ set_inf(v, inf)
9
+ elif k == "inf":
10
+ c[k] = inf
11
+
12
+
13
+ def enforce_config_constraints(config):
14
+ def string_to_setting(s):
15
+ path = s.split('.')
16
+ setting = config
17
+ for p in path:
18
+ setting = setting[p]
19
+
20
+ return setting
21
+
22
+ mutually_exclusive_bools = [
23
+ (
24
+ "model.template.average_templates",
25
+ "model.template.offload_templates"
26
+ )
27
+ ]
28
+
29
+ for s1, s2 in mutually_exclusive_bools:
30
+ s1_setting = string_to_setting(s1)
31
+ s2_setting = string_to_setting(s2)
32
+ if(s1_setting and s2_setting):
33
+ raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
34
+
35
+
36
+ def model_config(name, train=False, low_prec=False):
37
+ c = copy.deepcopy(config)
38
+ if name == "initial_training":
39
+ # AF2 Suppl. Table 4, "initial training" setting
40
+ pass
41
+ elif name == "finetuning":
42
+ # AF2 Suppl. Table 4, "finetuning" setting
43
+ c.data.train.max_extra_msa = 5120
44
+ c.data.train.crop_size = 384
45
+ c.data.train.max_msa_clusters = 512
46
+ c.loss.violation.weight = 1.
47
+ c.loss.experimentally_resolved.weight = 0.01
48
+ elif name == "finetuning_ptm":
49
+ c.data.train.max_extra_msa = 5120
50
+ c.data.train.crop_size = 384
51
+ c.data.train.max_msa_clusters = 512
52
+ c.loss.violation.weight = 1.
53
+ c.loss.experimentally_resolved.weight = 0.01
54
+ c.model.heads.tm.enabled = True
55
+ c.loss.tm.weight = 0.1
56
+ elif name == "model_1":
57
+ # AF2 Suppl. Table 5, Model 1.1.1
58
+ c.data.train.max_extra_msa = 5120
59
+ c.data.predict.max_extra_msa = 5120
60
+ c.data.common.reduce_max_clusters_by_max_templates = True
61
+ c.data.common.use_templates = True
62
+ c.data.common.use_template_torsion_angles = True
63
+ c.model.template.enabled = True
64
+ elif name == "model_2":
65
+ # AF2 Suppl. Table 5, Model 1.1.2
66
+ c.data.common.reduce_max_clusters_by_max_templates = True
67
+ c.data.common.use_templates = True
68
+ c.data.common.use_template_torsion_angles = True
69
+ c.model.template.enabled = True
70
+ elif name == "model_3":
71
+ # AF2 Suppl. Table 5, Model 1.2.1
72
+ c.data.train.max_extra_msa = 5120
73
+ c.data.predict.max_extra_msa = 5120
74
+ c.model.template.enabled = False
75
+ elif name == "model_4":
76
+ # AF2 Suppl. Table 5, Model 1.2.2
77
+ c.data.train.max_extra_msa = 5120
78
+ c.data.predict.max_extra_msa = 5120
79
+ c.model.template.enabled = False
80
+ elif name == "model_5":
81
+ # AF2 Suppl. Table 5, Model 1.2.3
82
+ c.model.template.enabled = False
83
+ elif name == "model_1_ptm":
84
+ c.data.train.max_extra_msa = 5120
85
+ c.data.predict.max_extra_msa = 5120
86
+ c.data.common.reduce_max_clusters_by_max_templates = True
87
+ c.data.common.use_templates = True
88
+ c.data.common.use_template_torsion_angles = True
89
+ c.model.template.enabled = True
90
+ c.model.heads.tm.enabled = True
91
+ c.loss.tm.weight = 0.1
92
+ elif name == "model_2_ptm":
93
+ c.data.common.reduce_max_clusters_by_max_templates = True
94
+ c.data.common.use_templates = True
95
+ c.data.common.use_template_torsion_angles = True
96
+ c.model.template.enabled = True
97
+ c.model.heads.tm.enabled = True
98
+ c.loss.tm.weight = 0.1
99
+ elif name == "model_3_ptm":
100
+ c.data.train.max_extra_msa = 5120
101
+ c.data.predict.max_extra_msa = 5120
102
+ c.model.template.enabled = False
103
+ c.model.heads.tm.enabled = True
104
+ c.loss.tm.weight = 0.1
105
+ elif name == "model_4_ptm":
106
+ c.data.train.max_extra_msa = 5120
107
+ c.data.predict.max_extra_msa = 5120
108
+ c.model.template.enabled = False
109
+ c.model.heads.tm.enabled = True
110
+ c.loss.tm.weight = 0.1
111
+ elif name == "model_5_ptm":
112
+ c.model.template.enabled = False
113
+ c.model.heads.tm.enabled = True
114
+ c.loss.tm.weight = 0.1
115
+ else:
116
+ raise ValueError("Invalid model name")
117
+
118
+ if train:
119
+ c.globals.blocks_per_ckpt = 1
120
+ c.globals.chunk_size = None
121
+ c.globals.use_lma = False
122
+ c.globals.offload_inference = False
123
+ c.model.template.average_templates = False
124
+ c.model.template.offload_templates = False
125
+ if low_prec:
126
+ c.globals.eps = 1e-4
127
+ # If we want exact numerical parity with the original, inf can't be
128
+ # a global constant
129
+ set_inf(c, 1e4)
130
+
131
+ enforce_config_constraints(c)
132
+
133
+ return c
134
+
135
+
136
+ c_z = mlc.FieldReference(128, field_type=int)
137
+ c_m = mlc.FieldReference(256, field_type=int)
138
+ c_t = mlc.FieldReference(64, field_type=int)
139
+ c_e = mlc.FieldReference(64, field_type=int)
140
+ c_s = mlc.FieldReference(384, field_type=int)
141
+ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
142
+ chunk_size = mlc.FieldReference(4, field_type=int)
143
+ aux_distogram_bins = mlc.FieldReference(64, field_type=int)
144
+ tm_enabled = mlc.FieldReference(False, field_type=bool)
145
+ eps = mlc.FieldReference(1e-8, field_type=float)
146
+ templates_enabled = mlc.FieldReference(True, field_type=bool)
147
+ embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
148
+ tune_chunk_size = mlc.FieldReference(True, field_type=bool)
149
+
150
+ NUM_RES = "num residues placeholder"
151
+ NUM_MSA_SEQ = "msa placeholder"
152
+ NUM_EXTRA_SEQ = "extra msa placeholder"
153
+ NUM_TEMPLATES = "num templates placeholder"
154
+
155
+ config = mlc.ConfigDict(
156
+ {
157
+ "data": {
158
+ "common": {
159
+ "feat": {
160
+ "aatype": [NUM_RES],
161
+ "all_atom_mask": [NUM_RES, None],
162
+ "all_atom_positions": [NUM_RES, None, None],
163
+ "alt_chi_angles": [NUM_RES, None],
164
+ "atom14_alt_gt_exists": [NUM_RES, None],
165
+ "atom14_alt_gt_positions": [NUM_RES, None, None],
166
+ "atom14_atom_exists": [NUM_RES, None],
167
+ "atom14_atom_is_ambiguous": [NUM_RES, None],
168
+ "atom14_gt_exists": [NUM_RES, None],
169
+ "atom14_gt_positions": [NUM_RES, None, None],
170
+ "atom37_atom_exists": [NUM_RES, None],
171
+ "backbone_rigid_mask": [NUM_RES],
172
+ "backbone_rigid_tensor": [NUM_RES, None, None],
173
+ "bert_mask": [NUM_MSA_SEQ, NUM_RES],
174
+ "chi_angles_sin_cos": [NUM_RES, None, None],
175
+ "chi_mask": [NUM_RES, None],
176
+ "extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES],
177
+ "extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES],
178
+ "extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
179
+ "extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
180
+ "extra_msa_row_mask": [NUM_EXTRA_SEQ],
181
+ "is_distillation": [],
182
+ "msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
183
+ "msa_mask": [NUM_MSA_SEQ, NUM_RES],
184
+ "msa_row_mask": [NUM_MSA_SEQ],
185
+ "no_recycling_iters": [],
186
+ "pseudo_beta": [NUM_RES, None],
187
+ "pseudo_beta_mask": [NUM_RES],
188
+ "residue_index": [NUM_RES],
189
+ "residx_atom14_to_atom37": [NUM_RES, None],
190
+ "residx_atom37_to_atom14": [NUM_RES, None],
191
+ "resolution": [],
192
+ "rigidgroups_alt_gt_frames": [NUM_RES, None, None, None],
193
+ "rigidgroups_group_exists": [NUM_RES, None],
194
+ "rigidgroups_group_is_ambiguous": [NUM_RES, None],
195
+ "rigidgroups_gt_exists": [NUM_RES, None],
196
+ "rigidgroups_gt_frames": [NUM_RES, None, None, None],
197
+ "seq_length": [],
198
+ "seq_mask": [NUM_RES],
199
+ "target_feat": [NUM_RES, None],
200
+ "template_aatype": [NUM_TEMPLATES, NUM_RES],
201
+ "template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
202
+ "template_all_atom_positions": [
203
+ NUM_TEMPLATES, NUM_RES, None, None,
204
+ ],
205
+ "template_alt_torsion_angles_sin_cos": [
206
+ NUM_TEMPLATES, NUM_RES, None, None,
207
+ ],
208
+ "template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
209
+ "template_backbone_rigid_tensor": [
210
+ NUM_TEMPLATES, NUM_RES, None, None,
211
+ ],
212
+ "template_mask": [NUM_TEMPLATES],
213
+ "template_pseudo_beta": [NUM_TEMPLATES, NUM_RES, None],
214
+ "template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_RES],
215
+ "template_sum_probs": [NUM_TEMPLATES, None],
216
+ "template_torsion_angles_mask": [
217
+ NUM_TEMPLATES, NUM_RES, None,
218
+ ],
219
+ "template_torsion_angles_sin_cos": [
220
+ NUM_TEMPLATES, NUM_RES, None, None,
221
+ ],
222
+ "true_msa": [NUM_MSA_SEQ, NUM_RES],
223
+ "use_clamped_fape": [],
224
+ },
225
+ "masked_msa": {
226
+ "profile_prob": 0.1,
227
+ "same_prob": 0.1,
228
+ "uniform_prob": 0.1,
229
+ },
230
+ "max_recycling_iters": 3,
231
+ "msa_cluster_features": True,
232
+ "reduce_msa_clusters_by_max_templates": False,
233
+ "resample_msa_in_recycling": True,
234
+ "template_features": [
235
+ "template_all_atom_positions",
236
+ "template_sum_probs",
237
+ "template_aatype",
238
+ "template_all_atom_mask",
239
+ ],
240
+ "unsupervised_features": [
241
+ "aatype",
242
+ "residue_index",
243
+ "msa",
244
+ "num_alignments",
245
+ "seq_length",
246
+ "between_segment_residues",
247
+ "deletion_matrix",
248
+ "no_recycling_iters",
249
+ ],
250
+ "use_templates": templates_enabled,
251
+ "use_template_torsion_angles": embed_template_torsion_angles,
252
+ },
253
+ "supervised": {
254
+ "clamp_prob": 0.9,
255
+ "supervised_features": [
256
+ "all_atom_mask",
257
+ "all_atom_positions",
258
+ "resolution",
259
+ "use_clamped_fape",
260
+ "is_distillation",
261
+ ],
262
+ },
263
+ "predict": {
264
+ "fixed_size": True,
265
+ "subsample_templates": False, # We want top templates.
266
+ "masked_msa_replace_fraction": 0.15,
267
+ "max_msa_clusters": 512,
268
+ "max_extra_msa": 1024,
269
+ "max_template_hits": 4,
270
+ "max_templates": 4,
271
+ "crop": False,
272
+ "crop_size": None,
273
+ "supervised": False,
274
+ "uniform_recycling": False,
275
+ },
276
+ "eval": {
277
+ "fixed_size": True,
278
+ "subsample_templates": False, # We want top templates.
279
+ "masked_msa_replace_fraction": 0.15,
280
+ "max_msa_clusters": 128,
281
+ "max_extra_msa": 1024,
282
+ "max_template_hits": 4,
283
+ "max_templates": 4,
284
+ "crop": False,
285
+ "crop_size": None,
286
+ "supervised": True,
287
+ "uniform_recycling": False,
288
+ },
289
+ "train": {
290
+ "fixed_size": True,
291
+ "subsample_templates": True,
292
+ "masked_msa_replace_fraction": 0.15,
293
+ "max_msa_clusters": 128,
294
+ "max_extra_msa": 1024,
295
+ "max_template_hits": 4,
296
+ "max_templates": 4,
297
+ "shuffle_top_k_prefiltered": 20,
298
+ "crop": True,
299
+ "crop_size": 256,
300
+ "supervised": True,
301
+ "clamp_prob": 0.9,
302
+ "max_distillation_msa_clusters": 1000,
303
+ "uniform_recycling": True,
304
+ "distillation_prob": 0.75,
305
+ },
306
+ "data_module": {
307
+ "use_small_bfd": False,
308
+ "data_loaders": {
309
+ "batch_size": 1,
310
+ "num_workers": 16,
311
+ },
312
+ },
313
+ },
314
+ # Recurring FieldReferences that can be changed globally here
315
+ "globals": {
316
+ "blocks_per_ckpt": blocks_per_ckpt,
317
+ "chunk_size": chunk_size,
318
+ "use_lma": False,
319
+ "offload_inference": False,
320
+ "c_z": c_z,
321
+ "c_m": c_m,
322
+ "c_t": c_t,
323
+ "c_e": c_e,
324
+ "c_s": c_s,
325
+ "eps": eps,
326
+ },
327
+ "model": {
328
+ "_mask_trans": False,
329
+ "input_embedder": {
330
+ "tf_dim": 22,
331
+ "msa_dim": 49,
332
+ "c_z": c_z,
333
+ "c_m": c_m,
334
+ "relpos_k": 32,
335
+ },
336
+ "recycling_embedder": {
337
+ "c_z": c_z,
338
+ "c_m": c_m,
339
+ "min_bin": 3.25,
340
+ "max_bin": 20.75,
341
+ "no_bins": 15,
342
+ "inf": 1e8,
343
+ },
344
+ "template": {
345
+ "distogram": {
346
+ "min_bin": 3.25,
347
+ "max_bin": 50.75,
348
+ "no_bins": 39,
349
+ },
350
+ "template_angle_embedder": {
351
+ # DISCREPANCY: c_in is supposed to be 51.
352
+ "c_in": 57,
353
+ "c_out": c_m,
354
+ },
355
+ "template_pair_embedder": {
356
+ "c_in": 88,
357
+ "c_out": c_t,
358
+ },
359
+ "template_pair_stack": {
360
+ "c_t": c_t,
361
+ # DISCREPANCY: c_hidden_tri_att here is given in the supplement
362
+ # as 64. In the code, it's 16.
363
+ "c_hidden_tri_att": 16,
364
+ "c_hidden_tri_mul": 64,
365
+ "no_blocks": 2,
366
+ "no_heads": 4,
367
+ "pair_transition_n": 2,
368
+ "dropout_rate": 0.25,
369
+ "blocks_per_ckpt": blocks_per_ckpt,
370
+ "tune_chunk_size": tune_chunk_size,
371
+ "inf": 1e9,
372
+ },
373
+ "template_pointwise_attention": {
374
+ "c_t": c_t,
375
+ "c_z": c_z,
376
+ # DISCREPANCY: c_hidden here is given in the supplement as 64.
377
+ # It's actually 16.
378
+ "c_hidden": 16,
379
+ "no_heads": 4,
380
+ "inf": 1e5, # 1e9,
381
+ },
382
+ "inf": 1e5, # 1e9,
383
+ "eps": eps, # 1e-6,
384
+ "enabled": templates_enabled,
385
+ "embed_angles": embed_template_torsion_angles,
386
+ "use_unit_vector": False,
387
+ # Approximate template computation, saving memory.
388
+ # In our experiments, results are equivalent to or better than
389
+ # the stock implementation. Should be enabled for all new
390
+ # training runs.
391
+ "average_templates": False,
392
+ # Offload template embeddings to CPU memory. Vastly reduced
393
+ # memory consumption at the cost of a modest increase in
394
+ # runtime. Useful for inference on very long sequences.
395
+ # Mutually exclusive with average_templates.
396
+ "offload_templates": False,
397
+ },
398
+ "extra_msa": {
399
+ "extra_msa_embedder": {
400
+ "c_in": 25,
401
+ "c_out": c_e,
402
+ },
403
+ "extra_msa_stack": {
404
+ "c_m": c_e,
405
+ "c_z": c_z,
406
+ "c_hidden_msa_att": 8,
407
+ "c_hidden_opm": 32,
408
+ "c_hidden_mul": 128,
409
+ "c_hidden_pair_att": 32,
410
+ "no_heads_msa": 8,
411
+ "no_heads_pair": 4,
412
+ "no_blocks": 4,
413
+ "transition_n": 4,
414
+ "msa_dropout": 0.15,
415
+ "pair_dropout": 0.25,
416
+ "clear_cache_between_blocks": False,
417
+ "tune_chunk_size": tune_chunk_size,
418
+ "inf": 1e9,
419
+ "eps": eps, # 1e-10,
420
+ "ckpt": blocks_per_ckpt is not None,
421
+ },
422
+ "enabled": True,
423
+ },
424
+ "evoformer_stack": {
425
+ "c_m": c_m,
426
+ "c_z": c_z,
427
+ "c_hidden_msa_att": 32,
428
+ "c_hidden_opm": 32,
429
+ "c_hidden_mul": 128,
430
+ "c_hidden_pair_att": 32,
431
+ "c_s": c_s,
432
+ "no_heads_msa": 8,
433
+ "no_heads_pair": 4,
434
+ "no_blocks": 48,
435
+ "transition_n": 4,
436
+ "msa_dropout": 0.15,
437
+ "pair_dropout": 0.25,
438
+ "blocks_per_ckpt": blocks_per_ckpt,
439
+ "clear_cache_between_blocks": False,
440
+ "tune_chunk_size": tune_chunk_size,
441
+ "inf": 1e9,
442
+ "eps": eps, # 1e-10,
443
+ },
444
+ "structure_module": {
445
+ "c_s": c_s,
446
+ "c_z": c_z,
447
+ "c_ipa": 16,
448
+ "c_resnet": 128,
449
+ "no_heads_ipa": 12,
450
+ "no_qk_points": 4,
451
+ "no_v_points": 8,
452
+ "dropout_rate": 0.1,
453
+ "no_blocks": 8,
454
+ "no_transition_layers": 1,
455
+ "no_resnet_blocks": 2,
456
+ "no_angles": 7,
457
+ "trans_scale_factor": 10,
458
+ "epsilon": eps, # 1e-12,
459
+ "inf": 1e5,
460
+ },
461
+ "heads": {
462
+ "lddt": {
463
+ "no_bins": 50,
464
+ "c_in": c_s,
465
+ "c_hidden": 128,
466
+ },
467
+ "distogram": {
468
+ "c_z": c_z,
469
+ "no_bins": aux_distogram_bins,
470
+ },
471
+ "tm": {
472
+ "c_z": c_z,
473
+ "no_bins": aux_distogram_bins,
474
+ "enabled": tm_enabled,
475
+ },
476
+ "masked_msa": {
477
+ "c_m": c_m,
478
+ "c_out": 23,
479
+ },
480
+ "experimentally_resolved": {
481
+ "c_s": c_s,
482
+ "c_out": 37,
483
+ },
484
+ },
485
+ },
486
+ "relax": {
487
+ "max_iterations": 0, # no max
488
+ "tolerance": 2.39,
489
+ "stiffness": 10.0,
490
+ "max_outer_iterations": 20,
491
+ "exclude_residues": [],
492
+ },
493
+ "loss": {
494
+ "distogram": {
495
+ "min_bin": 2.3125,
496
+ "max_bin": 21.6875,
497
+ "no_bins": 64,
498
+ "eps": eps, # 1e-6,
499
+ "weight": 0.3,
500
+ },
501
+ "experimentally_resolved": {
502
+ "eps": eps, # 1e-8,
503
+ "min_resolution": 0.1,
504
+ "max_resolution": 3.0,
505
+ "weight": 0.0,
506
+ },
507
+ "fape": {
508
+ "backbone": {
509
+ "clamp_distance": 10.0,
510
+ "loss_unit_distance": 10.0,
511
+ "weight": 0.5,
512
+ },
513
+ "sidechain": {
514
+ "clamp_distance": 10.0,
515
+ "length_scale": 10.0,
516
+ "weight": 0.5,
517
+ },
518
+ "eps": 1e-4,
519
+ "weight": 1.0,
520
+ },
521
+ "lddt": {
522
+ "min_resolution": 0.1,
523
+ "max_resolution": 3.0,
524
+ "cutoff": 15.0,
525
+ "no_bins": 50,
526
+ "eps": eps, # 1e-10,
527
+ "weight": 0.01,
528
+ },
529
+ "masked_msa": {
530
+ "eps": eps, # 1e-8,
531
+ "weight": 2.0,
532
+ },
533
+ "supervised_chi": {
534
+ "chi_weight": 0.5,
535
+ "angle_norm_weight": 0.01,
536
+ "eps": eps, # 1e-6,
537
+ "weight": 1.0,
538
+ },
539
+ "violation": {
540
+ "violation_tolerance_factor": 12.0,
541
+ "clash_overlap_tolerance": 1.5,
542
+ "eps": eps, # 1e-6,
543
+ "weight": 0.0,
544
+ },
545
+ "tm": {
546
+ "max_bin": 31,
547
+ "no_bins": 64,
548
+ "min_resolution": 0.1,
549
+ "max_resolution": 3.0,
550
+ "eps": eps, # 1e-8,
551
+ "weight": 0.,
552
+ "enabled": tm_enabled,
553
+ },
554
+ "eps": eps,
555
+ },
556
+ "ema": {"decay": 0.999},
557
+ }
558
+ )
openfold/data/__init__.py ADDED
File without changes
openfold/data/data_modules.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from functools import partial
3
+ import json
4
+ import logging
5
+ import os
6
+ import pickle
7
+ from typing import Optional, Sequence, List, Any
8
+
9
+ import ml_collections as mlc
10
+ import numpy as np
11
+ import pytorch_lightning as pl
12
+ import torch
13
+ from torch.utils.data import RandomSampler
14
+
15
+ from openfold.data import (
16
+ data_pipeline,
17
+ feature_pipeline,
18
+ mmcif_parsing,
19
+ templates,
20
+ )
21
+ from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
22
+
23
+
24
+ class OpenFoldSingleDataset(torch.utils.data.Dataset):
25
+ def __init__(self,
26
+ data_dir: str,
27
+ alignment_dir: str,
28
+ template_mmcif_dir: str,
29
+ max_template_date: str,
30
+ config: mlc.ConfigDict,
31
+ kalign_binary_path: str = '/usr/bin/kalign',
32
+ max_template_hits: int = 4,
33
+ obsolete_pdbs_file_path: Optional[str] = None,
34
+ template_release_dates_cache_path: Optional[str] = None,
35
+ shuffle_top_k_prefiltered: Optional[int] = None,
36
+ treat_pdb_as_distillation: bool = True,
37
+ mapping_path: Optional[str] = None,
38
+ mode: str = "train",
39
+ alignment_index: Optional[Any] = None,
40
+ _output_raw: bool = False,
41
+ _structure_index: Optional[Any] = None,
42
+ ):
43
+ """
44
+ Args:
45
+ data_dir:
46
+ A path to a directory containing mmCIF files (in train
47
+ mode) or FASTA files (in inference mode).
48
+ alignment_dir:
49
+ A path to a directory containing only data in the format
50
+ output by an AlignmentRunner
51
+ (defined in openfold.features.alignment_runner).
52
+ I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
53
+ or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
54
+ files.
55
+ template_mmcif_dir:
56
+ Path to a directory containing template mmCIF files.
57
+ config:
58
+ A dataset config object. See openfold.config
59
+ kalign_binary_path:
60
+ Path to kalign binary.
61
+ max_template_hits:
62
+ An upper bound on how many templates are considered. During
63
+ training, the templates ultimately used are subsampled
64
+ from this total quantity.
65
+ template_release_dates_cache_path:
66
+ Path to the output of scripts/generate_mmcif_cache.
67
+ obsolete_pdbs_file_path:
68
+ Path to the file containing replacements for obsolete PDBs.
69
+ shuffle_top_k_prefiltered:
70
+ Whether to uniformly shuffle the top k template hits before
71
+ parsing max_template_hits of them. Can be used to
72
+ approximate DeepMind's training-time template subsampling
73
+ scheme much more performantly.
74
+ treat_pdb_as_distillation:
75
+ Whether to assume that .pdb files in the data_dir are from
76
+ the self-distillation set (and should be subjected to
77
+ special distillation set preprocessing steps).
78
+ mode:
79
+ "train", "val", or "predict"
80
+ """
81
+ super(OpenFoldSingleDataset, self).__init__()
82
+ self.data_dir = data_dir
83
+ self.alignment_dir = alignment_dir
84
+ self.config = config
85
+ self.treat_pdb_as_distillation = treat_pdb_as_distillation
86
+ self.mode = mode
87
+ self.alignment_index = alignment_index
88
+ self._output_raw = _output_raw
89
+ self._structure_index = _structure_index
90
+
91
+ self.supported_exts = [".cif", ".core", ".pdb"]
92
+
93
+ valid_modes = ["train", "eval", "predict"]
94
+ if(mode not in valid_modes):
95
+ raise ValueError(f'mode must be one of {valid_modes}')
96
+
97
+ if(template_release_dates_cache_path is None):
98
+ logging.warning(
99
+ "Template release dates cache does not exist. Remember to run "
100
+ "scripts/generate_mmcif_cache.py before running OpenFold"
101
+ )
102
+
103
+ if(alignment_index is not None):
104
+ self._chain_ids = list(alignment_index.keys())
105
+ elif(mapping_path is None):
106
+ self._chain_ids = list(os.listdir(alignment_dir))
107
+ else:
108
+ with open(mapping_path, "r") as f:
109
+ self._chain_ids = [l.strip() for l in f.readlines()]
110
+
111
+ self._chain_id_to_idx_dict = {
112
+ chain: i for i, chain in enumerate(self._chain_ids)
113
+ }
114
+
115
+ template_featurizer = templates.TemplateHitFeaturizer(
116
+ mmcif_dir=template_mmcif_dir,
117
+ max_template_date=max_template_date,
118
+ max_hits=max_template_hits,
119
+ kalign_binary_path=kalign_binary_path,
120
+ release_dates_path=template_release_dates_cache_path,
121
+ obsolete_pdbs_path=obsolete_pdbs_file_path,
122
+ _shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
123
+ )
124
+
125
+ self.data_pipeline = data_pipeline.DataPipeline(
126
+ template_featurizer=template_featurizer,
127
+ )
128
+
129
+ if(not self._output_raw):
130
+ self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
131
+
132
+ def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
133
+ with open(path, 'r') as f:
134
+ mmcif_string = f.read()
135
+
136
+ mmcif_object = mmcif_parsing.parse(
137
+ file_id=file_id, mmcif_string=mmcif_string
138
+ )
139
+
140
+ # Crash if an error is encountered. Any parsing errors should have
141
+ # been dealt with at the alignment stage.
142
+ if(mmcif_object.mmcif_object is None):
143
+ raise list(mmcif_object.errors.values())[0]
144
+
145
+ mmcif_object = mmcif_object.mmcif_object
146
+
147
+ data = self.data_pipeline.process_mmcif(
148
+ mmcif=mmcif_object,
149
+ alignment_dir=alignment_dir,
150
+ chain_id=chain_id,
151
+ alignment_index=alignment_index
152
+ )
153
+
154
+ return data
155
+
156
+ def chain_id_to_idx(self, chain_id):
157
+ return self._chain_id_to_idx_dict[chain_id]
158
+
159
+ def idx_to_chain_id(self, idx):
160
+ return self._chain_ids[idx]
161
+
162
+ def __getitem__(self, idx):
163
+ name = self.idx_to_chain_id(idx)
164
+ alignment_dir = os.path.join(self.alignment_dir, name)
165
+
166
+ alignment_index = None
167
+ if(self.alignment_index is not None):
168
+ alignment_dir = self.alignment_dir
169
+ alignment_index = self.alignment_index[name]
170
+
171
+ if(self.mode == 'train' or self.mode == 'eval'):
172
+ spl = name.rsplit('_', 1)
173
+ if(len(spl) == 2):
174
+ file_id, chain_id = spl
175
+ else:
176
+ file_id, = spl
177
+ chain_id = None
178
+
179
+ path = os.path.join(self.data_dir, file_id)
180
+ structure_index_entry = None
181
+ if(self._structure_index is not None):
182
+ structure_index_entry = self._structure_index[name]
183
+ assert(len(structure_index_entry["files"]) == 1)
184
+ filename, _, _ = structure_index_entry["files"][0]
185
+ ext = os.path.splitext(filename)[1]
186
+ else:
187
+ ext = None
188
+ for e in self.supported_exts:
189
+ if(os.path.exists(path + e)):
190
+ ext = e
191
+ break
192
+
193
+ if(ext is None):
194
+ raise ValueError("Invalid file type")
195
+
196
+ path += ext
197
+ if(ext == ".cif"):
198
+ data = self._parse_mmcif(
199
+ path, file_id, chain_id, alignment_dir, alignment_index,
200
+ )
201
+ elif(ext == ".core"):
202
+ data = self.data_pipeline.process_core(
203
+ path, alignment_dir, alignment_index,
204
+ )
205
+ elif(ext == ".pdb"):
206
+ data = self.data_pipeline.process_pdb(
207
+ pdb_path=path,
208
+ alignment_dir=alignment_dir,
209
+ is_distillation=self.treat_pdb_as_distillation,
210
+ chain_id=chain_id,
211
+ alignment_index=alignment_index,
212
+ _structure_index=self._structure_index[name],
213
+ )
214
+ else:
215
+ raise ValueError("Extension branch missing")
216
+ else:
217
+ path = os.path.join(name, name + ".fasta")
218
+ data = self.data_pipeline.process_fasta(
219
+ fasta_path=path,
220
+ alignment_dir=alignment_dir,
221
+ alignment_index=alignment_index,
222
+ )
223
+
224
+ if(self._output_raw):
225
+ return data
226
+
227
+ feats = self.feature_pipeline.process_features(
228
+ data, self.mode
229
+ )
230
+
231
+ feats["batch_idx"] = torch.tensor([idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device)
232
+
233
+ return feats
234
+
235
+ def __len__(self):
236
+ return len(self._chain_ids)
237
+
238
+
239
+ def deterministic_train_filter(
240
+ chain_data_cache_entry: Any,
241
+ max_resolution: float = 9.,
242
+ max_single_aa_prop: float = 0.8,
243
+ ) -> bool:
244
+ # Hard filters
245
+ resolution = chain_data_cache_entry.get("resolution", None)
246
+ if(resolution is not None and resolution > max_resolution):
247
+ return False
248
+
249
+ seq = chain_data_cache_entry["seq"]
250
+ counts = {}
251
+ for aa in seq:
252
+ counts.setdefault(aa, 0)
253
+ counts[aa] += 1
254
+ largest_aa_count = max(counts.values())
255
+ largest_single_aa_prop = largest_aa_count / len(seq)
256
+ if(largest_single_aa_prop > max_single_aa_prop):
257
+ return False
258
+
259
+ return True
260
+
261
+
262
+ def get_stochastic_train_filter_prob(
263
+ chain_data_cache_entry: Any,
264
+ ) -> List[float]:
265
+ # Stochastic filters
266
+ probabilities = []
267
+
268
+ cluster_size = chain_data_cache_entry.get("cluster_size", None)
269
+ if(cluster_size is not None and cluster_size > 0):
270
+ probabilities.append(1 / cluster_size)
271
+
272
+ chain_length = len(chain_data_cache_entry["seq"])
273
+ probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
274
+
275
+ # Risk of underflow here?
276
+ out = 1
277
+ for p in probabilities:
278
+ out *= p
279
+
280
+ return out
281
+
282
+
283
+ class OpenFoldDataset(torch.utils.data.Dataset):
284
+ """
285
+ Implements the stochastic filters applied during AlphaFold's training.
286
+ Because samples are selected from constituent datasets randomly, the
287
+ length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
288
+ and filtered once at initialization.
289
+ """
290
+ def __init__(self,
291
+ datasets: Sequence[OpenFoldSingleDataset],
292
+ probabilities: Sequence[int],
293
+ epoch_len: int,
294
+ chain_data_cache_paths: List[str],
295
+ generator: torch.Generator = None,
296
+ _roll_at_init: bool = True,
297
+ ):
298
+ self.datasets = datasets
299
+ self.probabilities = probabilities
300
+ self.epoch_len = epoch_len
301
+ self.generator = generator
302
+
303
+ self.chain_data_caches = []
304
+ for path in chain_data_cache_paths:
305
+ with open(path, "r") as fp:
306
+ self.chain_data_caches.append(json.load(fp))
307
+
308
+ def looped_shuffled_dataset_idx(dataset_len):
309
+ while True:
310
+ # Uniformly shuffle each dataset's indices
311
+ weights = [1. for _ in range(dataset_len)]
312
+ shuf = torch.multinomial(
313
+ torch.tensor(weights),
314
+ num_samples=dataset_len,
315
+ replacement=False,
316
+ generator=self.generator,
317
+ )
318
+ for idx in shuf:
319
+ yield idx
320
+
321
+ def looped_samples(dataset_idx):
322
+ max_cache_len = int(epoch_len * probabilities[dataset_idx])
323
+ dataset = self.datasets[dataset_idx]
324
+ idx_iter = looped_shuffled_dataset_idx(len(dataset))
325
+ chain_data_cache = self.chain_data_caches[dataset_idx]
326
+ while True:
327
+ weights = []
328
+ idx = []
329
+ for _ in range(max_cache_len):
330
+ candidate_idx = next(idx_iter)
331
+ chain_id = dataset.idx_to_chain_id(candidate_idx)
332
+ chain_data_cache_entry = chain_data_cache[chain_id]
333
+ if(not deterministic_train_filter(chain_data_cache_entry)):
334
+ continue
335
+
336
+ p = get_stochastic_train_filter_prob(
337
+ chain_data_cache_entry,
338
+ )
339
+ weights.append([1. - p, p])
340
+ idx.append(candidate_idx)
341
+
342
+ samples = torch.multinomial(
343
+ torch.tensor(weights),
344
+ num_samples=1,
345
+ generator=self.generator,
346
+ )
347
+ samples = samples.squeeze()
348
+
349
+ cache = [i for i, s in zip(idx, samples) if s]
350
+
351
+ for datapoint_idx in cache:
352
+ yield datapoint_idx
353
+
354
+ self._samples = [looped_samples(i) for i in range(len(self.datasets))]
355
+
356
+ if(_roll_at_init):
357
+ self.reroll()
358
+
359
+ def __getitem__(self, idx):
360
+ dataset_idx, datapoint_idx = self.datapoints[idx]
361
+ return self.datasets[dataset_idx][datapoint_idx]
362
+
363
+ def __len__(self):
364
+ return self.epoch_len
365
+
366
+ def reroll(self):
367
+ dataset_choices = torch.multinomial(
368
+ torch.tensor(self.probabilities),
369
+ num_samples=self.epoch_len,
370
+ replacement=True,
371
+ generator=self.generator,
372
+ )
373
+
374
+ self.datapoints = []
375
+ for dataset_idx in dataset_choices:
376
+ samples = self._samples[dataset_idx]
377
+ datapoint_idx = next(samples)
378
+ self.datapoints.append((dataset_idx, datapoint_idx))
379
+
380
+
381
+ class OpenFoldBatchCollator:
382
+ def __call__(self, prots):
383
+ stack_fn = partial(torch.stack, dim=0)
384
+ return dict_multimap(stack_fn, prots)
385
+
386
+
387
+ class OpenFoldDataLoader(torch.utils.data.DataLoader):
388
+ def __init__(self, *args, config, stage="train", generator=None, **kwargs):
389
+ super().__init__(*args, **kwargs)
390
+ self.config = config
391
+ self.stage = stage
392
+
393
+ if(generator is None):
394
+ generator = torch.Generator()
395
+
396
+ self.generator = generator
397
+ self._prep_batch_properties_probs()
398
+
399
+ def _prep_batch_properties_probs(self):
400
+ keyed_probs = []
401
+ stage_cfg = self.config[self.stage]
402
+
403
+ max_iters = self.config.common.max_recycling_iters
404
+ if(stage_cfg.supervised):
405
+ clamp_prob = self.config.supervised.clamp_prob
406
+ keyed_probs.append(
407
+ ("use_clamped_fape", [1 - clamp_prob, clamp_prob])
408
+ )
409
+
410
+ if(stage_cfg.uniform_recycling):
411
+ recycling_probs = [
412
+ 1. / (max_iters + 1) for _ in range(max_iters + 1)
413
+ ]
414
+ else:
415
+ recycling_probs = [
416
+ 0. for _ in range(max_iters + 1)
417
+ ]
418
+ recycling_probs[-1] = 1.
419
+
420
+ keyed_probs.append(
421
+ ("no_recycling_iters", recycling_probs)
422
+ )
423
+
424
+ keys, probs = zip(*keyed_probs)
425
+ max_len = max([len(p) for p in probs])
426
+ padding = [[0.] * (max_len - len(p)) for p in probs]
427
+
428
+ self.prop_keys = keys
429
+ self.prop_probs_tensor = torch.tensor(
430
+ [p + pad for p, pad in zip(probs, padding)],
431
+ dtype=torch.float32,
432
+ )
433
+
434
+ def _add_batch_properties(self, batch):
435
+ samples = torch.multinomial(
436
+ self.prop_probs_tensor,
437
+ num_samples=1, # 1 per row
438
+ replacement=True,
439
+ generator=self.generator
440
+ )
441
+
442
+ aatype = batch["aatype"]
443
+ batch_dims = aatype.shape[:-2]
444
+ recycling_dim = aatype.shape[-1]
445
+ no_recycling = recycling_dim
446
+ for i, key in enumerate(self.prop_keys):
447
+ sample = int(samples[i][0])
448
+ sample_tensor = torch.tensor(
449
+ sample,
450
+ device=aatype.device,
451
+ requires_grad=False
452
+ )
453
+ orig_shape = sample_tensor.shape
454
+ sample_tensor = sample_tensor.view(
455
+ (1,) * len(batch_dims) + sample_tensor.shape + (1,)
456
+ )
457
+ sample_tensor = sample_tensor.expand(
458
+ batch_dims + orig_shape + (recycling_dim,)
459
+ )
460
+ batch[key] = sample_tensor
461
+
462
+ if(key == "no_recycling_iters"):
463
+ no_recycling = sample
464
+
465
+ resample_recycling = lambda t: t[..., :no_recycling + 1]
466
+ batch = tensor_tree_map(resample_recycling, batch)
467
+
468
+ return batch
469
+
470
+ def __iter__(self):
471
+ it = super().__iter__()
472
+
473
+ def _batch_prop_gen(iterator):
474
+ for batch in iterator:
475
+ yield self._add_batch_properties(batch)
476
+
477
+ return _batch_prop_gen(it)
478
+
479
+
480
+ class OpenFoldDataModule(pl.LightningDataModule):
481
+ def __init__(self,
482
+ config: mlc.ConfigDict,
483
+ template_mmcif_dir: str,
484
+ max_template_date: str,
485
+ train_data_dir: Optional[str] = None,
486
+ train_alignment_dir: Optional[str] = None,
487
+ train_chain_data_cache_path: Optional[str] = None,
488
+ distillation_data_dir: Optional[str] = None,
489
+ distillation_alignment_dir: Optional[str] = None,
490
+ distillation_chain_data_cache_path: Optional[str] = None,
491
+ val_data_dir: Optional[str] = None,
492
+ val_alignment_dir: Optional[str] = None,
493
+ predict_data_dir: Optional[str] = None,
494
+ predict_alignment_dir: Optional[str] = None,
495
+ kalign_binary_path: str = '/usr/bin/kalign',
496
+ train_mapping_path: Optional[str] = None,
497
+ distillation_mapping_path: Optional[str] = None,
498
+ obsolete_pdbs_file_path: Optional[str] = None,
499
+ template_release_dates_cache_path: Optional[str] = None,
500
+ batch_seed: Optional[int] = None,
501
+ train_epoch_len: int = 50000,
502
+ _distillation_structure_index_path: Optional[str] = None,
503
+ alignment_index_path: Optional[str] = None,
504
+ distillation_alignment_index_path: Optional[str] = None,
505
+ **kwargs
506
+ ):
507
+ super(OpenFoldDataModule, self).__init__()
508
+
509
+ self.config = config
510
+ self.template_mmcif_dir = template_mmcif_dir
511
+ self.max_template_date = max_template_date
512
+ self.train_data_dir = train_data_dir
513
+ self.train_alignment_dir = train_alignment_dir
514
+ self.train_chain_data_cache_path = train_chain_data_cache_path
515
+ self.distillation_data_dir = distillation_data_dir
516
+ self.distillation_alignment_dir = distillation_alignment_dir
517
+ self.distillation_chain_data_cache_path = (
518
+ distillation_chain_data_cache_path
519
+ )
520
+ self.val_data_dir = val_data_dir
521
+ self.val_alignment_dir = val_alignment_dir
522
+ self.predict_data_dir = predict_data_dir
523
+ self.predict_alignment_dir = predict_alignment_dir
524
+ self.kalign_binary_path = kalign_binary_path
525
+ self.train_mapping_path = train_mapping_path
526
+ self.distillation_mapping_path = distillation_mapping_path
527
+ self.template_release_dates_cache_path = (
528
+ template_release_dates_cache_path
529
+ )
530
+ self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
531
+ self.batch_seed = batch_seed
532
+ self.train_epoch_len = train_epoch_len
533
+
534
+ if(self.train_data_dir is None and self.predict_data_dir is None):
535
+ raise ValueError(
536
+ 'At least one of train_data_dir or predict_data_dir must be '
537
+ 'specified'
538
+ )
539
+
540
+ self.training_mode = self.train_data_dir is not None
541
+
542
+ if(self.training_mode and train_alignment_dir is None):
543
+ raise ValueError(
544
+ 'In training mode, train_alignment_dir must be specified'
545
+ )
546
+ elif(not self.training_mode and predict_alignment_dir is None):
547
+ raise ValueError(
548
+ 'In inference mode, predict_alignment_dir must be specified'
549
+ )
550
+ elif(val_data_dir is not None and val_alignment_dir is None):
551
+ raise ValueError(
552
+ 'If val_data_dir is specified, val_alignment_dir must '
553
+ 'be specified as well'
554
+ )
555
+
556
+ # An ad-hoc measure for our particular filesystem restrictions
557
+ self._distillation_structure_index = None
558
+ if(_distillation_structure_index_path is not None):
559
+ with open(_distillation_structure_index_path, "r") as fp:
560
+ self._distillation_structure_index = json.load(fp)
561
+
562
+ self.alignment_index = None
563
+ if(alignment_index_path is not None):
564
+ with open(alignment_index_path, "r") as fp:
565
+ self.alignment_index = json.load(fp)
566
+
567
+ self.distillation_alignment_index = None
568
+ if(distillation_alignment_index_path is not None):
569
+ with open(distillation_alignment_index_path, "r") as fp:
570
+ self.distillation_alignment_index = json.load(fp)
571
+
572
+ def setup(self):
573
+ # Most of the arguments are the same for the three datasets
574
+ dataset_gen = partial(OpenFoldSingleDataset,
575
+ template_mmcif_dir=self.template_mmcif_dir,
576
+ max_template_date=self.max_template_date,
577
+ config=self.config,
578
+ kalign_binary_path=self.kalign_binary_path,
579
+ template_release_dates_cache_path=
580
+ self.template_release_dates_cache_path,
581
+ obsolete_pdbs_file_path=
582
+ self.obsolete_pdbs_file_path,
583
+ )
584
+
585
+ if(self.training_mode):
586
+ train_dataset = dataset_gen(
587
+ data_dir=self.train_data_dir,
588
+ alignment_dir=self.train_alignment_dir,
589
+ mapping_path=self.train_mapping_path,
590
+ max_template_hits=self.config.train.max_template_hits,
591
+ shuffle_top_k_prefiltered=
592
+ self.config.train.shuffle_top_k_prefiltered,
593
+ treat_pdb_as_distillation=False,
594
+ mode="train",
595
+ alignment_index=self.alignment_index,
596
+ )
597
+
598
+ distillation_dataset = None
599
+ if(self.distillation_data_dir is not None):
600
+ distillation_dataset = dataset_gen(
601
+ data_dir=self.distillation_data_dir,
602
+ alignment_dir=self.distillation_alignment_dir,
603
+ mapping_path=self.distillation_mapping_path,
604
+ max_template_hits=self.config.train.max_template_hits,
605
+ treat_pdb_as_distillation=True,
606
+ mode="train",
607
+ alignment_index=self.distillation_alignment_index,
608
+ _structure_index=self._distillation_structure_index,
609
+ )
610
+
611
+ d_prob = self.config.train.distillation_prob
612
+
613
+ if(distillation_dataset is not None):
614
+ datasets = [train_dataset, distillation_dataset]
615
+ d_prob = self.config.train.distillation_prob
616
+ probabilities = [1. - d_prob, d_prob]
617
+ chain_data_cache_paths = [
618
+ self.train_chain_data_cache_path,
619
+ self.distillation_chain_data_cache_path,
620
+ ]
621
+ else:
622
+ datasets = [train_dataset]
623
+ probabilities = [1.]
624
+ chain_data_cache_paths = [
625
+ self.train_chain_data_cache_path,
626
+ ]
627
+
628
+ if(self.batch_seed is not None):
629
+ generator = torch.Generator()
630
+ generator = generator.manual_seed(self.batch_seed + 1)
631
+
632
+ self.train_dataset = OpenFoldDataset(
633
+ datasets=datasets,
634
+ probabilities=probabilities,
635
+ epoch_len=self.train_epoch_len,
636
+ chain_data_cache_paths=chain_data_cache_paths,
637
+ generator=generator,
638
+ _roll_at_init=False,
639
+ )
640
+
641
+ if(self.val_data_dir is not None):
642
+ self.eval_dataset = dataset_gen(
643
+ data_dir=self.val_data_dir,
644
+ alignment_dir=self.val_alignment_dir,
645
+ mapping_path=None,
646
+ max_template_hits=self.config.eval.max_template_hits,
647
+ mode="eval",
648
+ )
649
+ else:
650
+ self.eval_dataset = None
651
+ else:
652
+ self.predict_dataset = dataset_gen(
653
+ data_dir=self.predict_data_dir,
654
+ alignment_dir=self.predict_alignment_dir,
655
+ mapping_path=None,
656
+ max_template_hits=self.config.predict.max_template_hits,
657
+ mode="predict",
658
+ )
659
+
660
+ def _gen_dataloader(self, stage):
661
+ generator = torch.Generator()
662
+ if(self.batch_seed is not None):
663
+ generator = generator.manual_seed(self.batch_seed)
664
+
665
+ dataset = None
666
+ if(stage == "train"):
667
+ dataset = self.train_dataset
668
+ # Filter the dataset, if necessary
669
+ dataset.reroll()
670
+ elif(stage == "eval"):
671
+ dataset = self.eval_dataset
672
+ elif(stage == "predict"):
673
+ dataset = self.predict_dataset
674
+ else:
675
+ raise ValueError("Invalid stage")
676
+
677
+ batch_collator = OpenFoldBatchCollator()
678
+
679
+ dl = OpenFoldDataLoader(
680
+ dataset,
681
+ config=self.config,
682
+ stage=stage,
683
+ generator=generator,
684
+ batch_size=self.config.data_module.data_loaders.batch_size,
685
+ num_workers=self.config.data_module.data_loaders.num_workers,
686
+ collate_fn=batch_collator,
687
+ )
688
+
689
+ return dl
690
+
691
+ def train_dataloader(self):
692
+ return self._gen_dataloader("train")
693
+
694
+ def val_dataloader(self):
695
+ if(self.eval_dataset is not None):
696
+ return self._gen_dataloader("eval")
697
+ return None
698
+
699
+ def predict_dataloader(self):
700
+ return self._gen_dataloader("predict")
701
+
702
+
703
+ class DummyDataset(torch.utils.data.Dataset):
704
+ def __init__(self, batch_path):
705
+ with open(batch_path, "rb") as f:
706
+ self.batch = pickle.load(f)
707
+
708
+ def __getitem__(self, idx):
709
+ return copy.deepcopy(self.batch)
710
+
711
+ def __len__(self):
712
+ return 1000
713
+
714
+
715
+ class DummyDataLoader(pl.LightningDataModule):
716
+ def __init__(self, batch_path):
717
+ super().__init__()
718
+ self.dataset = DummyDataset(batch_path)
719
+
720
+ def train_dataloader(self):
721
+ return torch.utils.data.DataLoader(self.dataset)
openfold/data/data_pipeline.py ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import datetime
18
+ from multiprocessing import cpu_count
19
+ from typing import Mapping, Optional, Sequence, Any
20
+
21
+ import numpy as np
22
+
23
+ from openfold.data import templates, parsers, mmcif_parsing
24
+ from openfold.data.tools import jackhmmer, hhblits, hhsearch
25
+ from openfold.data.tools.utils import to_date
26
+ from openfold.np import residue_constants, protein
27
+
28
+
29
+ FeatureDict = Mapping[str, np.ndarray]
30
+
31
+ def empty_template_feats(n_res) -> FeatureDict:
32
+ return {
33
+ "template_aatype": np.zeros((0, n_res)).astype(np.int64),
34
+ "template_all_atom_positions":
35
+ np.zeros((0, n_res, 37, 3)).astype(np.float32),
36
+ "template_sum_probs": np.zeros((0, 1)).astype(np.float32),
37
+ "template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
38
+ }
39
+
40
+
41
+ def make_template_features(
42
+ input_sequence: str,
43
+ hits: Sequence[Any],
44
+ template_featurizer: Any,
45
+ query_pdb_code: Optional[str] = None,
46
+ query_release_date: Optional[str] = None,
47
+ ) -> FeatureDict:
48
+ hits_cat = sum(hits.values(), [])
49
+ if(len(hits_cat) == 0 or template_featurizer is None):
50
+ template_features = empty_template_feats(len(input_sequence))
51
+ else:
52
+ templates_result = template_featurizer.get_templates(
53
+ query_sequence=input_sequence,
54
+ query_pdb_code=query_pdb_code,
55
+ query_release_date=query_release_date,
56
+ hits=hits_cat,
57
+ )
58
+ template_features = templates_result.features
59
+
60
+ # The template featurizer doesn't format empty template features
61
+ # properly. This is a quick fix.
62
+ if(template_features["template_aatype"].shape[0] == 0):
63
+ template_features = empty_template_feats(len(input_sequence))
64
+
65
+ return template_features
66
+
67
+
68
+ def unify_template_features(
69
+ template_feature_list: Sequence[FeatureDict]
70
+ ) -> FeatureDict:
71
+ out_dicts = []
72
+ seq_lens = [fd["template_aatype"].shape[1] for fd in template_feature_list]
73
+ for i, fd in enumerate(template_feature_list):
74
+ out_dict = {}
75
+ n_templates, n_res = fd["template_aatype"].shape[:2]
76
+ for k,v in fd.items():
77
+ seq_keys = [
78
+ "template_aatype",
79
+ "template_all_atom_positions",
80
+ "template_all_atom_mask",
81
+ ]
82
+ if(k in seq_keys):
83
+ new_shape = list(v.shape)
84
+ assert(new_shape[1] == n_res)
85
+ new_shape[1] = sum(seq_lens)
86
+ new_array = np.zeros(new_shape, dtype=v.dtype)
87
+
88
+ if(k == "template_aatype"):
89
+ new_array[..., residue_constants.HHBLITS_AA_TO_ID['-']] = 1
90
+
91
+ offset = sum(seq_lens[:i])
92
+ new_array[:, offset:offset + seq_lens[i]] = v
93
+ out_dict[k] = new_array
94
+ else:
95
+ out_dict[k] = v
96
+
97
+ chain_indices = np.array(n_templates * [i])
98
+ out_dict["template_chain_index"] = chain_indices
99
+
100
+ if(n_templates != 0):
101
+ out_dicts.append(out_dict)
102
+
103
+ out_dict = {
104
+ k: np.concatenate([od[k] for od in out_dicts]) for k in out_dicts[0]
105
+ }
106
+
107
+ return out_dict
108
+
109
+
110
+ def make_sequence_features(
111
+ sequence: str, description: str, num_res: int
112
+ ) -> FeatureDict:
113
+ """Construct a feature dict of sequence features."""
114
+ features = {}
115
+ features["aatype"] = residue_constants.sequence_to_onehot(
116
+ sequence=sequence,
117
+ mapping=residue_constants.restype_order_with_x,
118
+ map_unknown_to_x=True,
119
+ )
120
+ features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
121
+ features["domain_name"] = np.array(
122
+ [description.encode("utf-8")], dtype=np.object_
123
+ )
124
+ features["residue_index"] = np.array(range(num_res), dtype=np.int32)
125
+ features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
126
+ features["sequence"] = np.array(
127
+ [sequence.encode("utf-8")], dtype=np.object_
128
+ )
129
+ return features
130
+
131
+
132
+ def make_mmcif_features(
133
+ mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
134
+ ) -> FeatureDict:
135
+ input_sequence = mmcif_object.chain_to_seqres[chain_id]
136
+ description = "_".join([mmcif_object.file_id, chain_id])
137
+ num_res = len(input_sequence)
138
+
139
+ mmcif_feats = {}
140
+
141
+ mmcif_feats.update(
142
+ make_sequence_features(
143
+ sequence=input_sequence,
144
+ description=description,
145
+ num_res=num_res,
146
+ )
147
+ )
148
+
149
+ all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
150
+ mmcif_object=mmcif_object, chain_id=chain_id
151
+ )
152
+ mmcif_feats["all_atom_positions"] = all_atom_positions
153
+ mmcif_feats["all_atom_mask"] = all_atom_mask
154
+
155
+ mmcif_feats["resolution"] = np.array(
156
+ [mmcif_object.header["resolution"]], dtype=np.float32
157
+ )
158
+
159
+ mmcif_feats["release_date"] = np.array(
160
+ [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
161
+ )
162
+
163
+ mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
164
+
165
+ return mmcif_feats
166
+
167
+
168
+ def _aatype_to_str_sequence(aatype):
169
+ return ''.join([
170
+ residue_constants.restypes_with_x[aatype[i]]
171
+ for i in range(len(aatype))
172
+ ])
173
+
174
+
175
+ def make_protein_features(
176
+ protein_object: protein.Protein,
177
+ description: str,
178
+ _is_distillation: bool = False,
179
+ ) -> FeatureDict:
180
+ pdb_feats = {}
181
+ aatype = protein_object.aatype
182
+ sequence = _aatype_to_str_sequence(aatype)
183
+ pdb_feats.update(
184
+ make_sequence_features(
185
+ sequence=sequence,
186
+ description=description,
187
+ num_res=len(protein_object.aatype),
188
+ )
189
+ )
190
+
191
+ all_atom_positions = protein_object.atom_positions
192
+ all_atom_mask = protein_object.atom_mask
193
+
194
+ pdb_feats["all_atom_positions"] = all_atom_positions.astype(np.float32)
195
+ pdb_feats["all_atom_mask"] = all_atom_mask.astype(np.float32)
196
+
197
+ pdb_feats["resolution"] = np.array([0.]).astype(np.float32)
198
+ pdb_feats["is_distillation"] = np.array(
199
+ 1. if _is_distillation else 0.
200
+ ).astype(np.float32)
201
+
202
+ return pdb_feats
203
+
204
+
205
+ def make_pdb_features(
206
+ protein_object: protein.Protein,
207
+ description: str,
208
+ is_distillation: bool = True,
209
+ confidence_threshold: float = 50.,
210
+ ) -> FeatureDict:
211
+ pdb_feats = make_protein_features(
212
+ protein_object, description, _is_distillation=True
213
+ )
214
+
215
+ if(is_distillation):
216
+ high_confidence = protein_object.b_factors > confidence_threshold
217
+ high_confidence = np.any(high_confidence, axis=-1)
218
+ pdb_feats["all_atom_mask"] *= high_confidence[..., None]
219
+
220
+ return pdb_feats
221
+
222
+
223
+ def make_msa_features(
224
+ msas: Sequence[Sequence[str]],
225
+ deletion_matrices: Sequence[parsers.DeletionMatrix],
226
+ ) -> FeatureDict:
227
+ """Constructs a feature dict of MSA features."""
228
+ if not msas:
229
+ raise ValueError("At least one MSA must be provided.")
230
+
231
+ int_msa = []
232
+ deletion_matrix = []
233
+ seen_sequences = set()
234
+ for msa_index, msa in enumerate(msas):
235
+ if not msa:
236
+ raise ValueError(
237
+ f"MSA {msa_index} must contain at least one sequence."
238
+ )
239
+ for sequence_index, sequence in enumerate(msa):
240
+ if sequence in seen_sequences:
241
+ continue
242
+ seen_sequences.add(sequence)
243
+ int_msa.append(
244
+ [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]
245
+ )
246
+ deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
247
+
248
+ num_res = len(msas[0][0])
249
+ num_alignments = len(int_msa)
250
+ features = {}
251
+ features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
252
+ features["msa"] = np.array(int_msa, dtype=np.int32)
253
+ features["num_alignments"] = np.array(
254
+ [num_alignments] * num_res, dtype=np.int32
255
+ )
256
+ return features
257
+
258
+
259
+ class AlignmentRunner:
260
+ """Runs alignment tools and saves the results"""
261
+ def __init__(
262
+ self,
263
+ jackhmmer_binary_path: Optional[str] = None,
264
+ hhblits_binary_path: Optional[str] = None,
265
+ hhsearch_binary_path: Optional[str] = None,
266
+ uniref90_database_path: Optional[str] = None,
267
+ mgnify_database_path: Optional[str] = None,
268
+ bfd_database_path: Optional[str] = None,
269
+ uniclust30_database_path: Optional[str] = None,
270
+ pdb70_database_path: Optional[str] = None,
271
+ use_small_bfd: Optional[bool] = None,
272
+ no_cpus: Optional[int] = None,
273
+ uniref_max_hits: int = 10000,
274
+ mgnify_max_hits: int = 5000,
275
+ ):
276
+ """
277
+ Args:
278
+ jackhmmer_binary_path:
279
+ Path to jackhmmer binary
280
+ hhblits_binary_path:
281
+ Path to hhblits binary
282
+ hhsearch_binary_path:
283
+ Path to hhsearch binary
284
+ uniref90_database_path:
285
+ Path to uniref90 database. If provided, jackhmmer_binary_path
286
+ must also be provided
287
+ mgnify_database_path:
288
+ Path to mgnify database. If provided, jackhmmer_binary_path
289
+ must also be provided
290
+ bfd_database_path:
291
+ Path to BFD database. Depending on the value of use_small_bfd,
292
+ one of hhblits_binary_path or jackhmmer_binary_path must be
293
+ provided.
294
+ uniclust30_database_path:
295
+ Path to uniclust30. Searched alongside BFD if use_small_bfd is
296
+ false.
297
+ pdb70_database_path:
298
+ Path to pdb70 database.
299
+ use_small_bfd:
300
+ Whether to search the BFD database alone with jackhmmer or
301
+ in conjunction with uniclust30 with hhblits.
302
+ no_cpus:
303
+ The number of CPUs available for alignment. By default, all
304
+ CPUs are used.
305
+ uniref_max_hits:
306
+ Max number of uniref hits
307
+ mgnify_max_hits:
308
+ Max number of mgnify hits
309
+ """
310
+ db_map = {
311
+ "jackhmmer": {
312
+ "binary": jackhmmer_binary_path,
313
+ "dbs": [
314
+ uniref90_database_path,
315
+ mgnify_database_path,
316
+ bfd_database_path if use_small_bfd else None,
317
+ ],
318
+ },
319
+ "hhblits": {
320
+ "binary": hhblits_binary_path,
321
+ "dbs": [
322
+ bfd_database_path if not use_small_bfd else None,
323
+ ],
324
+ },
325
+ "hhsearch": {
326
+ "binary": hhsearch_binary_path,
327
+ "dbs": [
328
+ pdb70_database_path,
329
+ ],
330
+ },
331
+ }
332
+
333
+ for name, dic in db_map.items():
334
+ binary, dbs = dic["binary"], dic["dbs"]
335
+ if(binary is None and not all([x is None for x in dbs])):
336
+ raise ValueError(
337
+ f"{name} DBs provided but {name} binary is None"
338
+ )
339
+
340
+ if(not all([x is None for x in db_map["hhsearch"]["dbs"]])
341
+ and uniref90_database_path is None):
342
+ raise ValueError(
343
+ """uniref90_database_path must be specified in order to perform
344
+ template search"""
345
+ )
346
+
347
+ self.uniref_max_hits = uniref_max_hits
348
+ self.mgnify_max_hits = mgnify_max_hits
349
+ self.use_small_bfd = use_small_bfd
350
+
351
+ if(no_cpus is None):
352
+ no_cpus = cpu_count()
353
+
354
+ self.jackhmmer_uniref90_runner = None
355
+ if(jackhmmer_binary_path is not None and
356
+ uniref90_database_path is not None
357
+ ):
358
+ self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
359
+ binary_path=jackhmmer_binary_path,
360
+ database_path=uniref90_database_path,
361
+ n_cpu=no_cpus,
362
+ )
363
+
364
+ self.jackhmmer_small_bfd_runner = None
365
+ self.hhblits_bfd_uniclust_runner = None
366
+ if(bfd_database_path is not None):
367
+ if use_small_bfd:
368
+ self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
369
+ binary_path=jackhmmer_binary_path,
370
+ database_path=bfd_database_path,
371
+ n_cpu=no_cpus,
372
+ )
373
+ else:
374
+ dbs = [bfd_database_path]
375
+ if(uniclust30_database_path is not None):
376
+ dbs.append(uniclust30_database_path)
377
+ self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
378
+ binary_path=hhblits_binary_path,
379
+ databases=dbs,
380
+ n_cpu=no_cpus,
381
+ )
382
+
383
+ self.jackhmmer_mgnify_runner = None
384
+ if(mgnify_database_path is not None):
385
+ self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
386
+ binary_path=jackhmmer_binary_path,
387
+ database_path=mgnify_database_path,
388
+ n_cpu=no_cpus,
389
+ )
390
+
391
+ self.hhsearch_pdb70_runner = None
392
+ if(pdb70_database_path is not None):
393
+ self.hhsearch_pdb70_runner = hhsearch.HHSearch(
394
+ binary_path=hhsearch_binary_path,
395
+ databases=[pdb70_database_path],
396
+ n_cpu=no_cpus,
397
+ )
398
+
399
+ def run(
400
+ self,
401
+ fasta_path: str,
402
+ output_dir: str,
403
+ ):
404
+ """Runs alignment tools on a sequence"""
405
+ if(self.jackhmmer_uniref90_runner is not None):
406
+ jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
407
+ fasta_path
408
+ )[0]
409
+ uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
410
+ jackhmmer_uniref90_result["sto"],
411
+ max_sequences=self.uniref_max_hits
412
+ )
413
+ uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
414
+ with open(uniref90_out_path, "w") as f:
415
+ f.write(uniref90_msa_as_a3m)
416
+
417
+ if(self.hhsearch_pdb70_runner is not None):
418
+ hhsearch_result = self.hhsearch_pdb70_runner.query(
419
+ uniref90_msa_as_a3m
420
+ )
421
+ pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
422
+ with open(pdb70_out_path, "w") as f:
423
+ f.write(hhsearch_result)
424
+
425
+ if(self.jackhmmer_mgnify_runner is not None):
426
+ jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
427
+ fasta_path
428
+ )[0]
429
+ mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
430
+ jackhmmer_mgnify_result["sto"],
431
+ max_sequences=self.mgnify_max_hits
432
+ )
433
+ mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
434
+ with open(mgnify_out_path, "w") as f:
435
+ f.write(mgnify_msa_as_a3m)
436
+
437
+ if(self.use_small_bfd and self.jackhmmer_small_bfd_runner is not None):
438
+ jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
439
+ fasta_path
440
+ )[0]
441
+ bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
442
+ with open(bfd_out_path, "w") as f:
443
+ f.write(jackhmmer_small_bfd_result["sto"])
444
+ elif(self.hhblits_bfd_uniclust_runner is not None):
445
+ hhblits_bfd_uniclust_result = (
446
+ self.hhblits_bfd_uniclust_runner.query(fasta_path)
447
+ )
448
+ if output_dir is not None:
449
+ bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
450
+ with open(bfd_out_path, "w") as f:
451
+ f.write(hhblits_bfd_uniclust_result["a3m"])
452
+
453
+
454
+ class DataPipeline:
455
+ """Assembles input features."""
456
+ def __init__(
457
+ self,
458
+ template_featurizer: Optional[templates.TemplateHitFeaturizer],
459
+ ):
460
+ self.template_featurizer = template_featurizer
461
+
462
+ def _parse_msa_data(
463
+ self,
464
+ alignment_dir: str,
465
+ alignment_index: Optional[Any] = None,
466
+ ) -> Mapping[str, Any]:
467
+ msa_data = {}
468
+ if(alignment_index is not None):
469
+ fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
470
+
471
+ def read_msa(start, size):
472
+ fp.seek(start)
473
+ msa = fp.read(size).decode("utf-8")
474
+ return msa
475
+
476
+ for (name, start, size) in alignment_index["files"]:
477
+ ext = os.path.splitext(name)[-1]
478
+
479
+ if(ext == ".a3m"):
480
+ msa, deletion_matrix = parsers.parse_a3m(
481
+ read_msa(start, size)
482
+ )
483
+ data = {"msa": msa, "deletion_matrix": deletion_matrix}
484
+ elif(ext == ".sto"):
485
+ msa, deletion_matrix, _ = parsers.parse_stockholm(
486
+ read_msa(start, size)
487
+ )
488
+ data = {"msa": msa, "deletion_matrix": deletion_matrix}
489
+ else:
490
+ continue
491
+
492
+ msa_data[name] = data
493
+
494
+ fp.close()
495
+ else:
496
+ for f in os.listdir(alignment_dir):
497
+ path = os.path.join(alignment_dir, f)
498
+ ext = os.path.splitext(f)[-1]
499
+
500
+ if(ext == ".a3m"):
501
+ with open(path, "r") as fp:
502
+ msa, deletion_matrix = parsers.parse_a3m(fp.read())
503
+ data = {"msa": msa, "deletion_matrix": deletion_matrix}
504
+ elif(ext == ".sto"):
505
+ with open(path, "r") as fp:
506
+ msa, deletion_matrix, _ = parsers.parse_stockholm(
507
+ fp.read()
508
+ )
509
+ data = {"msa": msa, "deletion_matrix": deletion_matrix}
510
+ else:
511
+ continue
512
+
513
+ msa_data[f] = data
514
+
515
+ return msa_data
516
+
517
+ def _parse_template_hits(
518
+ self,
519
+ alignment_dir: str,
520
+ alignment_index: Optional[Any] = None
521
+ ) -> Mapping[str, Any]:
522
+ all_hits = {}
523
+ if(alignment_index is not None):
524
+ fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
525
+
526
+ def read_template(start, size):
527
+ fp.seek(start)
528
+ return fp.read(size).decode("utf-8")
529
+
530
+ for (name, start, size) in alignment_index["files"]:
531
+ ext = os.path.splitext(name)[-1]
532
+
533
+ if(ext == ".hhr"):
534
+ hits = parsers.parse_hhr(read_template(start, size))
535
+ all_hits[name] = hits
536
+
537
+ fp.close()
538
+ else:
539
+ for f in os.listdir(alignment_dir):
540
+ path = os.path.join(alignment_dir, f)
541
+ ext = os.path.splitext(f)[-1]
542
+
543
+ if(ext == ".hhr"):
544
+ with open(path, "r") as fp:
545
+ hits = parsers.parse_hhr(fp.read())
546
+ all_hits[f] = hits
547
+
548
+ return all_hits
549
+
550
+ def _get_msas(self,
551
+ alignment_dir: str,
552
+ input_sequence: Optional[str] = None,
553
+ alignment_index: Optional[str] = None,
554
+ ):
555
+ msa_data = self._parse_msa_data(alignment_dir, alignment_index)
556
+ if(len(msa_data) == 0):
557
+ if(input_sequence is None):
558
+ raise ValueError(
559
+ """
560
+ If the alignment dir contains no MSAs, an input sequence
561
+ must be provided.
562
+ """
563
+ )
564
+ msa_data["dummy"] = {
565
+ "msa": [input_sequence],
566
+ "deletion_matrix": [[0 for _ in input_sequence]],
567
+ }
568
+
569
+ msas, deletion_matrices = zip(*[
570
+ (v["msa"], v["deletion_matrix"]) for v in msa_data.values()
571
+ ])
572
+
573
+ return msas, deletion_matrices
574
+
575
+ def _process_msa_feats(
576
+ self,
577
+ alignment_dir: str,
578
+ input_sequence: Optional[str] = None,
579
+ alignment_index: Optional[str] = None
580
+ ) -> Mapping[str, Any]:
581
+ msas, deletion_matrices = self._get_msas(
582
+ alignment_dir, input_sequence, alignment_index
583
+ )
584
+ msa_features = make_msa_features(
585
+ msas=msas,
586
+ deletion_matrices=deletion_matrices,
587
+ )
588
+
589
+ return msa_features
590
+
591
+ def process_fasta(
592
+ self,
593
+ fasta_path: str,
594
+ alignment_dir: str,
595
+ alignment_index: Optional[str] = None,
596
+ ) -> FeatureDict:
597
+ """Assembles features for a single sequence in a FASTA file"""
598
+ with open(fasta_path) as f:
599
+ fasta_str = f.read()
600
+ input_seqs, input_descs = parsers.parse_fasta(fasta_str)
601
+ if len(input_seqs) != 1:
602
+ raise ValueError(
603
+ f"More than one input sequence found in {fasta_path}."
604
+ )
605
+ input_sequence = input_seqs[0]
606
+ input_description = input_descs[0]
607
+ num_res = len(input_sequence)
608
+
609
+ hits = self._parse_template_hits(alignment_dir, alignment_index)
610
+ template_features = make_template_features(
611
+ input_sequence,
612
+ hits,
613
+ self.template_featurizer,
614
+ )
615
+
616
+ sequence_features = make_sequence_features(
617
+ sequence=input_sequence,
618
+ description=input_description,
619
+ num_res=num_res,
620
+ )
621
+
622
+ msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
623
+
624
+ return {
625
+ **sequence_features,
626
+ **msa_features,
627
+ **template_features
628
+ }
629
+
630
+ def process_mmcif(
631
+ self,
632
+ mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
633
+ alignment_dir: str,
634
+ chain_id: Optional[str] = None,
635
+ alignment_index: Optional[str] = None,
636
+ ) -> FeatureDict:
637
+ """
638
+ Assembles features for a specific chain in an mmCIF object.
639
+
640
+ If chain_id is None, it is assumed that there is only one chain
641
+ in the object. Otherwise, a ValueError is thrown.
642
+ """
643
+ if chain_id is None:
644
+ chains = mmcif.structure.get_chains()
645
+ chain = next(chains, None)
646
+ if chain is None:
647
+ raise ValueError("No chains in mmCIF file")
648
+ chain_id = chain.id
649
+
650
+ mmcif_feats = make_mmcif_features(mmcif, chain_id)
651
+
652
+ input_sequence = mmcif.chain_to_seqres[chain_id]
653
+ hits = self._parse_template_hits(alignment_dir, alignment_index)
654
+ template_features = make_template_features(
655
+ input_sequence,
656
+ hits,
657
+ self.template_featurizer,
658
+ query_release_date=to_date(mmcif.header["release_date"])
659
+ )
660
+
661
+ msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
662
+
663
+ return {**mmcif_feats, **template_features, **msa_features}
664
+
665
+ def process_pdb(
666
+ self,
667
+ pdb_path: str,
668
+ alignment_dir: str,
669
+ is_distillation: bool = True,
670
+ chain_id: Optional[str] = None,
671
+ _structure_index: Optional[str] = None,
672
+ alignment_index: Optional[str] = None,
673
+ ) -> FeatureDict:
674
+ """
675
+ Assembles features for a protein in a PDB file.
676
+ """
677
+ if(_structure_index is not None):
678
+ db_dir = os.path.dirname(pdb_path)
679
+ db = _structure_index["db"]
680
+ db_path = os.path.join(db_dir, db)
681
+ fp = open(db_path, "rb")
682
+ _, offset, length = _structure_index["files"][0]
683
+ fp.seek(offset)
684
+ pdb_str = fp.read(length).decode("utf-8")
685
+ fp.close()
686
+ else:
687
+ with open(pdb_path, 'r') as f:
688
+ pdb_str = f.read()
689
+
690
+ protein_object = protein.from_pdb_string(pdb_str, chain_id)
691
+ input_sequence = _aatype_to_str_sequence(protein_object.aatype)
692
+ description = os.path.splitext(os.path.basename(pdb_path))[0].upper()
693
+ pdb_feats = make_pdb_features(
694
+ protein_object,
695
+ description,
696
+ is_distillation=is_distillation
697
+ )
698
+
699
+ hits = self._parse_template_hits(alignment_dir, alignment_index)
700
+ template_features = make_template_features(
701
+ input_sequence,
702
+ hits,
703
+ self.template_featurizer,
704
+ )
705
+
706
+ msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
707
+
708
+ return {**pdb_feats, **template_features, **msa_features}
709
+
710
+ def process_core(
711
+ self,
712
+ core_path: str,
713
+ alignment_dir: str,
714
+ alignment_index: Optional[str] = None,
715
+ ) -> FeatureDict:
716
+ """
717
+ Assembles features for a protein in a ProteinNet .core file.
718
+ """
719
+ with open(core_path, 'r') as f:
720
+ core_str = f.read()
721
+
722
+ protein_object = protein.from_proteinnet_string(core_str)
723
+ input_sequence = _aatype_to_str_sequence(protein_object.aatype)
724
+ description = os.path.splitext(os.path.basename(core_path))[0].upper()
725
+ core_feats = make_protein_features(protein_object, description)
726
+
727
+ hits = self._parse_template_hits(alignment_dir, alignment_index)
728
+ template_features = make_template_features(
729
+ input_sequence,
730
+ hits,
731
+ self.template_featurizer,
732
+ )
733
+
734
+ msa_features = self._process_msa_feats(alignment_dir, input_sequence)
735
+
736
+ return {**core_feats, **template_features, **msa_features}
737
+
738
+ def process_multiseq_fasta(self,
739
+ fasta_path: str,
740
+ super_alignment_dir: str,
741
+ ri_gap: int = 200,
742
+ ) -> FeatureDict:
743
+ """
744
+ Assembles features for a multi-sequence FASTA. Uses Minkyung Baek's
745
+ hack from Twitter (a.k.a. AlphaFold-Gap).
746
+ """
747
+ with open(fasta_path, 'r') as f:
748
+ fasta_str = f.read()
749
+
750
+ input_seqs, input_descs = parsers.parse_fasta(fasta_str)
751
+
752
+ # No whitespace allowed
753
+ input_descs = [i.split()[0] for i in input_descs]
754
+
755
+ # Stitch all of the sequences together
756
+ input_sequence = ''.join(input_seqs)
757
+ input_description = '-'.join(input_descs)
758
+ num_res = len(input_sequence)
759
+
760
+ sequence_features = make_sequence_features(
761
+ sequence=input_sequence,
762
+ description=input_description,
763
+ num_res=num_res,
764
+ )
765
+
766
+ seq_lens = [len(s) for s in input_seqs]
767
+ total_offset = 0
768
+ for sl in seq_lens:
769
+ total_offset += sl
770
+ sequence_features["residue_index"][total_offset:] += ri_gap
771
+
772
+ msa_list = []
773
+ deletion_mat_list = []
774
+ for seq, desc in zip(input_seqs, input_descs):
775
+ alignment_dir = os.path.join(
776
+ super_alignment_dir, desc
777
+ )
778
+ msas, deletion_mats = self._get_msas(
779
+ alignment_dir, seq, None
780
+ )
781
+ msa_list.append(msas)
782
+ deletion_mat_list.append(deletion_mats)
783
+
784
+ final_msa = []
785
+ final_deletion_mat = []
786
+ msa_it = enumerate(zip(msa_list, deletion_mat_list))
787
+ for i, (msas, deletion_mats) in msa_it:
788
+ prec, post = sum(seq_lens[:i]), sum(seq_lens[i + 1:])
789
+ msas = [
790
+ [prec * '-' + seq + post * '-' for seq in msa] for msa in msas
791
+ ]
792
+ deletion_mats = [
793
+ [prec * [0] + dml + post * [0] for dml in deletion_mat]
794
+ for deletion_mat in deletion_mats
795
+ ]
796
+
797
+ assert(len(msas[0][-1]) == len(input_sequence))
798
+
799
+ final_msa.extend(msas)
800
+ final_deletion_mat.extend(deletion_mats)
801
+
802
+ msa_features = make_msa_features(
803
+ msas=final_msa,
804
+ deletion_matrices=final_deletion_mat,
805
+ )
806
+
807
+ template_feature_list = []
808
+ for seq, desc in zip(input_seqs, input_descs):
809
+ alignment_dir = os.path.join(
810
+ super_alignment_dir, desc
811
+ )
812
+ hits = self._parse_template_hits(alignment_dir, alignment_index=None)
813
+ template_features = make_template_features(
814
+ seq,
815
+ hits,
816
+ self.template_featurizer,
817
+ )
818
+ template_feature_list.append(template_features)
819
+
820
+ template_features = unify_template_features(template_feature_list)
821
+
822
+ return {
823
+ **sequence_features,
824
+ **msa_features,
825
+ **template_features,
826
+ }
openfold/data/data_transforms.py ADDED
@@ -0,0 +1,1212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import itertools
17
+ from functools import reduce, wraps
18
+ from operator import add
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
24
+ from openfold.np import residue_constants as rc
25
+ from openfold.utils.rigid_utils import Rotation, Rigid
26
+ from openfold.utils.tensor_utils import (
27
+ tree_map,
28
+ tensor_tree_map,
29
+ batched_gather,
30
+ )
31
+
32
+
33
+ MSA_FEATURE_NAMES = [
34
+ "msa",
35
+ "deletion_matrix",
36
+ "msa_mask",
37
+ "msa_row_mask",
38
+ "bert_mask",
39
+ "true_msa",
40
+ ]
41
+
42
+
43
+ def cast_to_64bit_ints(protein):
44
+ # We keep all ints as int64
45
+ for k, v in protein.items():
46
+ if v.dtype == torch.int32:
47
+ protein[k] = v.type(torch.int64)
48
+
49
+ return protein
50
+
51
+
52
+ def make_one_hot(x, num_classes):
53
+ x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
54
+ x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
55
+ return x_one_hot
56
+
57
+
58
+ def make_seq_mask(protein):
59
+ protein["seq_mask"] = torch.ones(
60
+ protein["aatype"].shape, dtype=torch.float32
61
+ )
62
+ return protein
63
+
64
+
65
+ def make_template_mask(protein):
66
+ protein["template_mask"] = torch.ones(
67
+ protein["template_aatype"].shape[0], dtype=torch.float32
68
+ )
69
+ return protein
70
+
71
+
72
+ def curry1(f):
73
+ """Supply all arguments but the first."""
74
+ @wraps(f)
75
+ def fc(*args, **kwargs):
76
+ return lambda x: f(x, *args, **kwargs)
77
+
78
+ return fc
79
+
80
+
81
+ def make_all_atom_aatype(protein):
82
+ protein["all_atom_aatype"] = protein["aatype"]
83
+ return protein
84
+
85
+
86
+ def fix_templates_aatype(protein):
87
+ # Map one-hot to indices
88
+ num_templates = protein["template_aatype"].shape[0]
89
+ if(num_templates > 0):
90
+ protein["template_aatype"] = torch.argmax(
91
+ protein["template_aatype"], dim=-1
92
+ )
93
+ # Map hhsearch-aatype to our aatype.
94
+ new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
95
+ new_order = torch.tensor(
96
+ new_order_list, dtype=torch.int64, device=protein["aatype"].device,
97
+ ).expand(num_templates, -1)
98
+ protein["template_aatype"] = torch.gather(
99
+ new_order, 1, index=protein["template_aatype"]
100
+ )
101
+
102
+ return protein
103
+
104
+
105
+ def correct_msa_restypes(protein):
106
+ """Correct MSA restype to have the same order as rc."""
107
+ new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
108
+ new_order = torch.tensor(
109
+ [new_order_list] * protein["msa"].shape[1],
110
+ device=protein["msa"].device,
111
+ ).transpose(0, 1)
112
+ protein["msa"] = torch.gather(new_order, 0, protein["msa"])
113
+
114
+ perm_matrix = np.zeros((22, 22), dtype=np.float32)
115
+ perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
116
+
117
+ for k in protein:
118
+ if "profile" in k:
119
+ num_dim = protein[k].shape.as_list()[-1]
120
+ assert num_dim in [
121
+ 20,
122
+ 21,
123
+ 22,
124
+ ], "num_dim for %s out of expected range: %s" % (k, num_dim)
125
+ protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
126
+
127
+ return protein
128
+
129
+
130
+ def squeeze_features(protein):
131
+ """Remove singleton and repeated dimensions in protein features."""
132
+ protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
133
+ for k in [
134
+ "domain_name",
135
+ "msa",
136
+ "num_alignments",
137
+ "seq_length",
138
+ "sequence",
139
+ "superfamily",
140
+ "deletion_matrix",
141
+ "resolution",
142
+ "between_segment_residues",
143
+ "residue_index",
144
+ "template_all_atom_mask",
145
+ ]:
146
+ if k in protein:
147
+ final_dim = protein[k].shape[-1]
148
+ if isinstance(final_dim, int) and final_dim == 1:
149
+ if torch.is_tensor(protein[k]):
150
+ protein[k] = torch.squeeze(protein[k], dim=-1)
151
+ else:
152
+ protein[k] = np.squeeze(protein[k], axis=-1)
153
+
154
+ for k in ["seq_length", "num_alignments"]:
155
+ if k in protein:
156
+ protein[k] = protein[k][0]
157
+
158
+ return protein
159
+
160
+
161
+ @curry1
162
+ def randomly_replace_msa_with_unknown(protein, replace_proportion):
163
+ """Replace a portion of the MSA with 'X'."""
164
+ msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
165
+ x_idx = 20
166
+ gap_idx = 21
167
+ msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
168
+ protein["msa"] = torch.where(
169
+ msa_mask,
170
+ torch.ones_like(protein["msa"]) * x_idx,
171
+ protein["msa"]
172
+ )
173
+ aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
174
+
175
+ protein["aatype"] = torch.where(
176
+ aatype_mask,
177
+ torch.ones_like(protein["aatype"]) * x_idx,
178
+ protein["aatype"],
179
+ )
180
+ return protein
181
+
182
+
183
+ @curry1
184
+ def sample_msa(protein, max_seq, keep_extra, seed=None):
185
+ """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
186
+ num_seq = protein["msa"].shape[0]
187
+ g = torch.Generator(device=protein["msa"].device)
188
+ if seed is not None:
189
+ g.manual_seed(seed)
190
+ shuffled = torch.randperm(num_seq - 1, generator=g) + 1
191
+ index_order = torch.cat(
192
+ (torch.tensor([0], device=shuffled.device), shuffled),
193
+ dim=0
194
+ )
195
+ num_sel = min(max_seq, num_seq)
196
+ sel_seq, not_sel_seq = torch.split(
197
+ index_order, [num_sel, num_seq - num_sel]
198
+ )
199
+
200
+ for k in MSA_FEATURE_NAMES:
201
+ if k in protein:
202
+ if keep_extra:
203
+ protein["extra_" + k] = torch.index_select(
204
+ protein[k], 0, not_sel_seq
205
+ )
206
+ protein[k] = torch.index_select(protein[k], 0, sel_seq)
207
+
208
+ return protein
209
+
210
+
211
+ @curry1
212
+ def add_distillation_flag(protein, distillation):
213
+ protein['is_distillation'] = distillation
214
+ return protein
215
+
216
+ @curry1
217
+ def sample_msa_distillation(protein, max_seq):
218
+ if(protein["is_distillation"] == 1):
219
+ protein = sample_msa(max_seq, keep_extra=False)(protein)
220
+ return protein
221
+
222
+
223
+ @curry1
224
+ def crop_extra_msa(protein, max_extra_msa):
225
+ num_seq = protein["extra_msa"].shape[0]
226
+ num_sel = min(max_extra_msa, num_seq)
227
+ select_indices = torch.randperm(num_seq)[:num_sel]
228
+ for k in MSA_FEATURE_NAMES:
229
+ if "extra_" + k in protein:
230
+ protein["extra_" + k] = torch.index_select(
231
+ protein["extra_" + k], 0, select_indices
232
+ )
233
+
234
+ return protein
235
+
236
+
237
+ def delete_extra_msa(protein):
238
+ for k in MSA_FEATURE_NAMES:
239
+ if "extra_" + k in protein:
240
+ del protein["extra_" + k]
241
+ return protein
242
+
243
+
244
+ # Not used in inference
245
+ @curry1
246
+ def block_delete_msa(protein, config):
247
+ num_seq = protein["msa"].shape[0]
248
+ block_num_seq = torch.floor(
249
+ torch.tensor(num_seq, dtype=torch.float32, device=protein["msa"].device)
250
+ * config.msa_fraction_per_block
251
+ ).to(torch.int32)
252
+
253
+ if config.randomize_num_blocks:
254
+ nb = torch.distributions.uniform.Uniform(
255
+ 0, config.num_blocks + 1
256
+ ).sample()
257
+ else:
258
+ nb = config.num_blocks
259
+
260
+ del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
261
+ del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
262
+ del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
263
+ del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
264
+
265
+ # Make sure we keep the original sequence
266
+ combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
267
+ uniques, counts = combined.unique(return_counts=True)
268
+ difference = uniques[counts == 1]
269
+ intersection = uniques[counts > 1]
270
+ keep_indices = torch.squeeze(difference, 0)
271
+
272
+ for k in MSA_FEATURE_NAMES:
273
+ if k in protein:
274
+ protein[k] = torch.gather(protein[k], keep_indices)
275
+
276
+ return protein
277
+
278
+
279
+ @curry1
280
+ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
281
+ weights = torch.cat(
282
+ [
283
+ torch.ones(21, device=protein["msa"].device),
284
+ gap_agreement_weight * torch.ones(1, device=protein["msa"].device),
285
+ torch.zeros(1, device=protein["msa"].device)
286
+ ],
287
+ 0,
288
+ )
289
+
290
+ # Make agreement score as weighted Hamming distance
291
+ msa_one_hot = make_one_hot(protein["msa"], 23)
292
+ sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
293
+ extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
294
+ extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
295
+
296
+ num_seq, num_res, _ = sample_one_hot.shape
297
+ extra_num_seq, _, _ = extra_one_hot.shape
298
+
299
+ # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
300
+ # in an optimized fashion to avoid possible memory or computation blowup.
301
+ agreement = torch.matmul(
302
+ torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
303
+ torch.reshape(
304
+ sample_one_hot * weights, [num_seq, num_res * 23]
305
+ ).transpose(0, 1),
306
+ )
307
+
308
+ # Assign each sequence in the extra sequences to the closest MSA sample
309
+ protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
310
+ torch.int64
311
+ )
312
+
313
+ return protein
314
+
315
+
316
+ def unsorted_segment_sum(data, segment_ids, num_segments):
317
+ """
318
+ Computes the sum along segments of a tensor. Similar to
319
+ tf.unsorted_segment_sum, but only supports 1-D indices.
320
+
321
+ :param data: A tensor whose segments are to be summed.
322
+ :param segment_ids: The 1-D segment indices tensor.
323
+ :param num_segments: The number of segments.
324
+ :return: A tensor of same data type as the data argument.
325
+ """
326
+ assert (
327
+ len(segment_ids.shape) == 1 and
328
+ segment_ids.shape[0] == data.shape[0]
329
+ )
330
+ segment_ids = segment_ids.view(
331
+ segment_ids.shape[0], *((1,) * len(data.shape[1:]))
332
+ )
333
+ segment_ids = segment_ids.expand(data.shape)
334
+ shape = [num_segments] + list(data.shape[1:])
335
+ tensor = (
336
+ torch.zeros(*shape, device=segment_ids.device)
337
+ .scatter_add_(0, segment_ids, data.float())
338
+ )
339
+ tensor = tensor.type(data.dtype)
340
+ return tensor
341
+
342
+
343
+ @curry1
344
+ def summarize_clusters(protein):
345
+ """Produce profile and deletion_matrix_mean within each cluster."""
346
+ num_seq = protein["msa"].shape[0]
347
+
348
+ def csum(x):
349
+ return unsorted_segment_sum(
350
+ x, protein["extra_cluster_assignment"], num_seq
351
+ )
352
+
353
+ mask = protein["extra_msa_mask"]
354
+ mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center
355
+
356
+ msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
357
+ msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
358
+ protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
359
+ del msa_sum
360
+
361
+ del_sum = csum(mask * protein["extra_deletion_matrix"])
362
+ del_sum += protein["deletion_matrix"] # Original sequence
363
+ protein["cluster_deletion_mean"] = del_sum / mask_counts
364
+ del del_sum
365
+
366
+ return protein
367
+
368
+
369
+ def make_msa_mask(protein):
370
+ """Mask features are all ones, but will later be zero-padded."""
371
+ protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
372
+ protein["msa_row_mask"] = torch.ones(
373
+ (protein["msa"].shape[0]), dtype=torch.float32
374
+ )
375
+ return protein
376
+
377
+
378
+ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
379
+ """Create pseudo beta features."""
380
+ is_gly = torch.eq(aatype, rc.restype_order["G"])
381
+ ca_idx = rc.atom_order["CA"]
382
+ cb_idx = rc.atom_order["CB"]
383
+ pseudo_beta = torch.where(
384
+ torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
385
+ all_atom_positions[..., ca_idx, :],
386
+ all_atom_positions[..., cb_idx, :],
387
+ )
388
+
389
+ if all_atom_mask is not None:
390
+ pseudo_beta_mask = torch.where(
391
+ is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
392
+ )
393
+ return pseudo_beta, pseudo_beta_mask
394
+ else:
395
+ return pseudo_beta
396
+
397
+
398
+ @curry1
399
+ def make_pseudo_beta(protein, prefix=""):
400
+ """Create pseudo-beta (alpha for glycine) position and mask."""
401
+ assert prefix in ["", "template_"]
402
+ (
403
+ protein[prefix + "pseudo_beta"],
404
+ protein[prefix + "pseudo_beta_mask"],
405
+ ) = pseudo_beta_fn(
406
+ protein["template_aatype" if prefix else "aatype"],
407
+ protein[prefix + "all_atom_positions"],
408
+ protein["template_all_atom_mask" if prefix else "all_atom_mask"],
409
+ )
410
+ return protein
411
+
412
+
413
+ @curry1
414
+ def add_constant_field(protein, key, value):
415
+ protein[key] = torch.tensor(value, device=protein["msa"].device)
416
+ return protein
417
+
418
+
419
+ def shaped_categorical(probs, epsilon=1e-10):
420
+ ds = probs.shape
421
+ num_classes = ds[-1]
422
+ distribution = torch.distributions.categorical.Categorical(
423
+ torch.reshape(probs + epsilon, [-1, num_classes])
424
+ )
425
+ counts = distribution.sample()
426
+ return torch.reshape(counts, ds[:-1])
427
+
428
+
429
+ def make_hhblits_profile(protein):
430
+ """Compute the HHblits MSA profile if not already present."""
431
+ if "hhblits_profile" in protein:
432
+ return protein
433
+
434
+ # Compute the profile for every residue (over all MSA sequences).
435
+ msa_one_hot = make_one_hot(protein["msa"], 22)
436
+
437
+ protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
438
+ return protein
439
+
440
+
441
+ @curry1
442
+ def make_masked_msa(protein, config, replace_fraction):
443
+ """Create data for BERT on raw MSA."""
444
+ # Add a random amino acid uniformly.
445
+ random_aa = torch.tensor(
446
+ [0.05] * 20 + [0.0, 0.0],
447
+ dtype=torch.float32,
448
+ device=protein["aatype"].device
449
+ )
450
+
451
+ categorical_probs = (
452
+ config.uniform_prob * random_aa
453
+ + config.profile_prob * protein["hhblits_profile"]
454
+ + config.same_prob * make_one_hot(protein["msa"], 22)
455
+ )
456
+
457
+ # Put all remaining probability on [MASK] which is a new column
458
+ pad_shapes = list(
459
+ reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
460
+ )
461
+ pad_shapes[1] = 1
462
+ mask_prob = (
463
+ 1.0 - config.profile_prob - config.same_prob - config.uniform_prob
464
+ )
465
+ assert mask_prob >= 0.0
466
+
467
+ categorical_probs = torch.nn.functional.pad(
468
+ categorical_probs, pad_shapes, value=mask_prob
469
+ )
470
+
471
+ sh = protein["msa"].shape
472
+ mask_position = torch.rand(sh) < replace_fraction
473
+
474
+ bert_msa = shaped_categorical(categorical_probs)
475
+ bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
476
+
477
+ # Mix real and masked MSA
478
+ protein["bert_mask"] = mask_position.to(torch.float32)
479
+ protein["true_msa"] = protein["msa"]
480
+ protein["msa"] = bert_msa
481
+
482
+ return protein
483
+
484
+
485
+ @curry1
486
+ def make_fixed_size(
487
+ protein,
488
+ shape_schema,
489
+ msa_cluster_size,
490
+ extra_msa_size,
491
+ num_res=0,
492
+ num_templates=0,
493
+ ):
494
+ """Guess at the MSA and sequence dimension to make fixed size."""
495
+ pad_size_map = {
496
+ NUM_RES: num_res,
497
+ NUM_MSA_SEQ: msa_cluster_size,
498
+ NUM_EXTRA_SEQ: extra_msa_size,
499
+ NUM_TEMPLATES: num_templates,
500
+ }
501
+
502
+ for k, v in protein.items():
503
+ # Don't transfer this to the accelerator.
504
+ if k == "extra_cluster_assignment":
505
+ continue
506
+ shape = list(v.shape)
507
+ schema = shape_schema[k]
508
+ msg = "Rank mismatch between shape and shape schema for"
509
+ assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
510
+ pad_size = [
511
+ pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
512
+ ]
513
+
514
+ padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
515
+ padding.reverse()
516
+ padding = list(itertools.chain(*padding))
517
+ if padding:
518
+ protein[k] = torch.nn.functional.pad(v, padding)
519
+ protein[k] = torch.reshape(protein[k], pad_size)
520
+
521
+ return protein
522
+
523
+
524
+ @curry1
525
+ def make_msa_feat(protein):
526
+ """Create and concatenate MSA features."""
527
+ # Whether there is a domain break. Always zero for chains, but keeping for
528
+ # compatibility with domain datasets.
529
+ has_break = torch.clip(
530
+ protein["between_segment_residues"].to(torch.float32), 0, 1
531
+ )
532
+ aatype_1hot = make_one_hot(protein["aatype"], 21)
533
+
534
+ target_feat = [
535
+ torch.unsqueeze(has_break, dim=-1),
536
+ aatype_1hot, # Everyone gets the original sequence.
537
+ ]
538
+
539
+ msa_1hot = make_one_hot(protein["msa"], 23)
540
+ has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
541
+ deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
542
+ 2.0 / np.pi
543
+ )
544
+
545
+ msa_feat = [
546
+ msa_1hot,
547
+ torch.unsqueeze(has_deletion, dim=-1),
548
+ torch.unsqueeze(deletion_value, dim=-1),
549
+ ]
550
+
551
+ if "cluster_profile" in protein:
552
+ deletion_mean_value = torch.atan(
553
+ protein["cluster_deletion_mean"] / 3.0
554
+ ) * (2.0 / np.pi)
555
+ msa_feat.extend(
556
+ [
557
+ protein["cluster_profile"],
558
+ torch.unsqueeze(deletion_mean_value, dim=-1),
559
+ ]
560
+ )
561
+
562
+ if "extra_deletion_matrix" in protein:
563
+ protein["extra_has_deletion"] = torch.clip(
564
+ protein["extra_deletion_matrix"], 0.0, 1.0
565
+ )
566
+ protein["extra_deletion_value"] = torch.atan(
567
+ protein["extra_deletion_matrix"] / 3.0
568
+ ) * (2.0 / np.pi)
569
+
570
+ protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
571
+ protein["target_feat"] = torch.cat(target_feat, dim=-1)
572
+ return protein
573
+
574
+
575
+ @curry1
576
+ def select_feat(protein, feature_list):
577
+ return {k: v for k, v in protein.items() if k in feature_list}
578
+
579
+
580
+ @curry1
581
+ def crop_templates(protein, max_templates):
582
+ for k, v in protein.items():
583
+ if k.startswith("template_"):
584
+ protein[k] = v[:max_templates]
585
+ return protein
586
+
587
+
588
+ def make_atom14_masks(protein):
589
+ """Construct denser atom positions (14 dimensions instead of 37)."""
590
+ restype_atom14_to_atom37 = []
591
+ restype_atom37_to_atom14 = []
592
+ restype_atom14_mask = []
593
+
594
+ for rt in rc.restypes:
595
+ atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
596
+ restype_atom14_to_atom37.append(
597
+ [(rc.atom_order[name] if name else 0) for name in atom_names]
598
+ )
599
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
600
+ restype_atom37_to_atom14.append(
601
+ [
602
+ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
603
+ for name in rc.atom_types
604
+ ]
605
+ )
606
+
607
+ restype_atom14_mask.append(
608
+ [(1.0 if name else 0.0) for name in atom_names]
609
+ )
610
+
611
+ # Add dummy mapping for restype 'UNK'
612
+ restype_atom14_to_atom37.append([0] * 14)
613
+ restype_atom37_to_atom14.append([0] * 37)
614
+ restype_atom14_mask.append([0.0] * 14)
615
+
616
+ restype_atom14_to_atom37 = torch.tensor(
617
+ restype_atom14_to_atom37,
618
+ dtype=torch.int32,
619
+ device=protein["aatype"].device,
620
+ )
621
+ restype_atom37_to_atom14 = torch.tensor(
622
+ restype_atom37_to_atom14,
623
+ dtype=torch.int32,
624
+ device=protein["aatype"].device,
625
+ )
626
+ restype_atom14_mask = torch.tensor(
627
+ restype_atom14_mask,
628
+ dtype=torch.float32,
629
+ device=protein["aatype"].device,
630
+ )
631
+ protein_aatype = protein['aatype'].to(torch.long)
632
+
633
+ # create the mapping for (residx, atom14) --> atom37, i.e. an array
634
+ # with shape (num_res, 14) containing the atom37 indices for this protein
635
+ residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
636
+ residx_atom14_mask = restype_atom14_mask[protein_aatype]
637
+
638
+ protein["atom14_atom_exists"] = residx_atom14_mask
639
+ protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
640
+
641
+ # create the gather indices for mapping back
642
+ residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
643
+ protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
644
+
645
+ # create the corresponding mask
646
+ restype_atom37_mask = torch.zeros(
647
+ [21, 37], dtype=torch.float32, device=protein["aatype"].device
648
+ )
649
+ for restype, restype_letter in enumerate(rc.restypes):
650
+ restype_name = rc.restype_1to3[restype_letter]
651
+ atom_names = rc.residue_atoms[restype_name]
652
+ for atom_name in atom_names:
653
+ atom_type = rc.atom_order[atom_name]
654
+ restype_atom37_mask[restype, atom_type] = 1
655
+
656
+ residx_atom37_mask = restype_atom37_mask[protein_aatype]
657
+ protein["atom37_atom_exists"] = residx_atom37_mask
658
+
659
+ return protein
660
+
661
+
662
+ def make_atom14_masks_np(batch):
663
+ batch = tree_map(
664
+ lambda n: torch.tensor(n, device=batch["aatype"].device),
665
+ batch,
666
+ np.ndarray
667
+ )
668
+ out = make_atom14_masks(batch)
669
+ out = tensor_tree_map(lambda t: np.array(t), out)
670
+ return out
671
+
672
+
673
+ def make_atom14_positions(protein):
674
+ """Constructs denser atom positions (14 dimensions instead of 37)."""
675
+ residx_atom14_mask = protein["atom14_atom_exists"]
676
+ residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
677
+
678
+ # Create a mask for known ground truth positions.
679
+ residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
680
+ protein["all_atom_mask"],
681
+ residx_atom14_to_atom37,
682
+ dim=-1,
683
+ no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
684
+ )
685
+
686
+ # Gather the ground truth positions.
687
+ residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
688
+ batched_gather(
689
+ protein["all_atom_positions"],
690
+ residx_atom14_to_atom37,
691
+ dim=-2,
692
+ no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
693
+ )
694
+ )
695
+
696
+ protein["atom14_atom_exists"] = residx_atom14_mask
697
+ protein["atom14_gt_exists"] = residx_atom14_gt_mask
698
+ protein["atom14_gt_positions"] = residx_atom14_gt_positions
699
+
700
+ # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
701
+ # alternative ground truth coordinates where the naming is swapped
702
+ restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
703
+ restype_3 += ["UNK"]
704
+
705
+ # Matrices for renaming ambiguous atoms.
706
+ all_matrices = {
707
+ res: torch.eye(
708
+ 14,
709
+ dtype=protein["all_atom_mask"].dtype,
710
+ device=protein["all_atom_mask"].device,
711
+ )
712
+ for res in restype_3
713
+ }
714
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
715
+ correspondences = torch.arange(
716
+ 14, device=protein["all_atom_mask"].device
717
+ )
718
+ for source_atom_swap, target_atom_swap in swap.items():
719
+ source_index = rc.restype_name_to_atom14_names[resname].index(
720
+ source_atom_swap
721
+ )
722
+ target_index = rc.restype_name_to_atom14_names[resname].index(
723
+ target_atom_swap
724
+ )
725
+ correspondences[source_index] = target_index
726
+ correspondences[target_index] = source_index
727
+ renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
728
+ for index, correspondence in enumerate(correspondences):
729
+ renaming_matrix[index, correspondence] = 1.0
730
+ all_matrices[resname] = renaming_matrix
731
+
732
+ renaming_matrices = torch.stack(
733
+ [all_matrices[restype] for restype in restype_3]
734
+ )
735
+
736
+ # Pick the transformation matrices for the given residue sequence
737
+ # shape (num_res, 14, 14).
738
+ renaming_transform = renaming_matrices[protein["aatype"]]
739
+
740
+ # Apply it to the ground truth positions. shape (num_res, 14, 3).
741
+ alternative_gt_positions = torch.einsum(
742
+ "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
743
+ )
744
+ protein["atom14_alt_gt_positions"] = alternative_gt_positions
745
+
746
+ # Create the mask for the alternative ground truth (differs from the
747
+ # ground truth mask, if only one of the atoms in an ambiguous pair has a
748
+ # ground truth position).
749
+ alternative_gt_mask = torch.einsum(
750
+ "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
751
+ )
752
+ protein["atom14_alt_gt_exists"] = alternative_gt_mask
753
+
754
+ # Create an ambiguous atoms mask. shape: (21, 14).
755
+ restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
756
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
757
+ for atom_name1, atom_name2 in swap.items():
758
+ restype = rc.restype_order[rc.restype_3to1[resname]]
759
+ atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
760
+ atom_name1
761
+ )
762
+ atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
763
+ atom_name2
764
+ )
765
+ restype_atom14_is_ambiguous[restype, atom_idx1] = 1
766
+ restype_atom14_is_ambiguous[restype, atom_idx2] = 1
767
+
768
+ # From this create an ambiguous_mask for the given sequence.
769
+ protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
770
+ protein["aatype"]
771
+ ]
772
+
773
+ return protein
774
+
775
+
776
+ def atom37_to_frames(protein, eps=1e-8):
777
+ aatype = protein["aatype"]
778
+ all_atom_positions = protein["all_atom_positions"]
779
+ all_atom_mask = protein["all_atom_mask"]
780
+
781
+ batch_dims = len(aatype.shape[:-1])
782
+
783
+ restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
784
+ restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
785
+ restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
786
+
787
+ for restype, restype_letter in enumerate(rc.restypes):
788
+ resname = rc.restype_1to3[restype_letter]
789
+ for chi_idx in range(4):
790
+ if rc.chi_angles_mask[restype][chi_idx]:
791
+ names = rc.chi_angles_atoms[resname][chi_idx]
792
+ restype_rigidgroup_base_atom_names[
793
+ restype, chi_idx + 4, :
794
+ ] = names[1:]
795
+
796
+ restype_rigidgroup_mask = all_atom_mask.new_zeros(
797
+ (*aatype.shape[:-1], 21, 8),
798
+ )
799
+ restype_rigidgroup_mask[..., 0] = 1
800
+ restype_rigidgroup_mask[..., 3] = 1
801
+ restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
802
+ rc.chi_angles_mask
803
+ )
804
+
805
+ lookuptable = rc.atom_order.copy()
806
+ lookuptable[""] = 0
807
+ lookup = np.vectorize(lambda x: lookuptable[x])
808
+ restype_rigidgroup_base_atom37_idx = lookup(
809
+ restype_rigidgroup_base_atom_names,
810
+ )
811
+ restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
812
+ restype_rigidgroup_base_atom37_idx,
813
+ )
814
+ restype_rigidgroup_base_atom37_idx = (
815
+ restype_rigidgroup_base_atom37_idx.view(
816
+ *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
817
+ )
818
+ )
819
+
820
+ residx_rigidgroup_base_atom37_idx = batched_gather(
821
+ restype_rigidgroup_base_atom37_idx,
822
+ aatype,
823
+ dim=-3,
824
+ no_batch_dims=batch_dims,
825
+ )
826
+
827
+ base_atom_pos = batched_gather(
828
+ all_atom_positions,
829
+ residx_rigidgroup_base_atom37_idx,
830
+ dim=-2,
831
+ no_batch_dims=len(all_atom_positions.shape[:-2]),
832
+ )
833
+
834
+ gt_frames = Rigid.from_3_points(
835
+ p_neg_x_axis=base_atom_pos[..., 0, :],
836
+ origin=base_atom_pos[..., 1, :],
837
+ p_xy_plane=base_atom_pos[..., 2, :],
838
+ eps=eps,
839
+ )
840
+
841
+ group_exists = batched_gather(
842
+ restype_rigidgroup_mask,
843
+ aatype,
844
+ dim=-2,
845
+ no_batch_dims=batch_dims,
846
+ )
847
+
848
+ gt_atoms_exist = batched_gather(
849
+ all_atom_mask,
850
+ residx_rigidgroup_base_atom37_idx,
851
+ dim=-1,
852
+ no_batch_dims=len(all_atom_mask.shape[:-1]),
853
+ )
854
+ gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
855
+
856
+ rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
857
+ rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
858
+ rots[..., 0, 0, 0] = -1
859
+ rots[..., 0, 2, 2] = -1
860
+ rots = Rotation(rot_mats=rots)
861
+
862
+ gt_frames = gt_frames.compose(Rigid(rots, None))
863
+
864
+ restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
865
+ *((1,) * batch_dims), 21, 8
866
+ )
867
+ restype_rigidgroup_rots = torch.eye(
868
+ 3, dtype=all_atom_mask.dtype, device=aatype.device
869
+ )
870
+ restype_rigidgroup_rots = torch.tile(
871
+ restype_rigidgroup_rots,
872
+ (*((1,) * batch_dims), 21, 8, 1, 1),
873
+ )
874
+
875
+ for resname, _ in rc.residue_atom_renaming_swaps.items():
876
+ restype = rc.restype_order[rc.restype_3to1[resname]]
877
+ chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
878
+ restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
879
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
880
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
881
+
882
+ residx_rigidgroup_is_ambiguous = batched_gather(
883
+ restype_rigidgroup_is_ambiguous,
884
+ aatype,
885
+ dim=-2,
886
+ no_batch_dims=batch_dims,
887
+ )
888
+
889
+ residx_rigidgroup_ambiguity_rot = batched_gather(
890
+ restype_rigidgroup_rots,
891
+ aatype,
892
+ dim=-4,
893
+ no_batch_dims=batch_dims,
894
+ )
895
+
896
+ residx_rigidgroup_ambiguity_rot = Rotation(
897
+ rot_mats=residx_rigidgroup_ambiguity_rot
898
+ )
899
+ alt_gt_frames = gt_frames.compose(
900
+ Rigid(residx_rigidgroup_ambiguity_rot, None)
901
+ )
902
+
903
+ gt_frames_tensor = gt_frames.to_tensor_4x4()
904
+ alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
905
+
906
+ protein["rigidgroups_gt_frames"] = gt_frames_tensor
907
+ protein["rigidgroups_gt_exists"] = gt_exists
908
+ protein["rigidgroups_group_exists"] = group_exists
909
+ protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
910
+ protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
911
+
912
+ return protein
913
+
914
+
915
+ def get_chi_atom_indices():
916
+ """Returns atom indices needed to compute chi angles for all residue types.
917
+
918
+ Returns:
919
+ A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
920
+ in the order specified in rc.restypes + unknown residue type
921
+ at the end. For chi angles which are not defined on the residue, the
922
+ positions indices are by default set to 0.
923
+ """
924
+ chi_atom_indices = []
925
+ for residue_name in rc.restypes:
926
+ residue_name = rc.restype_1to3[residue_name]
927
+ residue_chi_angles = rc.chi_angles_atoms[residue_name]
928
+ atom_indices = []
929
+ for chi_angle in residue_chi_angles:
930
+ atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
931
+ for _ in range(4 - len(atom_indices)):
932
+ atom_indices.append(
933
+ [0, 0, 0, 0]
934
+ ) # For chi angles not defined on the AA.
935
+ chi_atom_indices.append(atom_indices)
936
+
937
+ chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
938
+
939
+ return chi_atom_indices
940
+
941
+
942
+ @curry1
943
+ def atom37_to_torsion_angles(
944
+ protein,
945
+ prefix="",
946
+ ):
947
+ """
948
+ Convert coordinates to torsion angles.
949
+
950
+ This function is extremely sensitive to floating point imprecisions
951
+ and should be run with double precision whenever possible.
952
+
953
+ Args:
954
+ Dict containing:
955
+ * (prefix)aatype:
956
+ [*, N_res] residue indices
957
+ * (prefix)all_atom_positions:
958
+ [*, N_res, 37, 3] atom positions (in atom37
959
+ format)
960
+ * (prefix)all_atom_mask:
961
+ [*, N_res, 37] atom position mask
962
+ Returns:
963
+ The same dictionary updated with the following features:
964
+
965
+ "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
966
+ Torsion angles
967
+ "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
968
+ Alternate torsion angles (accounting for 180-degree symmetry)
969
+ "(prefix)torsion_angles_mask" ([*, N_res, 7])
970
+ Torsion angles mask
971
+ """
972
+ aatype = protein[prefix + "aatype"]
973
+ all_atom_positions = protein[prefix + "all_atom_positions"]
974
+ all_atom_mask = protein[prefix + "all_atom_mask"]
975
+
976
+ aatype = torch.clamp(aatype, max=20)
977
+
978
+ pad = all_atom_positions.new_zeros(
979
+ [*all_atom_positions.shape[:-3], 1, 37, 3]
980
+ )
981
+ prev_all_atom_positions = torch.cat(
982
+ [pad, all_atom_positions[..., :-1, :, :]], dim=-3
983
+ )
984
+
985
+ pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
986
+ prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
987
+
988
+ pre_omega_atom_pos = torch.cat(
989
+ [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
990
+ dim=-2,
991
+ )
992
+ phi_atom_pos = torch.cat(
993
+ [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
994
+ dim=-2,
995
+ )
996
+ psi_atom_pos = torch.cat(
997
+ [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
998
+ dim=-2,
999
+ )
1000
+
1001
+ pre_omega_mask = torch.prod(
1002
+ prev_all_atom_mask[..., 1:3], dim=-1
1003
+ ) * torch.prod(all_atom_mask[..., :2], dim=-1)
1004
+ phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
1005
+ all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
1006
+ )
1007
+ psi_mask = (
1008
+ torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
1009
+ * all_atom_mask[..., 4]
1010
+ )
1011
+
1012
+ chi_atom_indices = torch.as_tensor(
1013
+ get_chi_atom_indices(), device=aatype.device
1014
+ )
1015
+
1016
+ atom_indices = chi_atom_indices[..., aatype, :, :]
1017
+ chis_atom_pos = batched_gather(
1018
+ all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
1019
+ )
1020
+
1021
+ chi_angles_mask = list(rc.chi_angles_mask)
1022
+ chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
1023
+ chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
1024
+
1025
+ chis_mask = chi_angles_mask[aatype, :]
1026
+
1027
+ chi_angle_atoms_mask = batched_gather(
1028
+ all_atom_mask,
1029
+ atom_indices,
1030
+ dim=-1,
1031
+ no_batch_dims=len(atom_indices.shape[:-2]),
1032
+ )
1033
+ chi_angle_atoms_mask = torch.prod(
1034
+ chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
1035
+ )
1036
+ chis_mask = chis_mask * chi_angle_atoms_mask
1037
+
1038
+ torsions_atom_pos = torch.cat(
1039
+ [
1040
+ pre_omega_atom_pos[..., None, :, :],
1041
+ phi_atom_pos[..., None, :, :],
1042
+ psi_atom_pos[..., None, :, :],
1043
+ chis_atom_pos,
1044
+ ],
1045
+ dim=-3,
1046
+ )
1047
+
1048
+ torsion_angles_mask = torch.cat(
1049
+ [
1050
+ pre_omega_mask[..., None],
1051
+ phi_mask[..., None],
1052
+ psi_mask[..., None],
1053
+ chis_mask,
1054
+ ],
1055
+ dim=-1,
1056
+ )
1057
+
1058
+ torsion_frames = Rigid.from_3_points(
1059
+ torsions_atom_pos[..., 1, :],
1060
+ torsions_atom_pos[..., 2, :],
1061
+ torsions_atom_pos[..., 0, :],
1062
+ eps=1e-8,
1063
+ )
1064
+
1065
+ fourth_atom_rel_pos = torsion_frames.invert().apply(
1066
+ torsions_atom_pos[..., 3, :]
1067
+ )
1068
+
1069
+ torsion_angles_sin_cos = torch.stack(
1070
+ [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
1071
+ )
1072
+
1073
+ denom = torch.sqrt(
1074
+ torch.sum(
1075
+ torch.square(torsion_angles_sin_cos),
1076
+ dim=-1,
1077
+ dtype=torsion_angles_sin_cos.dtype,
1078
+ keepdims=True,
1079
+ )
1080
+ + 1e-8
1081
+ )
1082
+ torsion_angles_sin_cos = torsion_angles_sin_cos / denom
1083
+
1084
+ torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
1085
+ [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
1086
+ )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
1087
+
1088
+ chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
1089
+ rc.chi_pi_periodic,
1090
+ )[aatype, ...]
1091
+
1092
+ mirror_torsion_angles = torch.cat(
1093
+ [
1094
+ all_atom_mask.new_ones(*aatype.shape, 3),
1095
+ 1.0 - 2.0 * chi_is_ambiguous,
1096
+ ],
1097
+ dim=-1,
1098
+ )
1099
+
1100
+ alt_torsion_angles_sin_cos = (
1101
+ torsion_angles_sin_cos * mirror_torsion_angles[..., None]
1102
+ )
1103
+
1104
+ protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
1105
+ protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
1106
+ protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
1107
+
1108
+ return protein
1109
+
1110
+
1111
+ def get_backbone_frames(protein):
1112
+ # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
1113
+ protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
1114
+ ..., 0, :, :
1115
+ ]
1116
+ protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
1117
+
1118
+ return protein
1119
+
1120
+
1121
+ def get_chi_angles(protein):
1122
+ dtype = protein["all_atom_mask"].dtype
1123
+ protein["chi_angles_sin_cos"] = (
1124
+ protein["torsion_angles_sin_cos"][..., 3:, :]
1125
+ ).to(dtype)
1126
+ protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
1127
+
1128
+ return protein
1129
+
1130
+
1131
+ @curry1
1132
+ def random_crop_to_size(
1133
+ protein,
1134
+ crop_size,
1135
+ max_templates,
1136
+ shape_schema,
1137
+ subsample_templates=False,
1138
+ seed=None,
1139
+ ):
1140
+ """Crop randomly to `crop_size`, or keep as is if shorter than that."""
1141
+ # We want each ensemble to be cropped the same way
1142
+ g = torch.Generator(device=protein["seq_length"].device)
1143
+ if seed is not None:
1144
+ g.manual_seed(seed)
1145
+
1146
+ seq_length = protein["seq_length"]
1147
+
1148
+ if "template_mask" in protein:
1149
+ num_templates = protein["template_mask"].shape[-1]
1150
+ else:
1151
+ num_templates = 0
1152
+
1153
+ # No need to subsample templates if there aren't any
1154
+ subsample_templates = subsample_templates and num_templates
1155
+
1156
+ num_res_crop_size = min(int(seq_length), crop_size)
1157
+
1158
+ def _randint(lower, upper):
1159
+ return int(torch.randint(
1160
+ lower,
1161
+ upper + 1,
1162
+ (1,),
1163
+ device=protein["seq_length"].device,
1164
+ generator=g,
1165
+ )[0])
1166
+
1167
+ if subsample_templates:
1168
+ templates_crop_start = _randint(0, num_templates)
1169
+ templates_select_indices = torch.randperm(
1170
+ num_templates, device=protein["seq_length"].device, generator=g
1171
+ )
1172
+ else:
1173
+ templates_crop_start = 0
1174
+
1175
+ num_templates_crop_size = min(
1176
+ num_templates - templates_crop_start, max_templates
1177
+ )
1178
+
1179
+ n = seq_length - num_res_crop_size
1180
+ if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
1181
+ right_anchor = n
1182
+ else:
1183
+ x = _randint(0, n)
1184
+ right_anchor = n - x
1185
+
1186
+ num_res_crop_start = _randint(0, right_anchor)
1187
+
1188
+ for k, v in protein.items():
1189
+ if k not in shape_schema or (
1190
+ "template" not in k and NUM_RES not in shape_schema[k]
1191
+ ):
1192
+ continue
1193
+
1194
+ # randomly permute the templates before cropping them.
1195
+ if k.startswith("template") and subsample_templates:
1196
+ v = v[templates_select_indices]
1197
+
1198
+ slices = []
1199
+ for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
1200
+ is_num_res = dim_size == NUM_RES
1201
+ if i == 0 and k.startswith("template"):
1202
+ crop_size = num_templates_crop_size
1203
+ crop_start = templates_crop_start
1204
+ else:
1205
+ crop_start = num_res_crop_start if is_num_res else 0
1206
+ crop_size = num_res_crop_size if is_num_res else dim
1207
+ slices.append(slice(crop_start, crop_start + crop_size))
1208
+ protein[k] = v[slices]
1209
+
1210
+ protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
1211
+
1212
+ return protein
openfold/data/errors.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """General-purpose errors used throughout the data pipeline"""
17
+ class Error(Exception):
18
+ """Base class for exceptions."""
19
+
20
+
21
+ class MultipleChainsError(Error):
22
+ """An error indicating that multiple chains were found for a given ID."""
openfold/data/feature_pipeline.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ from typing import Mapping, Tuple, List, Optional, Dict, Sequence
18
+
19
+ import ml_collections
20
+ import numpy as np
21
+ import torch
22
+
23
+ from openfold.data import input_pipeline
24
+
25
+
26
+ FeatureDict = Mapping[str, np.ndarray]
27
+ TensorDict = Dict[str, torch.Tensor]
28
+
29
+
30
+ def np_to_tensor_dict(
31
+ np_example: Mapping[str, np.ndarray],
32
+ features: Sequence[str],
33
+ ) -> TensorDict:
34
+ """Creates dict of tensors from a dict of NumPy arrays.
35
+
36
+ Args:
37
+ np_example: A dict of NumPy feature arrays.
38
+ features: A list of strings of feature names to be returned in the dataset.
39
+
40
+ Returns:
41
+ A dictionary of features mapping feature names to features. Only the given
42
+ features are returned, all other ones are filtered out.
43
+ """
44
+ tensor_dict = {
45
+ k: torch.tensor(v) for k, v in np_example.items() if k in features
46
+ }
47
+
48
+ return tensor_dict
49
+
50
+
51
+ def make_data_config(
52
+ config: ml_collections.ConfigDict,
53
+ mode: str,
54
+ num_res: int,
55
+ ) -> Tuple[ml_collections.ConfigDict, List[str]]:
56
+ cfg = copy.deepcopy(config)
57
+ mode_cfg = cfg[mode]
58
+ with cfg.unlocked():
59
+ if mode_cfg.crop_size is None:
60
+ mode_cfg.crop_size = num_res
61
+
62
+ feature_names = cfg.common.unsupervised_features
63
+
64
+ if cfg.common.use_templates:
65
+ feature_names += cfg.common.template_features
66
+
67
+ if cfg[mode].supervised:
68
+ feature_names += cfg.supervised.supervised_features
69
+
70
+ return cfg, feature_names
71
+
72
+
73
+ def np_example_to_features(
74
+ np_example: FeatureDict,
75
+ config: ml_collections.ConfigDict,
76
+ mode: str,
77
+ ):
78
+ np_example = dict(np_example)
79
+ num_res = int(np_example["seq_length"][0])
80
+ cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
81
+
82
+ if "deletion_matrix_int" in np_example:
83
+ np_example["deletion_matrix"] = np_example.pop(
84
+ "deletion_matrix_int"
85
+ ).astype(np.float32)
86
+
87
+ tensor_dict = np_to_tensor_dict(
88
+ np_example=np_example, features=feature_names
89
+ )
90
+ with torch.no_grad():
91
+ features = input_pipeline.process_tensors_from_config(
92
+ tensor_dict,
93
+ cfg.common,
94
+ cfg[mode],
95
+ )
96
+
97
+ return {k: v for k, v in features.items()}
98
+
99
+
100
+ class FeaturePipeline:
101
+ def __init__(
102
+ self,
103
+ config: ml_collections.ConfigDict,
104
+ ):
105
+ self.config = config
106
+
107
+ def process_features(
108
+ self,
109
+ raw_features: FeatureDict,
110
+ mode: str = "train",
111
+ ) -> FeatureDict:
112
+ return np_example_to_features(
113
+ np_example=raw_features,
114
+ config=self.config,
115
+ mode=mode,
116
+ )
openfold/data/input_pipeline.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+
18
+ import torch
19
+
20
+ from openfold.data import data_transforms
21
+
22
+
23
+ def nonensembled_transform_fns(common_cfg, mode_cfg):
24
+ """Input pipeline data transformers that are not ensembled."""
25
+ transforms = [
26
+ data_transforms.cast_to_64bit_ints,
27
+ data_transforms.correct_msa_restypes,
28
+ data_transforms.squeeze_features,
29
+ data_transforms.randomly_replace_msa_with_unknown(0.0),
30
+ data_transforms.make_seq_mask,
31
+ data_transforms.make_msa_mask,
32
+ data_transforms.make_hhblits_profile,
33
+ ]
34
+ if common_cfg.use_templates:
35
+ transforms.extend(
36
+ [
37
+ data_transforms.fix_templates_aatype,
38
+ data_transforms.make_template_mask,
39
+ data_transforms.make_pseudo_beta("template_"),
40
+ ]
41
+ )
42
+ if common_cfg.use_template_torsion_angles:
43
+ transforms.extend(
44
+ [
45
+ data_transforms.atom37_to_torsion_angles("template_"),
46
+ ]
47
+ )
48
+
49
+ transforms.extend(
50
+ [
51
+ data_transforms.make_atom14_masks,
52
+ ]
53
+ )
54
+
55
+ if mode_cfg.supervised:
56
+ transforms.extend(
57
+ [
58
+ data_transforms.make_atom14_positions,
59
+ data_transforms.atom37_to_frames,
60
+ data_transforms.atom37_to_torsion_angles(""),
61
+ data_transforms.make_pseudo_beta(""),
62
+ data_transforms.get_backbone_frames,
63
+ data_transforms.get_chi_angles,
64
+ ]
65
+ )
66
+
67
+ return transforms
68
+
69
+
70
+ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
71
+ """Input pipeline data transformers that can be ensembled and averaged."""
72
+ transforms = []
73
+
74
+ if "max_distillation_msa_clusters" in mode_cfg:
75
+ transforms.append(
76
+ data_transforms.sample_msa_distillation(
77
+ mode_cfg.max_distillation_msa_clusters
78
+ )
79
+ )
80
+
81
+ if common_cfg.reduce_msa_clusters_by_max_templates:
82
+ pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
83
+ else:
84
+ pad_msa_clusters = mode_cfg.max_msa_clusters
85
+
86
+ max_msa_clusters = pad_msa_clusters
87
+ max_extra_msa = mode_cfg.max_extra_msa
88
+
89
+ msa_seed = None
90
+ if(not common_cfg.resample_msa_in_recycling):
91
+ msa_seed = ensemble_seed
92
+
93
+ transforms.append(
94
+ data_transforms.sample_msa(
95
+ max_msa_clusters,
96
+ keep_extra=True,
97
+ seed=msa_seed,
98
+ )
99
+ )
100
+
101
+ if "masked_msa" in common_cfg:
102
+ # Masked MSA should come *before* MSA clustering so that
103
+ # the clustering and full MSA profile do not leak information about
104
+ # the masked locations and secret corrupted locations.
105
+ transforms.append(
106
+ data_transforms.make_masked_msa(
107
+ common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
108
+ )
109
+ )
110
+
111
+ if common_cfg.msa_cluster_features:
112
+ transforms.append(data_transforms.nearest_neighbor_clusters())
113
+ transforms.append(data_transforms.summarize_clusters())
114
+
115
+ # Crop after creating the cluster profiles.
116
+ if max_extra_msa:
117
+ transforms.append(data_transforms.crop_extra_msa(max_extra_msa))
118
+ else:
119
+ transforms.append(data_transforms.delete_extra_msa)
120
+
121
+ transforms.append(data_transforms.make_msa_feat())
122
+
123
+ crop_feats = dict(common_cfg.feat)
124
+
125
+ if mode_cfg.fixed_size:
126
+ transforms.append(data_transforms.select_feat(list(crop_feats)))
127
+ transforms.append(
128
+ data_transforms.random_crop_to_size(
129
+ mode_cfg.crop_size,
130
+ mode_cfg.max_templates,
131
+ crop_feats,
132
+ mode_cfg.subsample_templates,
133
+ seed=ensemble_seed + 1,
134
+ )
135
+ )
136
+ transforms.append(
137
+ data_transforms.make_fixed_size(
138
+ crop_feats,
139
+ pad_msa_clusters,
140
+ mode_cfg.max_extra_msa,
141
+ mode_cfg.crop_size,
142
+ mode_cfg.max_templates,
143
+ )
144
+ )
145
+ else:
146
+ transforms.append(
147
+ data_transforms.crop_templates(mode_cfg.max_templates)
148
+ )
149
+
150
+ return transforms
151
+
152
+
153
+ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
154
+ """Based on the config, apply filters and transformations to the data."""
155
+
156
+ ensemble_seed = torch.Generator().seed()
157
+
158
+ def wrap_ensemble_fn(data, i):
159
+ """Function to be mapped over the ensemble dimension."""
160
+ d = data.copy()
161
+ fns = ensembled_transform_fns(
162
+ common_cfg,
163
+ mode_cfg,
164
+ ensemble_seed,
165
+ )
166
+ fn = compose(fns)
167
+ d["ensemble_index"] = i
168
+ return fn(d)
169
+
170
+ no_templates = True
171
+ if("template_aatype" in tensors):
172
+ no_templates = tensors["template_aatype"].shape[0] == 0
173
+
174
+ nonensembled = nonensembled_transform_fns(
175
+ common_cfg,
176
+ mode_cfg,
177
+ )
178
+
179
+ tensors = compose(nonensembled)(tensors)
180
+
181
+ if("no_recycling_iters" in tensors):
182
+ num_recycling = int(tensors["no_recycling_iters"])
183
+ else:
184
+ num_recycling = common_cfg.max_recycling_iters
185
+
186
+ tensors = map_fn(
187
+ lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
188
+ )
189
+
190
+ return tensors
191
+
192
+
193
+ @data_transforms.curry1
194
+ def compose(x, fs):
195
+ for f in fs:
196
+ x = f(x)
197
+ return x
198
+
199
+
200
+ def map_fn(fun, x):
201
+ ensembles = [fun(elem) for elem in x]
202
+ features = ensembles[0].keys()
203
+ ensembled_dict = {}
204
+ for feat in features:
205
+ ensembled_dict[feat] = torch.stack(
206
+ [dict_i[feat] for dict_i in ensembles], dim=-1
207
+ )
208
+ return ensembled_dict
openfold/data/mmcif_parsing.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Parses the mmCIF file format."""
17
+ import collections
18
+ import dataclasses
19
+ import io
20
+ import json
21
+ import logging
22
+ import os
23
+ from typing import Any, Mapping, Optional, Sequence, Tuple
24
+
25
+ from Bio import PDB
26
+ from Bio.Data import SCOPData
27
+ import numpy as np
28
+
29
+ from openfold.data.errors import MultipleChainsError
30
+ import openfold.np.residue_constants as residue_constants
31
+
32
+
33
+ # Type aliases:
34
+ ChainId = str
35
+ PdbHeader = Mapping[str, Any]
36
+ PdbStructure = PDB.Structure.Structure
37
+ SeqRes = str
38
+ MmCIFDict = Mapping[str, Sequence[str]]
39
+
40
+
41
+ @dataclasses.dataclass(frozen=True)
42
+ class Monomer:
43
+ id: str
44
+ num: int
45
+
46
+
47
+ # Note - mmCIF format provides no guarantees on the type of author-assigned
48
+ # sequence numbers. They need not be integers.
49
+ @dataclasses.dataclass(frozen=True)
50
+ class AtomSite:
51
+ residue_name: str
52
+ author_chain_id: str
53
+ mmcif_chain_id: str
54
+ author_seq_num: str
55
+ mmcif_seq_num: int
56
+ insertion_code: str
57
+ hetatm_atom: str
58
+ model_num: int
59
+
60
+
61
+ # Used to map SEQRES index to a residue in the structure.
62
+ @dataclasses.dataclass(frozen=True)
63
+ class ResiduePosition:
64
+ chain_id: str
65
+ residue_number: int
66
+ insertion_code: str
67
+
68
+
69
+ @dataclasses.dataclass(frozen=True)
70
+ class ResidueAtPosition:
71
+ position: Optional[ResiduePosition]
72
+ name: str
73
+ is_missing: bool
74
+ hetflag: str
75
+
76
+
77
+ @dataclasses.dataclass(frozen=True)
78
+ class MmcifObject:
79
+ """Representation of a parsed mmCIF file.
80
+
81
+ Contains:
82
+ file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
83
+ files being processed.
84
+ header: Biopython header.
85
+ structure: Biopython structure.
86
+ chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
87
+ {'A': 'ABCDEFG'}
88
+ seqres_to_structure: Dict; for each chain_id contains a mapping between
89
+ SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
90
+ 1: ResidueAtPosition,
91
+ ...}}
92
+ raw_string: The raw string used to construct the MmcifObject.
93
+ """
94
+
95
+ file_id: str
96
+ header: PdbHeader
97
+ structure: PdbStructure
98
+ chain_to_seqres: Mapping[ChainId, SeqRes]
99
+ seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
100
+ raw_string: Any
101
+
102
+
103
+ @dataclasses.dataclass(frozen=True)
104
+ class ParsingResult:
105
+ """Returned by the parse function.
106
+
107
+ Contains:
108
+ mmcif_object: A MmcifObject, may be None if no chain could be successfully
109
+ parsed.
110
+ errors: A dict mapping (file_id, chain_id) to any exception generated.
111
+ """
112
+
113
+ mmcif_object: Optional[MmcifObject]
114
+ errors: Mapping[Tuple[str, str], Any]
115
+
116
+
117
+ class ParseError(Exception):
118
+ """An error indicating that an mmCIF file could not be parsed."""
119
+
120
+
121
+ def mmcif_loop_to_list(
122
+ prefix: str, parsed_info: MmCIFDict
123
+ ) -> Sequence[Mapping[str, str]]:
124
+ """Extracts loop associated with a prefix from mmCIF data as a list.
125
+
126
+ Reference for loop_ in mmCIF:
127
+ http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
128
+
129
+ Args:
130
+ prefix: Prefix shared by each of the data items in the loop.
131
+ e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
132
+ _entity_poly_seq.mon_id. Should include the trailing period.
133
+ parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
134
+ parser.
135
+
136
+ Returns:
137
+ Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
138
+ """
139
+ cols = []
140
+ data = []
141
+ for key, value in parsed_info.items():
142
+ if key.startswith(prefix):
143
+ cols.append(key)
144
+ data.append(value)
145
+
146
+ assert all([len(xs) == len(data[0]) for xs in data]), (
147
+ "mmCIF error: Not all loops are the same length: %s" % cols
148
+ )
149
+
150
+ return [dict(zip(cols, xs)) for xs in zip(*data)]
151
+
152
+
153
+ def mmcif_loop_to_dict(
154
+ prefix: str,
155
+ index: str,
156
+ parsed_info: MmCIFDict,
157
+ ) -> Mapping[str, Mapping[str, str]]:
158
+ """Extracts loop associated with a prefix from mmCIF data as a dictionary.
159
+
160
+ Args:
161
+ prefix: Prefix shared by each of the data items in the loop.
162
+ e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
163
+ _entity_poly_seq.mon_id. Should include the trailing period.
164
+ index: Which item of loop data should serve as the key.
165
+ parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
166
+ parser.
167
+
168
+ Returns:
169
+ Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
170
+ indexed by the index column.
171
+ """
172
+ entries = mmcif_loop_to_list(prefix, parsed_info)
173
+ return {entry[index]: entry for entry in entries}
174
+
175
+
176
+ def parse(
177
+ *, file_id: str, mmcif_string: str, catch_all_errors: bool = True
178
+ ) -> ParsingResult:
179
+ """Entry point, parses an mmcif_string.
180
+
181
+ Args:
182
+ file_id: A string identifier for this file. Should be unique within the
183
+ collection of files being processed.
184
+ mmcif_string: Contents of an mmCIF file.
185
+ catch_all_errors: If True, all exceptions are caught and error messages are
186
+ returned as part of the ParsingResult. If False exceptions will be allowed
187
+ to propagate.
188
+
189
+ Returns:
190
+ A ParsingResult.
191
+ """
192
+ errors = {}
193
+ try:
194
+ parser = PDB.MMCIFParser(QUIET=True)
195
+ handle = io.StringIO(mmcif_string)
196
+ full_structure = parser.get_structure("", handle)
197
+ first_model_structure = _get_first_model(full_structure)
198
+ # Extract the _mmcif_dict from the parser, which contains useful fields not
199
+ # reflected in the Biopython structure.
200
+ parsed_info = parser._mmcif_dict # pylint:disable=protected-access
201
+
202
+ # Ensure all values are lists, even if singletons.
203
+ for key, value in parsed_info.items():
204
+ if not isinstance(value, list):
205
+ parsed_info[key] = [value]
206
+
207
+ header = _get_header(parsed_info)
208
+
209
+ # Determine the protein chains, and their start numbers according to the
210
+ # internal mmCIF numbering scheme (likely but not guaranteed to be 1).
211
+ valid_chains = _get_protein_chains(parsed_info=parsed_info)
212
+ if not valid_chains:
213
+ return ParsingResult(
214
+ None, {(file_id, ""): "No protein chains found in this file."}
215
+ )
216
+ seq_start_num = {
217
+ chain_id: min([monomer.num for monomer in seq])
218
+ for chain_id, seq in valid_chains.items()
219
+ }
220
+
221
+ # Loop over the atoms for which we have coordinates. Populate two mappings:
222
+ # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
223
+ # the authors / Biopython).
224
+ # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
225
+ mmcif_to_author_chain_id = {}
226
+ seq_to_structure_mappings = {}
227
+ for atom in _get_atom_site_list(parsed_info):
228
+ if atom.model_num != "1":
229
+ # We only process the first model at the moment.
230
+ continue
231
+
232
+ mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
233
+
234
+ if atom.mmcif_chain_id in valid_chains:
235
+ hetflag = " "
236
+ if atom.hetatm_atom == "HETATM":
237
+ # Water atoms are assigned a special hetflag of W in Biopython. We
238
+ # need to do the same, so that this hetflag can be used to fetch
239
+ # a residue from the Biopython structure by id.
240
+ if atom.residue_name in ("HOH", "WAT"):
241
+ hetflag = "W"
242
+ else:
243
+ hetflag = "H_" + atom.residue_name
244
+ insertion_code = atom.insertion_code
245
+ if not _is_set(atom.insertion_code):
246
+ insertion_code = " "
247
+ position = ResiduePosition(
248
+ chain_id=atom.author_chain_id,
249
+ residue_number=int(atom.author_seq_num),
250
+ insertion_code=insertion_code,
251
+ )
252
+ seq_idx = (
253
+ int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
254
+ )
255
+ current = seq_to_structure_mappings.get(
256
+ atom.author_chain_id, {}
257
+ )
258
+ current[seq_idx] = ResidueAtPosition(
259
+ position=position,
260
+ name=atom.residue_name,
261
+ is_missing=False,
262
+ hetflag=hetflag,
263
+ )
264
+ seq_to_structure_mappings[atom.author_chain_id] = current
265
+
266
+ # Add missing residue information to seq_to_structure_mappings.
267
+ for chain_id, seq_info in valid_chains.items():
268
+ author_chain = mmcif_to_author_chain_id[chain_id]
269
+ current_mapping = seq_to_structure_mappings[author_chain]
270
+ for idx, monomer in enumerate(seq_info):
271
+ if idx not in current_mapping:
272
+ current_mapping[idx] = ResidueAtPosition(
273
+ position=None,
274
+ name=monomer.id,
275
+ is_missing=True,
276
+ hetflag=" ",
277
+ )
278
+
279
+ author_chain_to_sequence = {}
280
+ for chain_id, seq_info in valid_chains.items():
281
+ author_chain = mmcif_to_author_chain_id[chain_id]
282
+ seq = []
283
+ for monomer in seq_info:
284
+ code = SCOPData.protein_letters_3to1.get(monomer.id, "X")
285
+ seq.append(code if len(code) == 1 else "X")
286
+ seq = "".join(seq)
287
+ author_chain_to_sequence[author_chain] = seq
288
+
289
+ mmcif_object = MmcifObject(
290
+ file_id=file_id,
291
+ header=header,
292
+ structure=first_model_structure,
293
+ chain_to_seqres=author_chain_to_sequence,
294
+ seqres_to_structure=seq_to_structure_mappings,
295
+ raw_string=parsed_info,
296
+ )
297
+
298
+ return ParsingResult(mmcif_object=mmcif_object, errors=errors)
299
+ except Exception as e: # pylint:disable=broad-except
300
+ errors[(file_id, "")] = e
301
+ if not catch_all_errors:
302
+ raise
303
+ return ParsingResult(mmcif_object=None, errors=errors)
304
+
305
+
306
+ def _get_first_model(structure: PdbStructure) -> PdbStructure:
307
+ """Returns the first model in a Biopython structure."""
308
+ return next(structure.get_models())
309
+
310
+
311
+ _MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
312
+
313
+
314
+ def get_release_date(parsed_info: MmCIFDict) -> str:
315
+ """Returns the oldest revision date."""
316
+ revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"]
317
+ return min(revision_dates)
318
+
319
+
320
+ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
321
+ """Returns a basic header containing method, release date and resolution."""
322
+ header = {}
323
+
324
+ experiments = mmcif_loop_to_list("_exptl.", parsed_info)
325
+ header["structure_method"] = ",".join(
326
+ [experiment["_exptl.method"].lower() for experiment in experiments]
327
+ )
328
+
329
+ # Note: The release_date here corresponds to the oldest revision. We prefer to
330
+ # use this for dataset filtering over the deposition_date.
331
+ if "_pdbx_audit_revision_history.revision_date" in parsed_info:
332
+ header["release_date"] = get_release_date(parsed_info)
333
+ else:
334
+ logging.warning(
335
+ "Could not determine release_date: %s", parsed_info["_entry.id"]
336
+ )
337
+
338
+ header["resolution"] = 0.00
339
+ for res_key in (
340
+ "_refine.ls_d_res_high",
341
+ "_em_3d_reconstruction.resolution",
342
+ "_reflns.d_resolution_high",
343
+ ):
344
+ if res_key in parsed_info:
345
+ try:
346
+ raw_resolution = parsed_info[res_key][0]
347
+ header["resolution"] = float(raw_resolution)
348
+ except ValueError:
349
+ logging.info(
350
+ "Invalid resolution format: %s", parsed_info[res_key]
351
+ )
352
+
353
+ return header
354
+
355
+
356
+ def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
357
+ """Returns list of atom sites; contains data not present in the structure."""
358
+ return [
359
+ AtomSite(*site)
360
+ for site in zip( # pylint:disable=g-complex-comprehension
361
+ parsed_info["_atom_site.label_comp_id"],
362
+ parsed_info["_atom_site.auth_asym_id"],
363
+ parsed_info["_atom_site.label_asym_id"],
364
+ parsed_info["_atom_site.auth_seq_id"],
365
+ parsed_info["_atom_site.label_seq_id"],
366
+ parsed_info["_atom_site.pdbx_PDB_ins_code"],
367
+ parsed_info["_atom_site.group_PDB"],
368
+ parsed_info["_atom_site.pdbx_PDB_model_num"],
369
+ )
370
+ ]
371
+
372
+
373
+ def _get_protein_chains(
374
+ *, parsed_info: Mapping[str, Any]
375
+ ) -> Mapping[ChainId, Sequence[Monomer]]:
376
+ """Extracts polymer information for protein chains only.
377
+
378
+ Args:
379
+ parsed_info: _mmcif_dict produced by the Biopython parser.
380
+
381
+ Returns:
382
+ A dict mapping mmcif chain id to a list of Monomers.
383
+ """
384
+ # Get polymer information for each entity in the structure.
385
+ entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info)
386
+
387
+ polymers = collections.defaultdict(list)
388
+ for entity_poly_seq in entity_poly_seqs:
389
+ polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append(
390
+ Monomer(
391
+ id=entity_poly_seq["_entity_poly_seq.mon_id"],
392
+ num=int(entity_poly_seq["_entity_poly_seq.num"]),
393
+ )
394
+ )
395
+
396
+ # Get chemical compositions. Will allow us to identify which of these polymers
397
+ # are proteins.
398
+ chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info)
399
+
400
+ # Get chains information for each entity. Necessary so that we can return a
401
+ # dict keyed on chain id rather than entity.
402
+ struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info)
403
+
404
+ entity_to_mmcif_chains = collections.defaultdict(list)
405
+ for struct_asym in struct_asyms:
406
+ chain_id = struct_asym["_struct_asym.id"]
407
+ entity_id = struct_asym["_struct_asym.entity_id"]
408
+ entity_to_mmcif_chains[entity_id].append(chain_id)
409
+
410
+ # Identify and return the valid protein chains.
411
+ valid_chains = {}
412
+ for entity_id, seq_info in polymers.items():
413
+ chain_ids = entity_to_mmcif_chains[entity_id]
414
+
415
+ # Reject polymers without any peptide-like components, such as DNA/RNA.
416
+ if any(
417
+ [
418
+ "peptide" in chem_comps[monomer.id]["_chem_comp.type"]
419
+ for monomer in seq_info
420
+ ]
421
+ ):
422
+ for chain_id in chain_ids:
423
+ valid_chains[chain_id] = seq_info
424
+ return valid_chains
425
+
426
+
427
+ def _is_set(data: str) -> bool:
428
+ """Returns False if data is a special mmCIF character indicating 'unset'."""
429
+ return data not in (".", "?")
430
+
431
+
432
+ def get_atom_coords(
433
+ mmcif_object: MmcifObject,
434
+ chain_id: str,
435
+ _zero_center_positions: bool = False
436
+ ) -> Tuple[np.ndarray, np.ndarray]:
437
+ # Locate the right chain
438
+ chains = list(mmcif_object.structure.get_chains())
439
+ relevant_chains = [c for c in chains if c.id == chain_id]
440
+ if len(relevant_chains) != 1:
441
+ raise MultipleChainsError(
442
+ f"Expected exactly one chain in structure with id {chain_id}."
443
+ )
444
+ chain = relevant_chains[0]
445
+
446
+ # Extract the coordinates
447
+ num_res = len(mmcif_object.chain_to_seqres[chain_id])
448
+ all_atom_positions = np.zeros(
449
+ [num_res, residue_constants.atom_type_num, 3], dtype=np.float32
450
+ )
451
+ all_atom_mask = np.zeros(
452
+ [num_res, residue_constants.atom_type_num], dtype=np.float32
453
+ )
454
+ for res_index in range(num_res):
455
+ pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
456
+ mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
457
+ res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
458
+ if not res_at_position.is_missing:
459
+ res = chain[
460
+ (
461
+ res_at_position.hetflag,
462
+ res_at_position.position.residue_number,
463
+ res_at_position.position.insertion_code,
464
+ )
465
+ ]
466
+ for atom in res.get_atoms():
467
+ atom_name = atom.get_name()
468
+ x, y, z = atom.get_coord()
469
+ if atom_name in residue_constants.atom_order.keys():
470
+ pos[residue_constants.atom_order[atom_name]] = [x, y, z]
471
+ mask[residue_constants.atom_order[atom_name]] = 1.0
472
+ elif atom_name.upper() == "SE" and res.get_resname() == "MSE":
473
+ # Put the coords of the selenium atom in the sulphur column
474
+ pos[residue_constants.atom_order["SD"]] = [x, y, z]
475
+ mask[residue_constants.atom_order["SD"]] = 1.0
476
+
477
+ all_atom_positions[res_index] = pos
478
+ all_atom_mask[res_index] = mask
479
+
480
+ if _zero_center_positions:
481
+ binary_mask = all_atom_mask.astype(bool)
482
+ translation_vec = all_atom_positions[binary_mask].mean(axis=0)
483
+ all_atom_positions[binary_mask] -= translation_vec
484
+
485
+ return all_atom_positions, all_atom_mask
openfold/data/parsers.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Functions for parsing various file formats."""
17
+ import collections
18
+ import dataclasses
19
+ import re
20
+ import string
21
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
22
+
23
+
24
+ DeletionMatrix = Sequence[Sequence[int]]
25
+
26
+
27
+ @dataclasses.dataclass(frozen=True)
28
+ class TemplateHit:
29
+ """Class representing a template hit."""
30
+
31
+ index: int
32
+ name: str
33
+ aligned_cols: int
34
+ sum_probs: float
35
+ query: str
36
+ hit_sequence: str
37
+ indices_query: List[int]
38
+ indices_hit: List[int]
39
+
40
+
41
+ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
42
+ """Parses FASTA string and returns list of strings with amino-acid sequences.
43
+
44
+ Arguments:
45
+ fasta_string: The string contents of a FASTA file.
46
+
47
+ Returns:
48
+ A tuple of two lists:
49
+ * A list of sequences.
50
+ * A list of sequence descriptions taken from the comment lines. In the
51
+ same order as the sequences.
52
+ """
53
+ sequences = []
54
+ descriptions = []
55
+ index = -1
56
+ for line in fasta_string.splitlines():
57
+ line = line.strip()
58
+ if line.startswith(">"):
59
+ index += 1
60
+ descriptions.append(line[1:]) # Remove the '>' at the beginning.
61
+ sequences.append("")
62
+ continue
63
+ elif not line:
64
+ continue # Skip blank lines.
65
+ sequences[index] += line
66
+
67
+ return sequences, descriptions
68
+
69
+
70
+ def parse_stockholm(
71
+ stockholm_string: str,
72
+ ) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
73
+ """Parses sequences and deletion matrix from stockholm format alignment.
74
+
75
+ Args:
76
+ stockholm_string: The string contents of a stockholm file. The first
77
+ sequence in the file should be the query sequence.
78
+
79
+ Returns:
80
+ A tuple of:
81
+ * A list of sequences that have been aligned to the query. These
82
+ might contain duplicates.
83
+ * The deletion matrix for the alignment as a list of lists. The element
84
+ at `deletion_matrix[i][j]` is the number of residues deleted from
85
+ the aligned sequence i at residue position j.
86
+ * The names of the targets matched, including the jackhmmer subsequence
87
+ suffix.
88
+ """
89
+ name_to_sequence = collections.OrderedDict()
90
+ for line in stockholm_string.splitlines():
91
+ line = line.strip()
92
+ if not line or line.startswith(("#", "//")):
93
+ continue
94
+ name, sequence = line.split()
95
+ if name not in name_to_sequence:
96
+ name_to_sequence[name] = ""
97
+ name_to_sequence[name] += sequence
98
+
99
+ msa = []
100
+ deletion_matrix = []
101
+
102
+ query = ""
103
+ keep_columns = []
104
+ for seq_index, sequence in enumerate(name_to_sequence.values()):
105
+ if seq_index == 0:
106
+ # Gather the columns with gaps from the query
107
+ query = sequence
108
+ keep_columns = [i for i, res in enumerate(query) if res != "-"]
109
+
110
+ # Remove the columns with gaps in the query from all sequences.
111
+ aligned_sequence = "".join([sequence[c] for c in keep_columns])
112
+
113
+ msa.append(aligned_sequence)
114
+
115
+ # Count the number of deletions w.r.t. query.
116
+ deletion_vec = []
117
+ deletion_count = 0
118
+ for seq_res, query_res in zip(sequence, query):
119
+ if seq_res != "-" or query_res != "-":
120
+ if query_res == "-":
121
+ deletion_count += 1
122
+ else:
123
+ deletion_vec.append(deletion_count)
124
+ deletion_count = 0
125
+ deletion_matrix.append(deletion_vec)
126
+
127
+ return msa, deletion_matrix, list(name_to_sequence.keys())
128
+
129
+
130
+ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
131
+ """Parses sequences and deletion matrix from a3m format alignment.
132
+
133
+ Args:
134
+ a3m_string: The string contents of a a3m file. The first sequence in the
135
+ file should be the query sequence.
136
+
137
+ Returns:
138
+ A tuple of:
139
+ * A list of sequences that have been aligned to the query. These
140
+ might contain duplicates.
141
+ * The deletion matrix for the alignment as a list of lists. The element
142
+ at `deletion_matrix[i][j]` is the number of residues deleted from
143
+ the aligned sequence i at residue position j.
144
+ """
145
+ sequences, _ = parse_fasta(a3m_string)
146
+ deletion_matrix = []
147
+ for msa_sequence in sequences:
148
+ deletion_vec = []
149
+ deletion_count = 0
150
+ for j in msa_sequence:
151
+ if j.islower():
152
+ deletion_count += 1
153
+ else:
154
+ deletion_vec.append(deletion_count)
155
+ deletion_count = 0
156
+ deletion_matrix.append(deletion_vec)
157
+
158
+ # Make the MSA matrix out of aligned (deletion-free) sequences.
159
+ deletion_table = str.maketrans("", "", string.ascii_lowercase)
160
+ aligned_sequences = [s.translate(deletion_table) for s in sequences]
161
+ return aligned_sequences, deletion_matrix
162
+
163
+
164
+ def _convert_sto_seq_to_a3m(
165
+ query_non_gaps: Sequence[bool], sto_seq: str
166
+ ) -> Iterable[str]:
167
+ for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
168
+ if is_query_res_non_gap:
169
+ yield sequence_res
170
+ elif sequence_res != "-":
171
+ yield sequence_res.lower()
172
+
173
+
174
+ def convert_stockholm_to_a3m(
175
+ stockholm_format: str, max_sequences: Optional[int] = None
176
+ ) -> str:
177
+ """Converts MSA in Stockholm format to the A3M format."""
178
+ descriptions = {}
179
+ sequences = {}
180
+ reached_max_sequences = False
181
+
182
+ for line in stockholm_format.splitlines():
183
+ reached_max_sequences = (
184
+ max_sequences and len(sequences) >= max_sequences
185
+ )
186
+ if line.strip() and not line.startswith(("#", "//")):
187
+ # Ignore blank lines, markup and end symbols - remainder are alignment
188
+ # sequence parts.
189
+ seqname, aligned_seq = line.split(maxsplit=1)
190
+ if seqname not in sequences:
191
+ if reached_max_sequences:
192
+ continue
193
+ sequences[seqname] = ""
194
+ sequences[seqname] += aligned_seq
195
+
196
+ for line in stockholm_format.splitlines():
197
+ if line[:4] == "#=GS":
198
+ # Description row - example format is:
199
+ # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
200
+ columns = line.split(maxsplit=3)
201
+ seqname, feature = columns[1:3]
202
+ value = columns[3] if len(columns) == 4 else ""
203
+ if feature != "DE":
204
+ continue
205
+ if reached_max_sequences and seqname not in sequences:
206
+ continue
207
+ descriptions[seqname] = value
208
+ if len(descriptions) == len(sequences):
209
+ break
210
+
211
+ # Convert sto format to a3m line by line
212
+ a3m_sequences = {}
213
+ # query_sequence is assumed to be the first sequence
214
+ query_sequence = next(iter(sequences.values()))
215
+ query_non_gaps = [res != "-" for res in query_sequence]
216
+ for seqname, sto_sequence in sequences.items():
217
+ a3m_sequences[seqname] = "".join(
218
+ _convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)
219
+ )
220
+
221
+ fasta_chunks = (
222
+ f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
223
+ for k in a3m_sequences
224
+ )
225
+ return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
226
+
227
+
228
+ def _get_hhr_line_regex_groups(
229
+ regex_pattern: str, line: str
230
+ ) -> Sequence[Optional[str]]:
231
+ match = re.match(regex_pattern, line)
232
+ if match is None:
233
+ raise RuntimeError(f"Could not parse query line {line}")
234
+ return match.groups()
235
+
236
+
237
+ def _update_hhr_residue_indices_list(
238
+ sequence: str, start_index: int, indices_list: List[int]
239
+ ):
240
+ """Computes the relative indices for each residue with respect to the original sequence."""
241
+ counter = start_index
242
+ for symbol in sequence:
243
+ if symbol == "-":
244
+ indices_list.append(-1)
245
+ else:
246
+ indices_list.append(counter)
247
+ counter += 1
248
+
249
+
250
+ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
251
+ """Parses the detailed HMM HMM comparison section for a single Hit.
252
+
253
+ This works on .hhr files generated from both HHBlits and HHSearch.
254
+
255
+ Args:
256
+ detailed_lines: A list of lines from a single comparison section between 2
257
+ sequences (which each have their own HMM's)
258
+
259
+ Returns:
260
+ A dictionary with the information from that detailed comparison section
261
+
262
+ Raises:
263
+ RuntimeError: If a certain line cannot be processed
264
+ """
265
+ # Parse first 2 lines.
266
+ number_of_hit = int(detailed_lines[0].split()[-1])
267
+ name_hit = detailed_lines[1][1:]
268
+
269
+ # Parse the summary line.
270
+ pattern = (
271
+ "Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
272
+ " ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
273
+ "]*Template_Neff=(.*)"
274
+ )
275
+ match = re.match(pattern, detailed_lines[2])
276
+ if match is None:
277
+ raise RuntimeError(
278
+ "Could not parse section: %s. Expected this: \n%s to contain summary."
279
+ % (detailed_lines, detailed_lines[2])
280
+ )
281
+ (prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [
282
+ float(x) for x in match.groups()
283
+ ]
284
+
285
+ # The next section reads the detailed comparisons. These are in a 'human
286
+ # readable' format which has a fixed length. The strategy employed is to
287
+ # assume that each block starts with the query sequence line, and to parse
288
+ # that with a regexp in order to deduce the fixed length used for that block.
289
+ query = ""
290
+ hit_sequence = ""
291
+ indices_query = []
292
+ indices_hit = []
293
+ length_block = None
294
+
295
+ for line in detailed_lines[3:]:
296
+ # Parse the query sequence line
297
+ if (
298
+ line.startswith("Q ")
299
+ and not line.startswith("Q ss_dssp")
300
+ and not line.startswith("Q ss_pred")
301
+ and not line.startswith("Q Consensus")
302
+ ):
303
+ # Thus the first 17 characters must be 'Q <query_name> ', and we can parse
304
+ # everything after that.
305
+ # start sequence end total_sequence_length
306
+ patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
307
+ groups = _get_hhr_line_regex_groups(patt, line[17:])
308
+
309
+ # Get the length of the parsed block using the start and finish indices,
310
+ # and ensure it is the same as the actual block length.
311
+ start = int(groups[0]) - 1 # Make index zero based.
312
+ delta_query = groups[1]
313
+ end = int(groups[2])
314
+ num_insertions = len([x for x in delta_query if x == "-"])
315
+ length_block = end - start + num_insertions
316
+ assert length_block == len(delta_query)
317
+
318
+ # Update the query sequence and indices list.
319
+ query += delta_query
320
+ _update_hhr_residue_indices_list(delta_query, start, indices_query)
321
+
322
+ elif line.startswith("T "):
323
+ # Parse the hit sequence.
324
+ if (
325
+ not line.startswith("T ss_dssp")
326
+ and not line.startswith("T ss_pred")
327
+ and not line.startswith("T Consensus")
328
+ ):
329
+ # Thus the first 17 characters must be 'T <hit_name> ', and we can
330
+ # parse everything after that.
331
+ # start sequence end total_sequence_length
332
+ patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
333
+ groups = _get_hhr_line_regex_groups(patt, line[17:])
334
+ start = int(groups[0]) - 1 # Make index zero based.
335
+ delta_hit_sequence = groups[1]
336
+ assert length_block == len(delta_hit_sequence)
337
+
338
+ # Update the hit sequence and indices list.
339
+ hit_sequence += delta_hit_sequence
340
+ _update_hhr_residue_indices_list(
341
+ delta_hit_sequence, start, indices_hit
342
+ )
343
+
344
+ return TemplateHit(
345
+ index=number_of_hit,
346
+ name=name_hit,
347
+ aligned_cols=int(aligned_cols),
348
+ sum_probs=sum_probs,
349
+ query=query,
350
+ hit_sequence=hit_sequence,
351
+ indices_query=indices_query,
352
+ indices_hit=indices_hit,
353
+ )
354
+
355
+
356
+ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
357
+ """Parses the content of an entire HHR file."""
358
+ lines = hhr_string.splitlines()
359
+
360
+ # Each .hhr file starts with a results table, then has a sequence of hit
361
+ # "paragraphs", each paragraph starting with a line 'No <hit number>'. We
362
+ # iterate through each paragraph to parse each hit.
363
+
364
+ block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
365
+
366
+ hits = []
367
+ if block_starts:
368
+ block_starts.append(len(lines)) # Add the end of the final block.
369
+ for i in range(len(block_starts) - 1):
370
+ hits.append(
371
+ _parse_hhr_hit(lines[block_starts[i] : block_starts[i + 1]])
372
+ )
373
+ return hits
374
+
375
+
376
+ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
377
+ """Parse target to e-value mapping parsed from Jackhmmer tblout string."""
378
+ e_values = {"query": 0}
379
+ lines = [line for line in tblout.splitlines() if line[0] != "#"]
380
+ # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
381
+ # space-delimited. Relevant fields are (1) target name: and
382
+ # (5) E-value (full sequence) (numbering from 1).
383
+ for line in lines:
384
+ fields = line.split()
385
+ e_value = fields[4]
386
+ target_name = fields[0]
387
+ e_values[target_name] = float(e_value)
388
+ return e_values
openfold/data/templates.py ADDED
@@ -0,0 +1,1108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Functions for getting templates and calculating template features."""
17
+ import dataclasses
18
+ import datetime
19
+ import glob
20
+ import json
21
+ import logging
22
+ import os
23
+ import re
24
+ from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
25
+
26
+ import numpy as np
27
+
28
+ from openfold.data import parsers, mmcif_parsing
29
+ from openfold.data.errors import Error
30
+ from openfold.data.tools import kalign
31
+ from openfold.data.tools.utils import to_date
32
+ from openfold.np import residue_constants
33
+
34
+
35
+ class NoChainsError(Error):
36
+ """An error indicating that template mmCIF didn't have any chains."""
37
+
38
+
39
+ class SequenceNotInTemplateError(Error):
40
+ """An error indicating that template mmCIF didn't contain the sequence."""
41
+
42
+
43
+ class NoAtomDataInTemplateError(Error):
44
+ """An error indicating that template mmCIF didn't contain atom positions."""
45
+
46
+
47
+ class TemplateAtomMaskAllZerosError(Error):
48
+ """An error indicating that template mmCIF had all atom positions masked."""
49
+
50
+
51
+ class QueryToTemplateAlignError(Error):
52
+ """An error indicating that the query can't be aligned to the template."""
53
+
54
+
55
+ class CaDistanceError(Error):
56
+ """An error indicating that a CA atom distance exceeds a threshold."""
57
+
58
+
59
+ # Prefilter exceptions.
60
+ class PrefilterError(Exception):
61
+ """A base class for template prefilter exceptions."""
62
+
63
+
64
+ class DateError(PrefilterError):
65
+ """An error indicating that the hit date was after the max allowed date."""
66
+
67
+
68
+ class PdbIdError(PrefilterError):
69
+ """An error indicating that the hit PDB ID was identical to the query."""
70
+
71
+
72
+ class AlignRatioError(PrefilterError):
73
+ """An error indicating that the hit align ratio to the query was too small."""
74
+
75
+
76
+ class DuplicateError(PrefilterError):
77
+ """An error indicating that the hit was an exact subsequence of the query."""
78
+
79
+
80
+ class LengthError(PrefilterError):
81
+ """An error indicating that the hit was too short."""
82
+
83
+
84
+ TEMPLATE_FEATURES = {
85
+ "template_aatype": np.int64,
86
+ "template_all_atom_mask": np.float32,
87
+ "template_all_atom_positions": np.float32,
88
+ "template_domain_names": np.object,
89
+ "template_sequence": np.object,
90
+ "template_sum_probs": np.float32,
91
+ }
92
+
93
+
94
+ def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
95
+ """Returns PDB id and chain id for an HHSearch Hit."""
96
+ # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
97
+ id_match = re.match(r"[a-zA-Z\d]{4}_[a-zA-Z0-9.]+", hit.name)
98
+ if not id_match:
99
+ raise ValueError(f"hit.name did not start with PDBID_chain: {hit.name}")
100
+ pdb_id, chain_id = id_match.group(0).split("_")
101
+ return pdb_id.lower(), chain_id
102
+
103
+
104
+ def _is_after_cutoff(
105
+ pdb_id: str,
106
+ release_dates: Mapping[str, datetime.datetime],
107
+ release_date_cutoff: Optional[datetime.datetime],
108
+ ) -> bool:
109
+ """Checks if the template date is after the release date cutoff.
110
+
111
+ Args:
112
+ pdb_id: 4 letter pdb code.
113
+ release_dates: Dictionary mapping PDB ids to their structure release dates.
114
+ release_date_cutoff: Max release date that is valid for this query.
115
+
116
+ Returns:
117
+ True if the template release date is after the cutoff, False otherwise.
118
+ """
119
+ pdb_id_upper = pdb_id.upper()
120
+ if release_date_cutoff is None:
121
+ raise ValueError("The release_date_cutoff must not be None.")
122
+ if pdb_id_upper in release_dates:
123
+ return release_dates[pdb_id_upper] > release_date_cutoff
124
+ else:
125
+ # Since this is just a quick prefilter to reduce the number of mmCIF files
126
+ # we need to parse, we don't have to worry about returning True here.
127
+ logging.info(
128
+ "Template structure not in release dates dict: %s", pdb_id
129
+ )
130
+ return False
131
+
132
+
133
+ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
134
+ """Parses the data file from PDB that lists which PDB ids are obsolete."""
135
+ with open(obsolete_file_path) as f:
136
+ result = {}
137
+ for line in f:
138
+ line = line.strip()
139
+ # We skip obsolete entries that don't contain a mapping to a new entry.
140
+ if line.startswith("OBSLTE") and len(line) > 30:
141
+ # Format: Date From To
142
+ # 'OBSLTE 31-JUL-94 116L 216L'
143
+ from_id = line[20:24].lower()
144
+ to_id = line[29:33].lower()
145
+ result[from_id] = to_id
146
+ return result
147
+
148
+
149
+ def generate_release_dates_cache(mmcif_dir: str, out_path: str):
150
+ dates = {}
151
+ for f in os.listdir(mmcif_dir):
152
+ if f.endswith(".cif"):
153
+ path = os.path.join(mmcif_dir, f)
154
+ with open(path, "r") as fp:
155
+ mmcif_string = fp.read()
156
+
157
+ file_id = os.path.splitext(f)[0]
158
+ mmcif = mmcif_parsing.parse(
159
+ file_id=file_id, mmcif_string=mmcif_string
160
+ )
161
+ if mmcif.mmcif_object is None:
162
+ logging.info(f"Failed to parse {f}. Skipping...")
163
+ continue
164
+
165
+ mmcif = mmcif.mmcif_object
166
+ release_date = mmcif.header["release_date"]
167
+
168
+ dates[file_id] = release_date
169
+
170
+ with open(out_path, "r") as fp:
171
+ fp.write(json.dumps(dates))
172
+
173
+
174
+ def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
175
+ """Parses release dates file, returns a mapping from PDBs to release dates."""
176
+ with open(path, "r") as fp:
177
+ data = json.load(fp)
178
+
179
+ return {
180
+ pdb.upper(): to_date(v)
181
+ for pdb, d in data.items()
182
+ for k, v in d.items()
183
+ if k == "release_date"
184
+ }
185
+
186
+
187
+ def _assess_hhsearch_hit(
188
+ hit: parsers.TemplateHit,
189
+ hit_pdb_code: str,
190
+ query_sequence: str,
191
+ query_pdb_code: Optional[str],
192
+ release_dates: Mapping[str, datetime.datetime],
193
+ release_date_cutoff: datetime.datetime,
194
+ max_subsequence_ratio: float = 0.95,
195
+ min_align_ratio: float = 0.1,
196
+ ) -> bool:
197
+ """Determines if template is valid (without parsing the template mmcif file).
198
+
199
+ Args:
200
+ hit: HhrHit for the template.
201
+ hit_pdb_code: The 4 letter pdb code of the template hit. This might be
202
+ different from the value in the actual hit since the original pdb might
203
+ have become obsolete.
204
+ query_sequence: Amino acid sequence of the query.
205
+ query_pdb_code: 4 letter pdb code of the query.
206
+ release_dates: Dictionary mapping pdb codes to their structure release
207
+ dates.
208
+ release_date_cutoff: Max release date that is valid for this query.
209
+ max_subsequence_ratio: Exclude any exact matches with this much overlap.
210
+ min_align_ratio: Minimum overlap between the template and query.
211
+
212
+ Returns:
213
+ True if the hit passed the prefilter. Raises an exception otherwise.
214
+
215
+ Raises:
216
+ DateError: If the hit date was after the max allowed date.
217
+ PdbIdError: If the hit PDB ID was identical to the query.
218
+ AlignRatioError: If the hit align ratio to the query was too small.
219
+ DuplicateError: If the hit was an exact subsequence of the query.
220
+ LengthError: If the hit was too short.
221
+ """
222
+ aligned_cols = hit.aligned_cols
223
+ align_ratio = aligned_cols / len(query_sequence)
224
+
225
+ template_sequence = hit.hit_sequence.replace("-", "")
226
+ length_ratio = float(len(template_sequence)) / len(query_sequence)
227
+
228
+ # Check whether the template is a large subsequence or duplicate of original
229
+ # query. This can happen due to duplicate entries in the PDB database.
230
+ duplicate = (
231
+ template_sequence in query_sequence
232
+ and length_ratio > max_subsequence_ratio
233
+ )
234
+
235
+ if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
236
+ date = release_dates[hit_pdb_code.upper()]
237
+ raise DateError(
238
+ f"Date ({date}) > max template date "
239
+ f"({release_date_cutoff})."
240
+ )
241
+
242
+ if query_pdb_code is not None:
243
+ if query_pdb_code.lower() == hit_pdb_code.lower():
244
+ raise PdbIdError("PDB code identical to Query PDB code.")
245
+
246
+ if align_ratio <= min_align_ratio:
247
+ raise AlignRatioError(
248
+ "Proportion of residues aligned to query too small. "
249
+ f"Align ratio: {align_ratio}."
250
+ )
251
+
252
+ if duplicate:
253
+ raise DuplicateError(
254
+ "Template is an exact subsequence of query with large "
255
+ f"coverage. Length ratio: {length_ratio}."
256
+ )
257
+
258
+ if len(template_sequence) < 10:
259
+ raise LengthError(
260
+ f"Template too short. Length: {len(template_sequence)}."
261
+ )
262
+
263
+ return True
264
+
265
+
266
+ def _find_template_in_pdb(
267
+ template_chain_id: str,
268
+ template_sequence: str,
269
+ mmcif_object: mmcif_parsing.MmcifObject,
270
+ ) -> Tuple[str, str, int]:
271
+ """Tries to find the template chain in the given pdb file.
272
+
273
+ This method tries the three following things in order:
274
+ 1. Tries if there is an exact match in both the chain ID and the sequence.
275
+ If yes, the chain sequence is returned. Otherwise:
276
+ 2. Tries if there is an exact match only in the sequence.
277
+ If yes, the chain sequence is returned. Otherwise:
278
+ 3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
279
+ If yes, the chain sequence is returned.
280
+ If none of these succeed, a SequenceNotInTemplateError is thrown.
281
+
282
+ Args:
283
+ template_chain_id: The template chain ID.
284
+ template_sequence: The template chain sequence.
285
+ mmcif_object: The PDB object to search for the template in.
286
+
287
+ Returns:
288
+ A tuple with:
289
+ * The chain sequence that was found to match the template in the PDB object.
290
+ * The ID of the chain that is being returned.
291
+ * The offset where the template sequence starts in the chain sequence.
292
+
293
+ Raises:
294
+ SequenceNotInTemplateError: If no match is found after the steps described
295
+ above.
296
+ """
297
+ # Try if there is an exact match in both the chain ID and the (sub)sequence.
298
+ pdb_id = mmcif_object.file_id
299
+ chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
300
+ if chain_sequence and (template_sequence in chain_sequence):
301
+ logging.info(
302
+ "Found an exact template match %s_%s.", pdb_id, template_chain_id
303
+ )
304
+ mapping_offset = chain_sequence.find(template_sequence)
305
+ return chain_sequence, template_chain_id, mapping_offset
306
+
307
+ # Try if there is an exact match in the (sub)sequence only.
308
+ for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
309
+ if chain_sequence and (template_sequence in chain_sequence):
310
+ logging.info("Found a sequence-only match %s_%s.", pdb_id, chain_id)
311
+ mapping_offset = chain_sequence.find(template_sequence)
312
+ return chain_sequence, chain_id, mapping_offset
313
+
314
+ # Return a chain sequence that fuzzy matches (X = wildcard) the template.
315
+ # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
316
+ regex = ["." if aa == "X" else "(?:%s|X)" % aa for aa in template_sequence]
317
+ regex = re.compile("".join(regex))
318
+ for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
319
+ match = re.search(regex, chain_sequence)
320
+ if match:
321
+ logging.info(
322
+ "Found a fuzzy sequence-only match %s_%s.", pdb_id, chain_id
323
+ )
324
+ mapping_offset = match.start()
325
+ return chain_sequence, chain_id, mapping_offset
326
+
327
+ # No hits, raise an error.
328
+ raise SequenceNotInTemplateError(
329
+ "Could not find the template sequence in %s_%s. Template sequence: %s, "
330
+ "chain_to_seqres: %s"
331
+ % (
332
+ pdb_id,
333
+ template_chain_id,
334
+ template_sequence,
335
+ mmcif_object.chain_to_seqres,
336
+ )
337
+ )
338
+
339
+
340
+ def _realign_pdb_template_to_query(
341
+ old_template_sequence: str,
342
+ template_chain_id: str,
343
+ mmcif_object: mmcif_parsing.MmcifObject,
344
+ old_mapping: Mapping[int, int],
345
+ kalign_binary_path: str,
346
+ ) -> Tuple[str, Mapping[int, int]]:
347
+ """Aligns template from the mmcif_object to the query.
348
+
349
+ In case PDB70 contains a different version of the template sequence, we need
350
+ to perform a realignment to the actual sequence that is in the mmCIF file.
351
+ This method performs such realignment, but returns the new sequence and
352
+ mapping only if the sequence in the mmCIF file is 90% identical to the old
353
+ sequence.
354
+
355
+ Note that the old_template_sequence comes from the hit, and contains only that
356
+ part of the chain that matches with the query while the new_template_sequence
357
+ is the full chain.
358
+
359
+ Args:
360
+ old_template_sequence: The template sequence that was returned by the PDB
361
+ template search (typically done using HHSearch).
362
+ template_chain_id: The template chain id was returned by the PDB template
363
+ search (typically done using HHSearch). This is used to find the right
364
+ chain in the mmcif_object chain_to_seqres mapping.
365
+ mmcif_object: A mmcif_object which holds the actual template data.
366
+ old_mapping: A mapping from the query sequence to the template sequence.
367
+ This mapping will be used to compute the new mapping from the query
368
+ sequence to the actual mmcif_object template sequence by aligning the
369
+ old_template_sequence and the actual template sequence.
370
+ kalign_binary_path: The path to a kalign executable.
371
+
372
+ Returns:
373
+ A tuple (new_template_sequence, new_query_to_template_mapping) where:
374
+ * new_template_sequence is the actual template sequence that was found in
375
+ the mmcif_object.
376
+ * new_query_to_template_mapping is the new mapping from the query to the
377
+ actual template found in the mmcif_object.
378
+
379
+ Raises:
380
+ QueryToTemplateAlignError:
381
+ * If there was an error thrown by the alignment tool.
382
+ * Or if the actual template sequence differs by more than 10% from the
383
+ old_template_sequence.
384
+ """
385
+ aligner = kalign.Kalign(binary_path=kalign_binary_path)
386
+ new_template_sequence = mmcif_object.chain_to_seqres.get(
387
+ template_chain_id, ""
388
+ )
389
+
390
+ # Sometimes the template chain id is unknown. But if there is only a single
391
+ # sequence within the mmcif_object, it is safe to assume it is that one.
392
+ if not new_template_sequence:
393
+ if len(mmcif_object.chain_to_seqres) == 1:
394
+ logging.info(
395
+ "Could not find %s in %s, but there is only 1 sequence, so "
396
+ "using that one.",
397
+ template_chain_id,
398
+ mmcif_object.file_id,
399
+ )
400
+ new_template_sequence = list(mmcif_object.chain_to_seqres.values())[
401
+ 0
402
+ ]
403
+ else:
404
+ raise QueryToTemplateAlignError(
405
+ f"Could not find chain {template_chain_id} in {mmcif_object.file_id}. "
406
+ "If there are no mmCIF parsing errors, it is possible it was not a "
407
+ "protein chain."
408
+ )
409
+
410
+ try:
411
+ (old_aligned_template, new_aligned_template), _ = parsers.parse_a3m(
412
+ aligner.align([old_template_sequence, new_template_sequence])
413
+ )
414
+ except Exception as e:
415
+ raise QueryToTemplateAlignError(
416
+ "Could not align old template %s to template %s (%s_%s). Error: %s"
417
+ % (
418
+ old_template_sequence,
419
+ new_template_sequence,
420
+ mmcif_object.file_id,
421
+ template_chain_id,
422
+ str(e),
423
+ )
424
+ )
425
+
426
+ logging.info(
427
+ "Old aligned template: %s\nNew aligned template: %s",
428
+ old_aligned_template,
429
+ new_aligned_template,
430
+ )
431
+
432
+ old_to_new_template_mapping = {}
433
+ old_template_index = -1
434
+ new_template_index = -1
435
+ num_same = 0
436
+ for old_template_aa, new_template_aa in zip(
437
+ old_aligned_template, new_aligned_template
438
+ ):
439
+ if old_template_aa != "-":
440
+ old_template_index += 1
441
+ if new_template_aa != "-":
442
+ new_template_index += 1
443
+ if old_template_aa != "-" and new_template_aa != "-":
444
+ old_to_new_template_mapping[old_template_index] = new_template_index
445
+ if old_template_aa == new_template_aa:
446
+ num_same += 1
447
+
448
+ # Require at least 90 % sequence identity wrt to the shorter of the sequences.
449
+ if (
450
+ float(num_same)
451
+ / min(len(old_template_sequence), len(new_template_sequence))
452
+ < 0.9
453
+ ):
454
+ raise QueryToTemplateAlignError(
455
+ "Insufficient similarity of the sequence in the database: %s to the "
456
+ "actual sequence in the mmCIF file %s_%s: %s. We require at least "
457
+ "90 %% similarity wrt to the shorter of the sequences. This is not a "
458
+ "problem unless you think this is a template that should be included."
459
+ % (
460
+ old_template_sequence,
461
+ mmcif_object.file_id,
462
+ template_chain_id,
463
+ new_template_sequence,
464
+ )
465
+ )
466
+
467
+ new_query_to_template_mapping = {}
468
+ for query_index, old_template_index in old_mapping.items():
469
+ new_query_to_template_mapping[
470
+ query_index
471
+ ] = old_to_new_template_mapping.get(old_template_index, -1)
472
+
473
+ new_template_sequence = new_template_sequence.replace("-", "")
474
+
475
+ return new_template_sequence, new_query_to_template_mapping
476
+
477
+
478
+ def _check_residue_distances(
479
+ all_positions: np.ndarray,
480
+ all_positions_mask: np.ndarray,
481
+ max_ca_ca_distance: float,
482
+ ):
483
+ """Checks if the distance between unmasked neighbor residues is ok."""
484
+ ca_position = residue_constants.atom_order["CA"]
485
+ prev_is_unmasked = False
486
+ prev_calpha = None
487
+ for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
488
+ this_is_unmasked = bool(mask[ca_position])
489
+ if this_is_unmasked:
490
+ this_calpha = coords[ca_position]
491
+ if prev_is_unmasked:
492
+ distance = np.linalg.norm(this_calpha - prev_calpha)
493
+ if distance > max_ca_ca_distance:
494
+ raise CaDistanceError(
495
+ "The distance between residues %d and %d is %f > limit %f."
496
+ % (i, i + 1, distance, max_ca_ca_distance)
497
+ )
498
+ prev_calpha = this_calpha
499
+ prev_is_unmasked = this_is_unmasked
500
+
501
+
502
+ def _get_atom_positions(
503
+ mmcif_object: mmcif_parsing.MmcifObject,
504
+ auth_chain_id: str,
505
+ max_ca_ca_distance: float,
506
+ _zero_center_positions: bool = False,
507
+ ) -> Tuple[np.ndarray, np.ndarray]:
508
+ """Gets atom positions and mask from a list of Biopython Residues."""
509
+ coords_with_mask = mmcif_parsing.get_atom_coords(
510
+ mmcif_object=mmcif_object,
511
+ chain_id=auth_chain_id,
512
+ _zero_center_positions=_zero_center_positions,
513
+ )
514
+ all_atom_positions, all_atom_mask = coords_with_mask
515
+ _check_residue_distances(
516
+ all_atom_positions, all_atom_mask, max_ca_ca_distance
517
+ )
518
+ return all_atom_positions, all_atom_mask
519
+
520
+
521
+ def _extract_template_features(
522
+ mmcif_object: mmcif_parsing.MmcifObject,
523
+ pdb_id: str,
524
+ mapping: Mapping[int, int],
525
+ template_sequence: str,
526
+ query_sequence: str,
527
+ template_chain_id: str,
528
+ kalign_binary_path: str,
529
+ _zero_center_positions: bool = True,
530
+ ) -> Tuple[Dict[str, Any], Optional[str]]:
531
+ """Parses atom positions in the target structure and aligns with the query.
532
+
533
+ Atoms for each residue in the template structure are indexed to coincide
534
+ with their corresponding residue in the query sequence, according to the
535
+ alignment mapping provided.
536
+
537
+ Args:
538
+ mmcif_object: mmcif_parsing.MmcifObject representing the template.
539
+ pdb_id: PDB code for the template.
540
+ mapping: Dictionary mapping indices in the query sequence to indices in
541
+ the template sequence.
542
+ template_sequence: String describing the amino acid sequence for the
543
+ template protein.
544
+ query_sequence: String describing the amino acid sequence for the query
545
+ protein.
546
+ template_chain_id: String ID describing which chain in the structure proto
547
+ should be used.
548
+ kalign_binary_path: The path to a kalign executable used for template
549
+ realignment.
550
+
551
+ Returns:
552
+ A tuple with:
553
+ * A dictionary containing the extra features derived from the template
554
+ protein structure.
555
+ * A warning message if the hit was realigned to the actual mmCIF sequence.
556
+ Otherwise None.
557
+
558
+ Raises:
559
+ NoChainsError: If the mmcif object doesn't contain any chains.
560
+ SequenceNotInTemplateError: If the given chain id / sequence can't
561
+ be found in the mmcif object.
562
+ QueryToTemplateAlignError: If the actual template in the mmCIF file
563
+ can't be aligned to the query.
564
+ NoAtomDataInTemplateError: If the mmcif object doesn't contain
565
+ atom positions.
566
+ TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
567
+ unmasked residues.
568
+ """
569
+ if mmcif_object is None or not mmcif_object.chain_to_seqres:
570
+ raise NoChainsError(
571
+ "No chains in PDB: %s_%s" % (pdb_id, template_chain_id)
572
+ )
573
+
574
+ warning = None
575
+ try:
576
+ seqres, chain_id, mapping_offset = _find_template_in_pdb(
577
+ template_chain_id=template_chain_id,
578
+ template_sequence=template_sequence,
579
+ mmcif_object=mmcif_object,
580
+ )
581
+ except SequenceNotInTemplateError:
582
+ # If PDB70 contains a different version of the template, we use the sequence
583
+ # from the mmcif_object.
584
+ chain_id = template_chain_id
585
+ warning = (
586
+ f"The exact sequence {template_sequence} was not found in "
587
+ f"{pdb_id}_{chain_id}. Realigning the template to the actual sequence."
588
+ )
589
+ logging.warning(warning)
590
+ # This throws an exception if it fails to realign the hit.
591
+ seqres, mapping = _realign_pdb_template_to_query(
592
+ old_template_sequence=template_sequence,
593
+ template_chain_id=template_chain_id,
594
+ mmcif_object=mmcif_object,
595
+ old_mapping=mapping,
596
+ kalign_binary_path=kalign_binary_path,
597
+ )
598
+ logging.info(
599
+ "Sequence in %s_%s: %s successfully realigned to %s",
600
+ pdb_id,
601
+ chain_id,
602
+ template_sequence,
603
+ seqres,
604
+ )
605
+ # The template sequence changed.
606
+ template_sequence = seqres
607
+ # No mapping offset, the query is aligned to the actual sequence.
608
+ mapping_offset = 0
609
+
610
+ try:
611
+ # Essentially set to infinity - we don't want to reject templates unless
612
+ # they're really really bad.
613
+ all_atom_positions, all_atom_mask = _get_atom_positions(
614
+ mmcif_object,
615
+ chain_id,
616
+ max_ca_ca_distance=150.0,
617
+ _zero_center_positions=_zero_center_positions,
618
+ )
619
+ except (CaDistanceError, KeyError) as ex:
620
+ raise NoAtomDataInTemplateError(
621
+ "Could not get atom data (%s_%s): %s" % (pdb_id, chain_id, str(ex))
622
+ ) from ex
623
+
624
+ all_atom_positions = np.split(
625
+ all_atom_positions, all_atom_positions.shape[0]
626
+ )
627
+ all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
628
+
629
+ output_templates_sequence = []
630
+ templates_all_atom_positions = []
631
+ templates_all_atom_masks = []
632
+
633
+ for _ in query_sequence:
634
+ # Residues in the query_sequence that are not in the template_sequence:
635
+ templates_all_atom_positions.append(
636
+ np.zeros((residue_constants.atom_type_num, 3))
637
+ )
638
+ templates_all_atom_masks.append(
639
+ np.zeros(residue_constants.atom_type_num)
640
+ )
641
+ output_templates_sequence.append("-")
642
+
643
+ for k, v in mapping.items():
644
+ template_index = v + mapping_offset
645
+ templates_all_atom_positions[k] = all_atom_positions[template_index][0]
646
+ templates_all_atom_masks[k] = all_atom_masks[template_index][0]
647
+ output_templates_sequence[k] = template_sequence[v]
648
+
649
+ # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
650
+ if np.sum(templates_all_atom_masks) < 5:
651
+ raise TemplateAtomMaskAllZerosError(
652
+ "Template all atom mask was all zeros: %s_%s. Residue range: %d-%d"
653
+ % (
654
+ pdb_id,
655
+ chain_id,
656
+ min(mapping.values()) + mapping_offset,
657
+ max(mapping.values()) + mapping_offset,
658
+ )
659
+ )
660
+
661
+ output_templates_sequence = "".join(output_templates_sequence)
662
+
663
+ templates_aatype = residue_constants.sequence_to_onehot(
664
+ output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID
665
+ )
666
+
667
+ return (
668
+ {
669
+ "template_all_atom_positions": np.array(
670
+ templates_all_atom_positions
671
+ ),
672
+ "template_all_atom_mask": np.array(templates_all_atom_masks),
673
+ "template_sequence": output_templates_sequence.encode(),
674
+ "template_aatype": np.array(templates_aatype),
675
+ "template_domain_names": f"{pdb_id.lower()}_{chain_id}".encode(),
676
+ },
677
+ warning,
678
+ )
679
+
680
+
681
+ def _build_query_to_hit_index_mapping(
682
+ hit_query_sequence: str,
683
+ hit_sequence: str,
684
+ indices_hit: Sequence[int],
685
+ indices_query: Sequence[int],
686
+ original_query_sequence: str,
687
+ ) -> Mapping[int, int]:
688
+ """Gets mapping from indices in original query sequence to indices in the hit.
689
+
690
+ hit_query_sequence and hit_sequence are two aligned sequences containing gap
691
+ characters. hit_query_sequence contains only the part of the original query
692
+ sequence that matched the hit. When interpreting the indices from the .hhr, we
693
+ need to correct for this to recover a mapping from original query sequence to
694
+ the hit sequence.
695
+
696
+ Args:
697
+ hit_query_sequence: The portion of the query sequence that is in the .hhr
698
+ hit
699
+ hit_sequence: The portion of the hit sequence that is in the .hhr
700
+ indices_hit: The indices for each aminoacid relative to the hit sequence
701
+ indices_query: The indices for each aminoacid relative to the original query
702
+ sequence
703
+ original_query_sequence: String describing the original query sequence.
704
+
705
+ Returns:
706
+ Dictionary with indices in the original query sequence as keys and indices
707
+ in the hit sequence as values.
708
+ """
709
+ # If the hit is empty (no aligned residues), return empty mapping
710
+ if not hit_query_sequence:
711
+ return {}
712
+
713
+ # Remove gaps and find the offset of hit.query relative to original query.
714
+ hhsearch_query_sequence = hit_query_sequence.replace("-", "")
715
+ hit_sequence = hit_sequence.replace("-", "")
716
+ hhsearch_query_offset = original_query_sequence.find(
717
+ hhsearch_query_sequence
718
+ )
719
+
720
+ # Index of -1 used for gap characters. Subtract the min index ignoring gaps.
721
+ min_idx = min(x for x in indices_hit if x > -1)
722
+ fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit]
723
+
724
+ min_idx = min(x for x in indices_query if x > -1)
725
+ fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]
726
+
727
+ # Zip the corrected indices, ignore case where both seqs have gap characters.
728
+ mapping = {}
729
+ for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
730
+ if q_t != -1 and q_i != -1:
731
+ if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(
732
+ original_query_sequence
733
+ ):
734
+ continue
735
+ mapping[q_i + hhsearch_query_offset] = q_t
736
+
737
+ return mapping
738
+
739
+
740
+ @dataclasses.dataclass(frozen=True)
741
+ class PrefilterResult:
742
+ valid: bool
743
+ error: Optional[str]
744
+ warning: Optional[str]
745
+
746
+ @dataclasses.dataclass(frozen=True)
747
+ class SingleHitResult:
748
+ features: Optional[Mapping[str, Any]]
749
+ error: Optional[str]
750
+ warning: Optional[str]
751
+
752
+
753
+ def _prefilter_hit(
754
+ query_sequence: str,
755
+ query_pdb_code: Optional[str],
756
+ hit: parsers.TemplateHit,
757
+ max_template_date: datetime.datetime,
758
+ release_dates: Mapping[str, datetime.datetime],
759
+ obsolete_pdbs: Mapping[str, str],
760
+ strict_error_check: bool = False,
761
+ ):
762
+ # Fail hard if we can't get the PDB ID and chain name from the hit.
763
+ hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
764
+
765
+ if hit_pdb_code not in release_dates:
766
+ if hit_pdb_code in obsolete_pdbs:
767
+ hit_pdb_code = obsolete_pdbs[hit_pdb_code]
768
+
769
+ # Pass hit_pdb_code since it might have changed due to the pdb being
770
+ # obsolete.
771
+ try:
772
+ _assess_hhsearch_hit(
773
+ hit=hit,
774
+ hit_pdb_code=hit_pdb_code,
775
+ query_sequence=query_sequence,
776
+ query_pdb_code=query_pdb_code,
777
+ release_dates=release_dates,
778
+ release_date_cutoff=max_template_date,
779
+ )
780
+ except PrefilterError as e:
781
+ hit_name = f"{hit_pdb_code}_{hit_chain_id}"
782
+ msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
783
+ logging.info("%s: %s", query_pdb_code, msg)
784
+ if strict_error_check and isinstance(
785
+ e, (DateError, PdbIdError, DuplicateError)
786
+ ):
787
+ # In strict mode we treat some prefilter cases as errors.
788
+ return PrefilterResult(valid=False, error=msg, warning=None)
789
+
790
+ return PrefilterResult(valid=False, error=None, warning=None)
791
+
792
+ return PrefilterResult(valid=True, error=None, warning=None)
793
+
794
+
795
+ def _process_single_hit(
796
+ query_sequence: str,
797
+ query_pdb_code: Optional[str],
798
+ hit: parsers.TemplateHit,
799
+ mmcif_dir: str,
800
+ max_template_date: datetime.datetime,
801
+ release_dates: Mapping[str, datetime.datetime],
802
+ obsolete_pdbs: Mapping[str, str],
803
+ kalign_binary_path: str,
804
+ strict_error_check: bool = False,
805
+ _zero_center_positions: bool = True,
806
+ ) -> SingleHitResult:
807
+ """Tries to extract template features from a single HHSearch hit."""
808
+ # Fail hard if we can't get the PDB ID and chain name from the hit.
809
+ hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
810
+
811
+ if hit_pdb_code not in release_dates:
812
+ if hit_pdb_code in obsolete_pdbs:
813
+ hit_pdb_code = obsolete_pdbs[hit_pdb_code]
814
+
815
+ mapping = _build_query_to_hit_index_mapping(
816
+ hit.query,
817
+ hit.hit_sequence,
818
+ hit.indices_hit,
819
+ hit.indices_query,
820
+ query_sequence,
821
+ )
822
+
823
+ # The mapping is from the query to the actual hit sequence, so we need to
824
+ # remove gaps (which regardless have a missing confidence score).
825
+ template_sequence = hit.hit_sequence.replace("-", "")
826
+
827
+ cif_path = os.path.join(mmcif_dir, hit_pdb_code + ".cif")
828
+ logging.info(
829
+ "Reading PDB entry from %s. Query: %s, template: %s",
830
+ cif_path,
831
+ query_sequence,
832
+ template_sequence,
833
+ )
834
+ # Fail if we can't find the mmCIF file.
835
+ with open(cif_path, "r") as cif_file:
836
+ cif_string = cif_file.read()
837
+
838
+ parsing_result = mmcif_parsing.parse(
839
+ file_id=hit_pdb_code, mmcif_string=cif_string
840
+ )
841
+
842
+ if parsing_result.mmcif_object is not None:
843
+ hit_release_date = datetime.datetime.strptime(
844
+ parsing_result.mmcif_object.header["release_date"], "%Y-%m-%d"
845
+ )
846
+ if hit_release_date > max_template_date:
847
+ error = "Template %s date (%s) > max template date (%s)." % (
848
+ hit_pdb_code,
849
+ hit_release_date,
850
+ max_template_date,
851
+ )
852
+ if strict_error_check:
853
+ return SingleHitResult(features=None, error=error, warning=None)
854
+ else:
855
+ logging.info(error)
856
+ return SingleHitResult(features=None, error=None, warning=None)
857
+
858
+ try:
859
+ features, realign_warning = _extract_template_features(
860
+ mmcif_object=parsing_result.mmcif_object,
861
+ pdb_id=hit_pdb_code,
862
+ mapping=mapping,
863
+ template_sequence=template_sequence,
864
+ query_sequence=query_sequence,
865
+ template_chain_id=hit_chain_id,
866
+ kalign_binary_path=kalign_binary_path,
867
+ _zero_center_positions=_zero_center_positions,
868
+ )
869
+ features["template_sum_probs"] = [hit.sum_probs]
870
+
871
+ # It is possible there were some errors when parsing the other chains in the
872
+ # mmCIF file, but the template features for the chain we want were still
873
+ # computed. In such case the mmCIF parsing errors are not relevant.
874
+ return SingleHitResult(
875
+ features=features, error=None, warning=realign_warning
876
+ )
877
+ except (
878
+ NoChainsError,
879
+ NoAtomDataInTemplateError,
880
+ TemplateAtomMaskAllZerosError,
881
+ ) as e:
882
+ # These 3 errors indicate missing mmCIF experimental data rather than a
883
+ # problem with the template search, so turn them into warnings.
884
+ warning = (
885
+ "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
886
+ "%s, mmCIF parsing errors: %s"
887
+ % (
888
+ hit_pdb_code,
889
+ hit_chain_id,
890
+ hit.sum_probs,
891
+ hit.index,
892
+ str(e),
893
+ parsing_result.errors,
894
+ )
895
+ )
896
+ if strict_error_check:
897
+ return SingleHitResult(features=None, error=warning, warning=None)
898
+ else:
899
+ return SingleHitResult(features=None, error=None, warning=warning)
900
+ except Error as e:
901
+ error = (
902
+ "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
903
+ "%s, mmCIF parsing errors: %s"
904
+ % (
905
+ hit_pdb_code,
906
+ hit_chain_id,
907
+ hit.sum_probs,
908
+ hit.index,
909
+ str(e),
910
+ parsing_result.errors,
911
+ )
912
+ )
913
+ return SingleHitResult(features=None, error=error, warning=None)
914
+
915
+
916
+ @dataclasses.dataclass(frozen=True)
917
+ class TemplateSearchResult:
918
+ features: Mapping[str, Any]
919
+ errors: Sequence[str]
920
+ warnings: Sequence[str]
921
+
922
+
923
+ class TemplateHitFeaturizer:
924
+ """A class for turning hhr hits to template features."""
925
+ def __init__(
926
+ self,
927
+ mmcif_dir: str,
928
+ max_template_date: str,
929
+ max_hits: int,
930
+ kalign_binary_path: str,
931
+ release_dates_path: Optional[str] = None,
932
+ obsolete_pdbs_path: Optional[str] = None,
933
+ strict_error_check: bool = False,
934
+ _shuffle_top_k_prefiltered: Optional[int] = None,
935
+ _zero_center_positions: bool = True,
936
+ ):
937
+ """Initializes the Template Search.
938
+
939
+ Args:
940
+ mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
941
+ is found by HHSearch, this directory is used to retrieve the template
942
+ data.
943
+ max_template_date: The maximum date permitted for template structures. No
944
+ template with date higher than this date will be returned. In ISO8601
945
+ date format, YYYY-MM-DD.
946
+ max_hits: The maximum number of templates that will be returned.
947
+ kalign_binary_path: The path to a kalign executable used for template
948
+ realignment.
949
+ release_dates_path: An optional path to a file with a mapping from PDB IDs
950
+ to their release dates. Thanks to this we don't have to redundantly
951
+ parse mmCIF files to get that information.
952
+ obsolete_pdbs_path: An optional path to a file containing a mapping from
953
+ obsolete PDB IDs to the PDB IDs of their replacements.
954
+ strict_error_check: If True, then the following will be treated as errors:
955
+ * If any template date is after the max_template_date.
956
+ * If any template has identical PDB ID to the query.
957
+ * If any template is a duplicate of the query.
958
+ * Any feature computation errors.
959
+ """
960
+ self._mmcif_dir = mmcif_dir
961
+ if not glob.glob(os.path.join(self._mmcif_dir, "*.cif")):
962
+ logging.error("Could not find CIFs in %s", self._mmcif_dir)
963
+ raise ValueError(f"Could not find CIFs in {self._mmcif_dir}")
964
+
965
+ try:
966
+ self._max_template_date = datetime.datetime.strptime(
967
+ max_template_date, "%Y-%m-%d"
968
+ )
969
+ except ValueError:
970
+ raise ValueError(
971
+ "max_template_date must be set and have format YYYY-MM-DD."
972
+ )
973
+ self.max_hits = max_hits
974
+ self._kalign_binary_path = kalign_binary_path
975
+ self._strict_error_check = strict_error_check
976
+
977
+ if release_dates_path:
978
+ logging.info(
979
+ "Using precomputed release dates %s.", release_dates_path
980
+ )
981
+ self._release_dates = _parse_release_dates(release_dates_path)
982
+ else:
983
+ self._release_dates = {}
984
+
985
+ if obsolete_pdbs_path:
986
+ logging.info(
987
+ "Using precomputed obsolete pdbs %s.", obsolete_pdbs_path
988
+ )
989
+ self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
990
+ else:
991
+ self._obsolete_pdbs = {}
992
+
993
+ self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
994
+ self._zero_center_positions = _zero_center_positions
995
+
996
+ def get_templates(
997
+ self,
998
+ query_sequence: str,
999
+ query_pdb_code: Optional[str],
1000
+ query_release_date: Optional[datetime.datetime],
1001
+ hits: Sequence[parsers.TemplateHit],
1002
+ ) -> TemplateSearchResult:
1003
+ """Computes the templates for given query sequence (more details above)."""
1004
+ logging.info("Searching for template for: %s", query_pdb_code)
1005
+
1006
+ template_features = {}
1007
+ for template_feature_name in TEMPLATE_FEATURES:
1008
+ template_features[template_feature_name] = []
1009
+
1010
+ # Always use a max_template_date. Set to query_release_date minus 60 days
1011
+ # if that's earlier.
1012
+ template_cutoff_date = self._max_template_date
1013
+ if query_release_date:
1014
+ delta = datetime.timedelta(days=60)
1015
+ if query_release_date - delta < template_cutoff_date:
1016
+ template_cutoff_date = query_release_date - delta
1017
+ assert template_cutoff_date < query_release_date
1018
+ assert template_cutoff_date <= self._max_template_date
1019
+
1020
+ num_hits = 0
1021
+ errors = []
1022
+ warnings = []
1023
+
1024
+ filtered = []
1025
+ for hit in hits:
1026
+ prefilter_result = _prefilter_hit(
1027
+ query_sequence=query_sequence,
1028
+ query_pdb_code=query_pdb_code,
1029
+ hit=hit,
1030
+ max_template_date=template_cutoff_date,
1031
+ release_dates=self._release_dates,
1032
+ obsolete_pdbs=self._obsolete_pdbs,
1033
+ strict_error_check=self._strict_error_check,
1034
+ )
1035
+
1036
+ if prefilter_result.error:
1037
+ errors.append(prefilter_result.error)
1038
+
1039
+ if prefilter_result.warning:
1040
+ warnings.append(prefilter_result.warning)
1041
+
1042
+ if prefilter_result.valid:
1043
+ filtered.append(hit)
1044
+
1045
+ filtered = list(
1046
+ sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
1047
+ )
1048
+
1049
+ idx = list(range(len(filtered)))
1050
+ if(self._shuffle_top_k_prefiltered):
1051
+ stk = self._shuffle_top_k_prefiltered
1052
+ idx[:stk] = np.random.permutation(idx[:stk])
1053
+
1054
+ for i in idx:
1055
+ # We got all the templates we wanted, stop processing hits.
1056
+ if num_hits >= self.max_hits:
1057
+ break
1058
+
1059
+ hit = filtered[i]
1060
+
1061
+ result = _process_single_hit(
1062
+ query_sequence=query_sequence,
1063
+ query_pdb_code=query_pdb_code,
1064
+ hit=hit,
1065
+ mmcif_dir=self._mmcif_dir,
1066
+ max_template_date=template_cutoff_date,
1067
+ release_dates=self._release_dates,
1068
+ obsolete_pdbs=self._obsolete_pdbs,
1069
+ strict_error_check=self._strict_error_check,
1070
+ kalign_binary_path=self._kalign_binary_path,
1071
+ _zero_center_positions=self._zero_center_positions,
1072
+ )
1073
+
1074
+ if result.error:
1075
+ errors.append(result.error)
1076
+
1077
+ # There could be an error even if there are some results, e.g. thrown by
1078
+ # other unparsable chains in the same mmCIF file.
1079
+ if result.warning:
1080
+ warnings.append(result.warning)
1081
+
1082
+ if result.features is None:
1083
+ logging.info(
1084
+ "Skipped invalid hit %s, error: %s, warning: %s",
1085
+ hit.name,
1086
+ result.error,
1087
+ result.warning,
1088
+ )
1089
+ else:
1090
+ # Increment the hit counter, since we got features out of this hit.
1091
+ num_hits += 1
1092
+ for k in template_features:
1093
+ template_features[k].append(result.features[k])
1094
+
1095
+ for name in template_features:
1096
+ if num_hits > 0:
1097
+ template_features[name] = np.stack(
1098
+ template_features[name], axis=0
1099
+ ).astype(TEMPLATE_FEATURES[name])
1100
+ else:
1101
+ # Make sure the feature has correct dtype even if empty.
1102
+ template_features[name] = np.array(
1103
+ [], dtype=TEMPLATE_FEATURES[name]
1104
+ )
1105
+
1106
+ return TemplateSearchResult(
1107
+ features=template_features, errors=errors, warnings=warnings
1108
+ )
openfold/data/tools/__init__.py ADDED
File without changes
openfold/data/tools/hhblits.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Library to run HHblits from Python."""
17
+ import glob
18
+ import logging
19
+ import os
20
+ import subprocess
21
+ from typing import Any, Mapping, Optional, Sequence
22
+
23
+ from openfold.data.tools import utils
24
+
25
+
26
+ _HHBLITS_DEFAULT_P = 20
27
+ _HHBLITS_DEFAULT_Z = 500
28
+
29
+
30
+ class HHBlits:
31
+ """Python wrapper of the HHblits binary."""
32
+
33
+ def __init__(
34
+ self,
35
+ *,
36
+ binary_path: str,
37
+ databases: Sequence[str],
38
+ n_cpu: int = 4,
39
+ n_iter: int = 3,
40
+ e_value: float = 0.001,
41
+ maxseq: int = 1_000_000,
42
+ realign_max: int = 100_000,
43
+ maxfilt: int = 100_000,
44
+ min_prefilter_hits: int = 1000,
45
+ all_seqs: bool = False,
46
+ alt: Optional[int] = None,
47
+ p: int = _HHBLITS_DEFAULT_P,
48
+ z: int = _HHBLITS_DEFAULT_Z,
49
+ ):
50
+ """Initializes the Python HHblits wrapper.
51
+
52
+ Args:
53
+ binary_path: The path to the HHblits executable.
54
+ databases: A sequence of HHblits database paths. This should be the
55
+ common prefix for the database files (i.e. up to but not including
56
+ _hhm.ffindex etc.)
57
+ n_cpu: The number of CPUs to give HHblits.
58
+ n_iter: The number of HHblits iterations.
59
+ e_value: The E-value, see HHblits docs for more details.
60
+ maxseq: The maximum number of rows in an input alignment. Note that this
61
+ parameter is only supported in HHBlits version 3.1 and higher.
62
+ realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
63
+ maxfilt: Max number of hits allowed to pass the 2nd prefilter.
64
+ HHblits default: 20000.
65
+ min_prefilter_hits: Min number of hits to pass prefilter.
66
+ HHblits default: 100.
67
+ all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
68
+ HHblits default: False.
69
+ alt: Show up to this many alternative alignments.
70
+ p: Minimum Prob for a hit to be included in the output hhr file.
71
+ HHblits default: 20.
72
+ z: Hard cap on number of hits reported in the hhr file.
73
+ HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
74
+
75
+ Raises:
76
+ RuntimeError: If HHblits binary not found within the path.
77
+ """
78
+ self.binary_path = binary_path
79
+ self.databases = databases
80
+
81
+ for database_path in self.databases:
82
+ if not glob.glob(database_path + "_*"):
83
+ logging.error(
84
+ "Could not find HHBlits database %s", database_path
85
+ )
86
+ raise ValueError(
87
+ f"Could not find HHBlits database {database_path}"
88
+ )
89
+
90
+ self.n_cpu = n_cpu
91
+ self.n_iter = n_iter
92
+ self.e_value = e_value
93
+ self.maxseq = maxseq
94
+ self.realign_max = realign_max
95
+ self.maxfilt = maxfilt
96
+ self.min_prefilter_hits = min_prefilter_hits
97
+ self.all_seqs = all_seqs
98
+ self.alt = alt
99
+ self.p = p
100
+ self.z = z
101
+
102
+ def query(self, input_fasta_path: str) -> Mapping[str, Any]:
103
+ """Queries the database using HHblits."""
104
+ with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
105
+ a3m_path = os.path.join(query_tmp_dir, "output.a3m")
106
+
107
+ db_cmd = []
108
+ for db_path in self.databases:
109
+ db_cmd.append("-d")
110
+ db_cmd.append(db_path)
111
+ cmd = [
112
+ self.binary_path,
113
+ "-i",
114
+ input_fasta_path,
115
+ "-cpu",
116
+ str(self.n_cpu),
117
+ "-oa3m",
118
+ a3m_path,
119
+ "-o",
120
+ "/dev/null",
121
+ "-n",
122
+ str(self.n_iter),
123
+ "-e",
124
+ str(self.e_value),
125
+ "-maxseq",
126
+ str(self.maxseq),
127
+ "-realign_max",
128
+ str(self.realign_max),
129
+ "-maxfilt",
130
+ str(self.maxfilt),
131
+ "-min_prefilter_hits",
132
+ str(self.min_prefilter_hits),
133
+ ]
134
+ if self.all_seqs:
135
+ cmd += ["-all"]
136
+ if self.alt:
137
+ cmd += ["-alt", str(self.alt)]
138
+ if self.p != _HHBLITS_DEFAULT_P:
139
+ cmd += ["-p", str(self.p)]
140
+ if self.z != _HHBLITS_DEFAULT_Z:
141
+ cmd += ["-Z", str(self.z)]
142
+ cmd += db_cmd
143
+
144
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
145
+ process = subprocess.Popen(
146
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
147
+ )
148
+
149
+ with utils.timing("HHblits query"):
150
+ stdout, stderr = process.communicate()
151
+ retcode = process.wait()
152
+
153
+ if retcode:
154
+ # Logs have a 15k character limit, so log HHblits error line by line.
155
+ logging.error("HHblits failed. HHblits stderr begin:")
156
+ for error_line in stderr.decode("utf-8").splitlines():
157
+ if error_line.strip():
158
+ logging.error(error_line.strip())
159
+ logging.error("HHblits stderr end")
160
+ raise RuntimeError(
161
+ "HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n"
162
+ % (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8"))
163
+ )
164
+
165
+ with open(a3m_path) as f:
166
+ a3m = f.read()
167
+
168
+ raw_output = dict(
169
+ a3m=a3m,
170
+ output=stdout,
171
+ stderr=stderr,
172
+ n_iter=self.n_iter,
173
+ e_value=self.e_value,
174
+ )
175
+ return raw_output
openfold/data/tools/hhsearch.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Library to run HHsearch from Python."""
17
+ import glob
18
+ import logging
19
+ import os
20
+ import subprocess
21
+ from typing import Sequence
22
+
23
+ from openfold.data.tools import utils
24
+
25
+
26
+ class HHSearch:
27
+ """Python wrapper of the HHsearch binary."""
28
+
29
+ def __init__(
30
+ self,
31
+ *,
32
+ binary_path: str,
33
+ databases: Sequence[str],
34
+ n_cpu: int = 2,
35
+ maxseq: int = 1_000_000,
36
+ ):
37
+ """Initializes the Python HHsearch wrapper.
38
+
39
+ Args:
40
+ binary_path: The path to the HHsearch executable.
41
+ databases: A sequence of HHsearch database paths. This should be the
42
+ common prefix for the database files (i.e. up to but not including
43
+ _hhm.ffindex etc.)
44
+ n_cpu: The number of CPUs to use
45
+ maxseq: The maximum number of rows in an input alignment. Note that this
46
+ parameter is only supported in HHBlits version 3.1 and higher.
47
+
48
+ Raises:
49
+ RuntimeError: If HHsearch binary not found within the path.
50
+ """
51
+ self.binary_path = binary_path
52
+ self.databases = databases
53
+ self.n_cpu = n_cpu
54
+ self.maxseq = maxseq
55
+
56
+ for database_path in self.databases:
57
+ if not glob.glob(database_path + "_*"):
58
+ logging.error(
59
+ "Could not find HHsearch database %s", database_path
60
+ )
61
+ raise ValueError(
62
+ f"Could not find HHsearch database {database_path}"
63
+ )
64
+
65
+ def query(self, a3m: str) -> str:
66
+ """Queries the database using HHsearch using a given a3m."""
67
+ with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
68
+ input_path = os.path.join(query_tmp_dir, "query.a3m")
69
+ hhr_path = os.path.join(query_tmp_dir, "output.hhr")
70
+ with open(input_path, "w") as f:
71
+ f.write(a3m)
72
+
73
+ db_cmd = []
74
+ for db_path in self.databases:
75
+ db_cmd.append("-d")
76
+ db_cmd.append(db_path)
77
+ cmd = [
78
+ self.binary_path,
79
+ "-i",
80
+ input_path,
81
+ "-o",
82
+ hhr_path,
83
+ "-maxseq",
84
+ str(self.maxseq),
85
+ "-cpu",
86
+ str(self.n_cpu),
87
+ ] + db_cmd
88
+
89
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
90
+ process = subprocess.Popen(
91
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
92
+ )
93
+ with utils.timing("HHsearch query"):
94
+ stdout, stderr = process.communicate()
95
+ retcode = process.wait()
96
+
97
+ if retcode:
98
+ # Stderr is truncated to prevent proto size errors in Beam.
99
+ raise RuntimeError(
100
+ "HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
101
+ % (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
102
+ )
103
+
104
+ with open(hhr_path) as f:
105
+ hhr = f.read()
106
+ return hhr
openfold/data/tools/jackhmmer.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Library to run Jackhmmer from Python."""
17
+
18
+ from concurrent import futures
19
+ import glob
20
+ import logging
21
+ import os
22
+ import subprocess
23
+ from typing import Any, Callable, Mapping, Optional, Sequence
24
+ from urllib import request
25
+
26
+ from openfold.data.tools import utils
27
+
28
+
29
+ class Jackhmmer:
30
+ """Python wrapper of the Jackhmmer binary."""
31
+
32
+ def __init__(
33
+ self,
34
+ *,
35
+ binary_path: str,
36
+ database_path: str,
37
+ n_cpu: int = 8,
38
+ n_iter: int = 1,
39
+ e_value: float = 0.0001,
40
+ z_value: Optional[int] = None,
41
+ get_tblout: bool = False,
42
+ filter_f1: float = 0.0005,
43
+ filter_f2: float = 0.00005,
44
+ filter_f3: float = 0.0000005,
45
+ incdom_e: Optional[float] = None,
46
+ dom_e: Optional[float] = None,
47
+ num_streamed_chunks: Optional[int] = None,
48
+ streaming_callback: Optional[Callable[[int], None]] = None,
49
+ ):
50
+ """Initializes the Python Jackhmmer wrapper.
51
+
52
+ Args:
53
+ binary_path: The path to the jackhmmer executable.
54
+ database_path: The path to the jackhmmer database (FASTA format).
55
+ n_cpu: The number of CPUs to give Jackhmmer.
56
+ n_iter: The number of Jackhmmer iterations.
57
+ e_value: The E-value, see Jackhmmer docs for more details.
58
+ z_value: The Z-value, see Jackhmmer docs for more details.
59
+ get_tblout: Whether to save tblout string.
60
+ filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
61
+ filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
62
+ filter_f3: Forward pre-filter, set to >1.0 to turn off.
63
+ incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
64
+ round.
65
+ dom_e: Domain e-value criteria for inclusion in tblout.
66
+ num_streamed_chunks: Number of database chunks to stream over.
67
+ streaming_callback: Callback function run after each chunk iteration with
68
+ the iteration number as argument.
69
+ """
70
+ self.binary_path = binary_path
71
+ self.database_path = database_path
72
+ self.num_streamed_chunks = num_streamed_chunks
73
+
74
+ if (
75
+ not os.path.exists(self.database_path)
76
+ and num_streamed_chunks is None
77
+ ):
78
+ logging.error("Could not find Jackhmmer database %s", database_path)
79
+ raise ValueError(
80
+ f"Could not find Jackhmmer database {database_path}"
81
+ )
82
+
83
+ self.n_cpu = n_cpu
84
+ self.n_iter = n_iter
85
+ self.e_value = e_value
86
+ self.z_value = z_value
87
+ self.filter_f1 = filter_f1
88
+ self.filter_f2 = filter_f2
89
+ self.filter_f3 = filter_f3
90
+ self.incdom_e = incdom_e
91
+ self.dom_e = dom_e
92
+ self.get_tblout = get_tblout
93
+ self.streaming_callback = streaming_callback
94
+
95
+ def _query_chunk(
96
+ self, input_fasta_path: str, database_path: str
97
+ ) -> Mapping[str, Any]:
98
+ """Queries the database chunk using Jackhmmer."""
99
+ with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
100
+ sto_path = os.path.join(query_tmp_dir, "output.sto")
101
+
102
+ # The F1/F2/F3 are the expected proportion to pass each of the filtering
103
+ # stages (which get progressively more expensive), reducing these
104
+ # speeds up the pipeline at the expensive of sensitivity. They are
105
+ # currently set very low to make querying Mgnify run in a reasonable
106
+ # amount of time.
107
+ cmd_flags = [
108
+ # Don't pollute stdout with Jackhmmer output.
109
+ "-o",
110
+ "/dev/null",
111
+ "-A",
112
+ sto_path,
113
+ "--noali",
114
+ "--F1",
115
+ str(self.filter_f1),
116
+ "--F2",
117
+ str(self.filter_f2),
118
+ "--F3",
119
+ str(self.filter_f3),
120
+ "--incE",
121
+ str(self.e_value),
122
+ # Report only sequences with E-values <= x in per-sequence output.
123
+ "-E",
124
+ str(self.e_value),
125
+ "--cpu",
126
+ str(self.n_cpu),
127
+ "-N",
128
+ str(self.n_iter),
129
+ ]
130
+ if self.get_tblout:
131
+ tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
132
+ cmd_flags.extend(["--tblout", tblout_path])
133
+
134
+ if self.z_value:
135
+ cmd_flags.extend(["-Z", str(self.z_value)])
136
+
137
+ if self.dom_e is not None:
138
+ cmd_flags.extend(["--domE", str(self.dom_e)])
139
+
140
+ if self.incdom_e is not None:
141
+ cmd_flags.extend(["--incdomE", str(self.incdom_e)])
142
+
143
+ cmd = (
144
+ [self.binary_path]
145
+ + cmd_flags
146
+ + [input_fasta_path, database_path]
147
+ )
148
+
149
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
150
+ process = subprocess.Popen(
151
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
152
+ )
153
+ with utils.timing(
154
+ f"Jackhmmer ({os.path.basename(database_path)}) query"
155
+ ):
156
+ _, stderr = process.communicate()
157
+ retcode = process.wait()
158
+
159
+ if retcode:
160
+ raise RuntimeError(
161
+ "Jackhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
162
+ )
163
+
164
+ # Get e-values for each target name
165
+ tbl = ""
166
+ if self.get_tblout:
167
+ with open(tblout_path) as f:
168
+ tbl = f.read()
169
+
170
+ with open(sto_path) as f:
171
+ sto = f.read()
172
+
173
+ raw_output = dict(
174
+ sto=sto,
175
+ tbl=tbl,
176
+ stderr=stderr,
177
+ n_iter=self.n_iter,
178
+ e_value=self.e_value,
179
+ )
180
+
181
+ return raw_output
182
+
183
+ def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
184
+ """Queries the database using Jackhmmer."""
185
+ if self.num_streamed_chunks is None:
186
+ return [self._query_chunk(input_fasta_path, self.database_path)]
187
+
188
+ db_basename = os.path.basename(self.database_path)
189
+ db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
190
+ db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
191
+
192
+ # Remove existing files to prevent OOM
193
+ for f in glob.glob(db_local_chunk("[0-9]*")):
194
+ try:
195
+ os.remove(f)
196
+ except OSError:
197
+ print(f"OSError while deleting {f}")
198
+
199
+ # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
200
+ with futures.ThreadPoolExecutor(max_workers=2) as executor:
201
+ chunked_output = []
202
+ for i in range(1, self.num_streamed_chunks + 1):
203
+ # Copy the chunk locally
204
+ if i == 1:
205
+ future = executor.submit(
206
+ request.urlretrieve,
207
+ db_remote_chunk(i),
208
+ db_local_chunk(i),
209
+ )
210
+ if i < self.num_streamed_chunks:
211
+ next_future = executor.submit(
212
+ request.urlretrieve,
213
+ db_remote_chunk(i + 1),
214
+ db_local_chunk(i + 1),
215
+ )
216
+
217
+ # Run Jackhmmer with the chunk
218
+ future.result()
219
+ chunked_output.append(
220
+ self._query_chunk(input_fasta_path, db_local_chunk(i))
221
+ )
222
+
223
+ # Remove the local copy of the chunk
224
+ os.remove(db_local_chunk(i))
225
+ future = next_future
226
+ if self.streaming_callback:
227
+ self.streaming_callback(i)
228
+ return chunked_output
openfold/data/tools/kalign.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """A Python wrapper for Kalign."""
17
+ import os
18
+ import subprocess
19
+ from typing import Sequence
20
+
21
+ from absl import logging
22
+
23
+ from openfold.data.tools import utils
24
+
25
+
26
+ def _to_a3m(sequences: Sequence[str]) -> str:
27
+ """Converts sequences to an a3m file."""
28
+ names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
29
+ a3m = []
30
+ for sequence, name in zip(sequences, names):
31
+ a3m.append(u">" + name + u"\n")
32
+ a3m.append(sequence + u"\n")
33
+ return "".join(a3m)
34
+
35
+
36
+ class Kalign:
37
+ """Python wrapper of the Kalign binary."""
38
+
39
+ def __init__(self, *, binary_path: str):
40
+ """Initializes the Python Kalign wrapper.
41
+
42
+ Args:
43
+ binary_path: The path to the Kalign binary.
44
+
45
+ Raises:
46
+ RuntimeError: If Kalign binary not found within the path.
47
+ """
48
+ self.binary_path = binary_path
49
+
50
+ def align(self, sequences: Sequence[str]) -> str:
51
+ """Aligns the sequences and returns the alignment in A3M string.
52
+
53
+ Args:
54
+ sequences: A list of query sequence strings. The sequences have to be at
55
+ least 6 residues long (Kalign requires this). Note that the order in
56
+ which you give the sequences might alter the output slightly as
57
+ different alignment tree might get constructed.
58
+
59
+ Returns:
60
+ A string with the alignment in a3m format.
61
+
62
+ Raises:
63
+ RuntimeError: If Kalign fails.
64
+ ValueError: If any of the sequences is less than 6 residues long.
65
+ """
66
+ logging.info("Aligning %d sequences", len(sequences))
67
+
68
+ for s in sequences:
69
+ if len(s) < 6:
70
+ raise ValueError(
71
+ "Kalign requires all sequences to be at least 6 "
72
+ "residues long. Got %s (%d residues)." % (s, len(s))
73
+ )
74
+
75
+ with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
76
+ input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
77
+ output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
78
+
79
+ with open(input_fasta_path, "w") as f:
80
+ f.write(_to_a3m(sequences))
81
+
82
+ cmd = [
83
+ self.binary_path,
84
+ "-i",
85
+ input_fasta_path,
86
+ "-o",
87
+ output_a3m_path,
88
+ "-format",
89
+ "fasta",
90
+ ]
91
+
92
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
93
+ process = subprocess.Popen(
94
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
95
+ )
96
+
97
+ with utils.timing("Kalign query"):
98
+ stdout, stderr = process.communicate()
99
+ retcode = process.wait()
100
+ logging.info(
101
+ "Kalign stdout:\n%s\n\nstderr:\n%s\n",
102
+ stdout.decode("utf-8"),
103
+ stderr.decode("utf-8"),
104
+ )
105
+
106
+ if retcode:
107
+ raise RuntimeError(
108
+ "Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
109
+ % (stdout.decode("utf-8"), stderr.decode("utf-8"))
110
+ )
111
+
112
+ with open(output_a3m_path) as f:
113
+ a3m = f.read()
114
+
115
+ return a3m
openfold/data/tools/utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Common utilities for data pipeline tools."""
17
+ import contextlib
18
+ import datetime
19
+ import logging
20
+ import shutil
21
+ import tempfile
22
+ import time
23
+ from typing import Optional
24
+
25
+
26
+ @contextlib.contextmanager
27
+ def tmpdir_manager(base_dir: Optional[str] = None):
28
+ """Context manager that deletes a temporary directory on exit."""
29
+ tmpdir = tempfile.mkdtemp(dir=base_dir)
30
+ try:
31
+ yield tmpdir
32
+ finally:
33
+ shutil.rmtree(tmpdir, ignore_errors=True)
34
+
35
+
36
+ @contextlib.contextmanager
37
+ def timing(msg: str):
38
+ logging.info("Started %s", msg)
39
+ tic = time.perf_counter()
40
+ yield
41
+ toc = time.perf_counter()
42
+ logging.info("Finished %s in %.3f seconds", msg, toc - tic)
43
+
44
+
45
+ def to_date(s: str):
46
+ return datetime.datetime(
47
+ year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
48
+ )
openfold/np/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import importlib as importlib
4
+
5
+ _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
6
+ __all__ = [
7
+ os.path.basename(f)[:-3]
8
+ for f in _files
9
+ if os.path.isfile(f) and not f.endswith("__init__.py")
10
+ ]
11
+ _modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
12
+ for _m in _modules:
13
+ globals()[_m[0]] = _m[1]
14
+
15
+ # Avoid needlessly cluttering the global namespace
16
+ del _files, _m, _modules
openfold/np/protein.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Protein data type."""
17
+ import dataclasses
18
+ import io
19
+ from typing import Any, Sequence, Mapping, Optional
20
+ import re
21
+ import string
22
+
23
+ from openfold.np import residue_constants
24
+ from Bio import PDB
25
+ import numpy as np
26
+
27
+
28
+ FeatureDict = Mapping[str, np.ndarray]
29
+ ModelOutput = Mapping[str, Any] # Is a nested dict.
30
+ PICO_TO_ANGSTROM = 0.01
31
+
32
+ @dataclasses.dataclass(frozen=True)
33
+ class Protein:
34
+ """Protein structure representation."""
35
+
36
+ # Cartesian coordinates of atoms in angstroms. The atom types correspond to
37
+ # residue_constants.atom_types, i.e. the first three are N, CA, CB.
38
+ atom_positions: np.ndarray # [num_res, num_atom_type, 3]
39
+
40
+ # Amino-acid type for each residue represented as an integer between 0 and
41
+ # 20, where 20 is 'X'.
42
+ aatype: np.ndarray # [num_res]
43
+
44
+ # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
45
+ # is present and 0.0 if not. This should be used for loss masking.
46
+ atom_mask: np.ndarray # [num_res, num_atom_type]
47
+
48
+ # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
49
+ residue_index: np.ndarray # [num_res]
50
+
51
+ # B-factors, or temperature factors, of each residue (in sq. angstroms units),
52
+ # representing the displacement of the residue from its ground truth mean
53
+ # value.
54
+ b_factors: np.ndarray # [num_res, num_atom_type]
55
+
56
+ # Chain indices for multi-chain predictions
57
+ chain_index: Optional[np.ndarray] = None
58
+
59
+ # Optional remark about the protein. Included as a comment in output PDB
60
+ # files
61
+ remark: Optional[str] = None
62
+
63
+ # Templates used to generate this protein (prediction-only)
64
+ parents: Optional[Sequence[str]] = None
65
+
66
+ # Chain corresponding to each parent
67
+ parents_chain_index: Optional[Sequence[int]] = None
68
+
69
+
70
+ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
71
+ """Takes a PDB string and constructs a Protein object.
72
+
73
+ WARNING: All non-standard residue types will be converted into UNK. All
74
+ non-standard atoms will be ignored.
75
+
76
+ Args:
77
+ pdb_str: The contents of the pdb file
78
+ chain_id: If None, then the pdb file must contain a single chain (which
79
+ will be parsed). If chain_id is specified (e.g. A), then only that chain
80
+ is parsed.
81
+
82
+ Returns:
83
+ A new `Protein` parsed from the pdb contents.
84
+ """
85
+ pdb_fh = pdb_str
86
+ parser = PDB.PDBParser(QUIET=True)
87
+ structure = parser.get_structure("none", pdb_fh)
88
+ models = list(structure.get_models())
89
+ if len(models) != 1:
90
+ raise ValueError(
91
+ f"Only single model PDBs are supported. Found {len(models)} models."
92
+ )
93
+ model = models[0]
94
+
95
+ atom_positions = []
96
+ aatype = []
97
+ atom_mask = []
98
+ residue_index = []
99
+ chain_ids = []
100
+ b_factors = []
101
+
102
+ for chain in model:
103
+ if(chain_id is not None and chain.id != chain_id):
104
+ continue
105
+ for res in chain:
106
+ if res.id[2] != " ":
107
+ raise ValueError(
108
+ f"PDB contains an insertion code at chain {chain.id} and residue "
109
+ f"index {res.id[1]}. These are not supported."
110
+ )
111
+ res_shortname = residue_constants.restype_3to1.get(res.resname, "X")
112
+ restype_idx = residue_constants.restype_order.get(
113
+ res_shortname, residue_constants.restype_num
114
+ )
115
+ pos = np.zeros((residue_constants.atom_type_num, 3))
116
+ mask = np.zeros((residue_constants.atom_type_num,))
117
+ res_b_factors = np.zeros((residue_constants.atom_type_num,))
118
+ for atom in res:
119
+ if atom.name not in residue_constants.atom_types:
120
+ continue
121
+ pos[residue_constants.atom_order[atom.name]] = atom.coord
122
+ mask[residue_constants.atom_order[atom.name]] = 1.0
123
+ res_b_factors[
124
+ residue_constants.atom_order[atom.name]
125
+ ] = atom.bfactor
126
+ if np.sum(mask) < 0.5:
127
+ # If no known atom positions are reported for the residue then skip it.
128
+ continue
129
+ aatype.append(restype_idx)
130
+ atom_positions.append(pos)
131
+ atom_mask.append(mask)
132
+ residue_index.append(res.id[1])
133
+ chain_ids.append(chain.id)
134
+ b_factors.append(res_b_factors)
135
+
136
+ parents = None
137
+ parents_chain_index = None
138
+ if("PARENT" in pdb_str):
139
+ parents = []
140
+ parents_chain_index = []
141
+ chain_id = 0
142
+ for l in pdb_str.split("\n"):
143
+ if("PARENT" in l):
144
+ if(not "N/A" in l):
145
+ parent_names = l.split()[1:]
146
+ parents.extend(parent_names)
147
+ parents_chain_index.extend([
148
+ chain_id for _ in parent_names
149
+ ])
150
+ chain_id += 1
151
+
152
+ unique_chain_ids = np.unique(chain_ids)
153
+ chain_id_mapping = {cid: n for n, cid in enumerate(string.ascii_uppercase)}
154
+ chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
155
+
156
+ return Protein(
157
+ atom_positions=np.array(atom_positions),
158
+ atom_mask=np.array(atom_mask),
159
+ aatype=np.array(aatype),
160
+ residue_index=np.array(residue_index),
161
+ chain_index=chain_index,
162
+ b_factors=np.array(b_factors),
163
+ parents=parents,
164
+ parents_chain_index=parents_chain_index,
165
+ )
166
+
167
+
168
+ def from_proteinnet_string(proteinnet_str: str) -> Protein:
169
+ tag_re = r'(\[[A-Z]+\]\n)'
170
+ tags = [
171
+ tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0
172
+ ]
173
+ groups = zip(tags[0::2], [l.split('\n') for l in tags[1::2]])
174
+
175
+ atoms = ['N', 'CA', 'C']
176
+ aatype = None
177
+ atom_positions = None
178
+ atom_mask = None
179
+ for g in groups:
180
+ if("[PRIMARY]" == g[0]):
181
+ seq = g[1][0].strip()
182
+ for i in range(len(seq)):
183
+ if(seq[i] not in residue_constants.restypes):
184
+ seq[i] = 'X'
185
+ aatype = np.array([
186
+ residue_constants.restype_order.get(
187
+ res_symbol, residue_constants.restype_num
188
+ ) for res_symbol in seq
189
+ ])
190
+ elif("[TERTIARY]" == g[0]):
191
+ tertiary = []
192
+ for axis in range(3):
193
+ tertiary.append(list(map(float, g[1][axis].split())))
194
+ tertiary_np = np.array(tertiary)
195
+ atom_positions = np.zeros(
196
+ (len(tertiary[0])//3, residue_constants.atom_type_num, 3)
197
+ ).astype(np.float32)
198
+ for i, atom in enumerate(atoms):
199
+ atom_positions[:, residue_constants.atom_order[atom], :] = (
200
+ np.transpose(tertiary_np[:, i::3])
201
+ )
202
+ atom_positions *= PICO_TO_ANGSTROM
203
+ elif("[MASK]" == g[0]):
204
+ mask = np.array(list(map({'-': 0, '+': 1}.get, g[1][0].strip())))
205
+ atom_mask = np.zeros(
206
+ (len(mask), residue_constants.atom_type_num,)
207
+ ).astype(np.float32)
208
+ for i, atom in enumerate(atoms):
209
+ atom_mask[:, residue_constants.atom_order[atom]] = 1
210
+ atom_mask *= mask[..., None]
211
+
212
+ return Protein(
213
+ atom_positions=atom_positions,
214
+ atom_mask=atom_mask,
215
+ aatype=aatype,
216
+ residue_index=np.arange(len(aatype)),
217
+ b_factors=None,
218
+ )
219
+
220
+
221
+ def get_pdb_headers(prot: Protein, chain_id: int = 0) -> Sequence[str]:
222
+ pdb_headers = []
223
+
224
+ remark = prot.remark
225
+ if(remark is not None):
226
+ pdb_headers.append(f"REMARK {remark}")
227
+
228
+ parents = prot.parents
229
+ parents_chain_index = prot.parents_chain_index
230
+ if(parents_chain_index is not None):
231
+ parents = [
232
+ p for i, p in zip(parents_chain_index, parents) if i == chain_id
233
+ ]
234
+
235
+ if(parents is None or len(parents) == 0):
236
+ parents = ["N/A"]
237
+
238
+ pdb_headers.append(f"PARENT {' '.join(parents)}")
239
+
240
+ return pdb_headers
241
+
242
+
243
+ def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
244
+ """ Add pdb headers to an existing PDB string. Useful during multi-chain
245
+ recycling
246
+ """
247
+ out_pdb_lines = []
248
+ lines = pdb_str.split('\n')
249
+
250
+ remark = prot.remark
251
+ if(remark is not None):
252
+ out_pdb_lines.append(f"REMARK {remark}")
253
+
254
+ parents_per_chain = None
255
+ if(prot.parents is not None and len(prot.parents) > 0):
256
+ parents_per_chain = []
257
+ if(prot.parents_chain_index is not None):
258
+ cur_chain = prot.parents_chain_index[0]
259
+ parent_dict = {}
260
+ for p, i in zip(prot.parents, prot.parents_chain_index):
261
+ parent_dict.setdefault(str(i), [])
262
+ parent_dict[str(i)].append(p)
263
+
264
+ max_idx = max([int(chain_idx) for chain_idx in parent_dict])
265
+ for i in range(max_idx + 1):
266
+ chain_parents = parent_dict.get(str(i), ["N/A"])
267
+ parents_per_chain.append(chain_parents)
268
+ else:
269
+ parents_per_chain.append(prot.parents)
270
+ else:
271
+ parents_per_chain = [["N/A"]]
272
+
273
+ make_parent_line = lambda p: f"PARENT {' '.join(p)}"
274
+
275
+ out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
276
+
277
+ chain_counter = 0
278
+ for i, l in enumerate(lines):
279
+ if("PARENT" not in l and "REMARK" not in l):
280
+ out_pdb_lines.append(l)
281
+ if("TER" in l and not "END" in lines[i + 1]):
282
+ chain_counter += 1
283
+ if(not chain_counter >= len(parents_per_chain)):
284
+ chain_parents = parents_per_chain[chain_counter]
285
+ else:
286
+ chain_parents = ["N/A"]
287
+
288
+ out_pdb_lines.append(make_parent_line(chain_parents))
289
+
290
+ return '\n'.join(out_pdb_lines)
291
+
292
+
293
+ def to_pdb(prot: Protein) -> str:
294
+ """Converts a `Protein` instance to a PDB string.
295
+
296
+ Args:
297
+ prot: The protein to convert to PDB.
298
+
299
+ Returns:
300
+ PDB string.
301
+ """
302
+ restypes = residue_constants.restypes + ["X"]
303
+ res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], "UNK")
304
+ atom_types = residue_constants.atom_types
305
+
306
+ pdb_lines = []
307
+
308
+ atom_mask = prot.atom_mask
309
+ aatype = prot.aatype
310
+ atom_positions = prot.atom_positions
311
+ residue_index = prot.residue_index.astype(np.int32)
312
+ b_factors = prot.b_factors
313
+ chain_index = prot.chain_index
314
+
315
+ if np.any(aatype > residue_constants.restype_num):
316
+ raise ValueError("Invalid aatypes.")
317
+
318
+ headers = get_pdb_headers(prot)
319
+ if(len(headers) > 0):
320
+ pdb_lines.extend(headers)
321
+
322
+ n = aatype.shape[0]
323
+ atom_index = 1
324
+ prev_chain_index = 0
325
+ chain_tags = string.ascii_uppercase
326
+ # Add all atom sites.
327
+ for i in range(n):
328
+ res_name_3 = res_1to3(aatype[i])
329
+ for atom_name, pos, mask, b_factor in zip(
330
+ atom_types, atom_positions[i], atom_mask[i], b_factors[i]
331
+ ):
332
+ if mask < 0.5:
333
+ chain_tag = "A"
334
+ if(chain_index is not None):
335
+ chain_tag = chain_tags[chain_index[i]]
336
+ continue
337
+
338
+ record_type = "ATOM"
339
+ name = atom_name if len(atom_name) == 4 else f" {atom_name}"
340
+ alt_loc = ""
341
+ insertion_code = ""
342
+ occupancy = 1.00
343
+ element = atom_name[
344
+ 0
345
+ ] # Protein supports only C, N, O, S, this works.
346
+ charge = ""
347
+
348
+ chain_tag = "A"
349
+ if(chain_index is not None):
350
+ chain_tag = chain_tags[chain_index[i]]
351
+
352
+ # PDB is a columnar format, every space matters here!
353
+ atom_line = (
354
+ f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
355
+ f"{res_name_3:>3} {chain_tag:>1}"
356
+ f"{residue_index[i]:>4}{insertion_code:>1} "
357
+ f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
358
+ f"{occupancy:>6.2f}{b_factor:>6.2f} "
359
+ f"{element:>2}{charge:>2}"
360
+ )
361
+ pdb_lines.append(atom_line)
362
+ atom_index += 1
363
+
364
+ should_terminate = (i == n - 1)
365
+ if(chain_index is not None):
366
+ if(i != n - 1 and chain_index[i + 1] != prev_chain_index):
367
+ should_terminate = True
368
+ prev_chain_index = chain_index[i + 1]
369
+
370
+ if(should_terminate):
371
+ # Close the chain.
372
+ chain_end = "TER"
373
+ chain_termination_line = (
374
+ f"{chain_end:<6}{atom_index:>5} "
375
+ f"{res_1to3(aatype[i]):>3} "
376
+ f"{chain_tag:>1}{residue_index[i]:>4}"
377
+ )
378
+ pdb_lines.append(chain_termination_line)
379
+ atom_index += 1
380
+
381
+ if(i != n - 1):
382
+ # "prev" is a misnomer here. This happens at the beginning of
383
+ # each new chain.
384
+ pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
385
+
386
+ pdb_lines.append("END")
387
+ pdb_lines.append("")
388
+ return "\n".join(pdb_lines)
389
+
390
+
391
+ def ideal_atom_mask(prot: Protein) -> np.ndarray:
392
+ """Computes an ideal atom mask.
393
+
394
+ `Protein.atom_mask` typically is defined according to the atoms that are
395
+ reported in the PDB. This function computes a mask according to heavy atoms
396
+ that should be present in the given sequence of amino acids.
397
+
398
+ Args:
399
+ prot: `Protein` whose fields are `numpy.ndarray` objects.
400
+
401
+ Returns:
402
+ An ideal atom mask.
403
+ """
404
+ return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
405
+
406
+
407
+ def from_prediction(
408
+ features: FeatureDict,
409
+ result: ModelOutput,
410
+ b_factors: Optional[np.ndarray] = None,
411
+ chain_index: Optional[np.ndarray] = None,
412
+ remark: Optional[str] = None,
413
+ parents: Optional[Sequence[str]] = None,
414
+ parents_chain_index: Optional[Sequence[int]] = None
415
+ ) -> Protein:
416
+ """Assembles a protein from a prediction.
417
+
418
+ Args:
419
+ features: Dictionary holding model inputs.
420
+ result: Dictionary holding model outputs.
421
+ b_factors: (Optional) B-factors to use for the protein.
422
+ chain_index: (Optional) Chain indices for multi-chain predictions
423
+ remark: (Optional) Remark about the prediction
424
+ parents: (Optional) List of template names
425
+ Returns:
426
+ A protein instance.
427
+ """
428
+ if b_factors is None:
429
+ b_factors = np.zeros_like(result["final_atom_mask"])
430
+
431
+ return Protein(
432
+ aatype=features["aatype"],
433
+ atom_positions=result["final_atom_positions"],
434
+ atom_mask=result["final_atom_mask"],
435
+ residue_index=features["residue_index"] + 1,
436
+ b_factors=b_factors,
437
+ chain_index=chain_index,
438
+ remark=remark,
439
+ parents=parents,
440
+ parents_chain_index=parents_chain_index,
441
+ )
openfold/np/relax/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import importlib as importlib
4
+
5
+ _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
6
+ __all__ = [
7
+ os.path.basename(f)[:-3]
8
+ for f in _files
9
+ if os.path.isfile(f) and not f.endswith("__init__.py")
10
+ ]
11
+ _modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
12
+ for _m in _modules:
13
+ globals()[_m[0]] = _m[1]
14
+
15
+ # Avoid needlessly cluttering the global namespace
16
+ del _files, _m, _modules
openfold/np/relax/amber_minimize.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Restrained Amber Minimization of a structure."""
17
+
18
+ import io
19
+ import time
20
+ from typing import Collection, Optional, Sequence
21
+
22
+ from absl import logging
23
+ from openfold.np import (
24
+ protein,
25
+ residue_constants,
26
+ )
27
+ import openfold.utils.loss as loss
28
+ from openfold.np.relax import cleanup, utils
29
+ import ml_collections
30
+ import numpy as np
31
+ from simtk import openmm
32
+ from simtk import unit
33
+ from simtk.openmm import app as openmm_app
34
+ from simtk.openmm.app.internal.pdbstructure import PdbStructure
35
+
36
+ ENERGY = unit.kilocalories_per_mole
37
+ LENGTH = unit.angstroms
38
+
39
+
40
+ def will_restrain(atom: openmm_app.Atom, rset: str) -> bool:
41
+ """Returns True if the atom will be restrained by the given restraint set."""
42
+
43
+ if rset == "non_hydrogen":
44
+ return atom.element.name != "hydrogen"
45
+ elif rset == "c_alpha":
46
+ return atom.name == "CA"
47
+
48
+
49
+ def _add_restraints(
50
+ system: openmm.System,
51
+ reference_pdb: openmm_app.PDBFile,
52
+ stiffness: unit.Unit,
53
+ rset: str,
54
+ exclude_residues: Sequence[int],
55
+ ):
56
+ """Adds a harmonic potential that restrains the system to a structure."""
57
+ assert rset in ["non_hydrogen", "c_alpha"]
58
+
59
+ force = openmm.CustomExternalForce(
60
+ "0.5 * k * ((x-x0)^2 + (y-y0)^2 + (z-z0)^2)"
61
+ )
62
+ force.addGlobalParameter("k", stiffness)
63
+ for p in ["x0", "y0", "z0"]:
64
+ force.addPerParticleParameter(p)
65
+
66
+ for i, atom in enumerate(reference_pdb.topology.atoms()):
67
+ if atom.residue.index in exclude_residues:
68
+ continue
69
+ if will_restrain(atom, rset):
70
+ force.addParticle(i, reference_pdb.positions[i])
71
+ logging.info(
72
+ "Restraining %d / %d particles.",
73
+ force.getNumParticles(),
74
+ system.getNumParticles(),
75
+ )
76
+ system.addForce(force)
77
+
78
+
79
+ def _openmm_minimize(
80
+ pdb_str: str,
81
+ max_iterations: int,
82
+ tolerance: unit.Unit,
83
+ stiffness: unit.Unit,
84
+ restraint_set: str,
85
+ exclude_residues: Sequence[int],
86
+ use_gpu: bool,
87
+ ):
88
+ """Minimize energy via openmm."""
89
+
90
+ pdb_file = io.StringIO(pdb_str)
91
+ pdb = openmm_app.PDBFile(pdb_file)
92
+
93
+ force_field = openmm_app.ForceField("amber99sb.xml")
94
+ constraints = openmm_app.HBonds
95
+ system = force_field.createSystem(pdb.topology, constraints=constraints)
96
+ if stiffness > 0 * ENERGY / (LENGTH ** 2):
97
+ _add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)
98
+
99
+ integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
100
+ platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
101
+ simulation = openmm_app.Simulation(
102
+ pdb.topology, system, integrator, platform
103
+ )
104
+ simulation.context.setPositions(pdb.positions)
105
+
106
+ ret = {}
107
+ state = simulation.context.getState(getEnergy=True, getPositions=True)
108
+ ret["einit"] = state.getPotentialEnergy().value_in_unit(ENERGY)
109
+ ret["posinit"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
110
+ simulation.minimizeEnergy(maxIterations=max_iterations, tolerance=tolerance)
111
+ state = simulation.context.getState(getEnergy=True, getPositions=True)
112
+ ret["efinal"] = state.getPotentialEnergy().value_in_unit(ENERGY)
113
+ ret["pos"] = state.getPositions(asNumpy=True).value_in_unit(LENGTH)
114
+ ret["min_pdb"] = _get_pdb_string(simulation.topology, state.getPositions())
115
+ return ret
116
+
117
+
118
+ def _get_pdb_string(topology: openmm_app.Topology, positions: unit.Quantity):
119
+ """Returns a pdb string provided OpenMM topology and positions."""
120
+ with io.StringIO() as f:
121
+ openmm_app.PDBFile.writeFile(topology, positions, f)
122
+ return f.getvalue()
123
+
124
+
125
+ def _check_cleaned_atoms(pdb_cleaned_string: str, pdb_ref_string: str):
126
+ """Checks that no atom positions have been altered by cleaning."""
127
+ cleaned = openmm_app.PDBFile(io.StringIO(pdb_cleaned_string))
128
+ reference = openmm_app.PDBFile(io.StringIO(pdb_ref_string))
129
+
130
+ cl_xyz = np.array(cleaned.getPositions().value_in_unit(LENGTH))
131
+ ref_xyz = np.array(reference.getPositions().value_in_unit(LENGTH))
132
+
133
+ for ref_res, cl_res in zip(
134
+ reference.topology.residues(), cleaned.topology.residues()
135
+ ):
136
+ assert ref_res.name == cl_res.name
137
+ for rat in ref_res.atoms():
138
+ for cat in cl_res.atoms():
139
+ if cat.name == rat.name:
140
+ if not np.array_equal(
141
+ cl_xyz[cat.index], ref_xyz[rat.index]
142
+ ):
143
+ raise ValueError(
144
+ f"Coordinates of cleaned atom {cat} do not match "
145
+ f"coordinates of reference atom {rat}."
146
+ )
147
+
148
+
149
+ def _check_residues_are_well_defined(prot: protein.Protein):
150
+ """Checks that all residues contain non-empty atom sets."""
151
+ if (prot.atom_mask.sum(axis=-1) == 0).any():
152
+ raise ValueError(
153
+ "Amber minimization can only be performed on proteins with"
154
+ " well-defined residues. This protein contains at least"
155
+ " one residue with no atoms."
156
+ )
157
+
158
+
159
+ def _check_atom_mask_is_ideal(prot):
160
+ """Sanity-check the atom mask is ideal, up to a possible OXT."""
161
+ atom_mask = prot.atom_mask
162
+ ideal_atom_mask = protein.ideal_atom_mask(prot)
163
+ utils.assert_equal_nonterminal_atom_types(atom_mask, ideal_atom_mask)
164
+
165
+
166
+ def clean_protein(prot: protein.Protein, checks: bool = True):
167
+ """Adds missing atoms to Protein instance.
168
+
169
+ Args:
170
+ prot: A `protein.Protein` instance.
171
+ checks: A `bool` specifying whether to add additional checks to the cleaning
172
+ process.
173
+
174
+ Returns:
175
+ pdb_string: A string of the cleaned protein.
176
+ """
177
+ _check_atom_mask_is_ideal(prot)
178
+
179
+ # Clean pdb.
180
+ prot_pdb_string = protein.to_pdb(prot)
181
+ pdb_file = io.StringIO(prot_pdb_string)
182
+ alterations_info = {}
183
+ fixed_pdb = cleanup.fix_pdb(pdb_file, alterations_info)
184
+ fixed_pdb_file = io.StringIO(fixed_pdb)
185
+ pdb_structure = PdbStructure(fixed_pdb_file)
186
+ cleanup.clean_structure(pdb_structure, alterations_info)
187
+
188
+ logging.info("alterations info: %s", alterations_info)
189
+
190
+ # Write pdb file of cleaned structure.
191
+ as_file = openmm_app.PDBFile(pdb_structure)
192
+ pdb_string = _get_pdb_string(as_file.getTopology(), as_file.getPositions())
193
+ if checks:
194
+ _check_cleaned_atoms(pdb_string, prot_pdb_string)
195
+
196
+ headers = protein.get_pdb_headers(prot)
197
+ if(len(headers) > 0):
198
+ pdb_string = '\n'.join(['\n'.join(headers), pdb_string])
199
+
200
+ return pdb_string
201
+
202
+
203
+ def make_atom14_positions(prot):
204
+ """Constructs denser atom positions (14 dimensions instead of 37)."""
205
+ restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37
206
+ restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14
207
+ restype_atom14_mask = []
208
+
209
+ for rt in residue_constants.restypes:
210
+ atom_names = residue_constants.restype_name_to_atom14_names[
211
+ residue_constants.restype_1to3[rt]
212
+ ]
213
+
214
+ restype_atom14_to_atom37.append(
215
+ [
216
+ (residue_constants.atom_order[name] if name else 0)
217
+ for name in atom_names
218
+ ]
219
+ )
220
+
221
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
222
+ restype_atom37_to_atom14.append(
223
+ [
224
+ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
225
+ for name in residue_constants.atom_types
226
+ ]
227
+ )
228
+
229
+ restype_atom14_mask.append(
230
+ [(1.0 if name else 0.0) for name in atom_names]
231
+ )
232
+
233
+ # Add dummy mapping for restype 'UNK'.
234
+ restype_atom14_to_atom37.append([0] * 14)
235
+ restype_atom37_to_atom14.append([0] * 37)
236
+ restype_atom14_mask.append([0.0] * 14)
237
+
238
+ restype_atom14_to_atom37 = np.array(
239
+ restype_atom14_to_atom37, dtype=np.int32
240
+ )
241
+ restype_atom37_to_atom14 = np.array(
242
+ restype_atom37_to_atom14, dtype=np.int32
243
+ )
244
+ restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32)
245
+
246
+ # Create the mapping for (residx, atom14) --> atom37, i.e. an array
247
+ # with shape (num_res, 14) containing the atom37 indices for this protein.
248
+ residx_atom14_to_atom37 = restype_atom14_to_atom37[prot["aatype"]]
249
+ residx_atom14_mask = restype_atom14_mask[prot["aatype"]]
250
+
251
+ # Create a mask for known ground truth positions.
252
+ residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis(
253
+ prot["all_atom_mask"], residx_atom14_to_atom37, axis=1
254
+ ).astype(np.float32)
255
+
256
+ # Gather the ground truth positions.
257
+ residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * (
258
+ np.take_along_axis(
259
+ prot["all_atom_positions"],
260
+ residx_atom14_to_atom37[..., None],
261
+ axis=1,
262
+ )
263
+ )
264
+
265
+ prot["atom14_atom_exists"] = residx_atom14_mask
266
+ prot["atom14_gt_exists"] = residx_atom14_gt_mask
267
+ prot["atom14_gt_positions"] = residx_atom14_gt_positions
268
+
269
+ prot["residx_atom14_to_atom37"] = residx_atom14_to_atom37.astype(np.int64)
270
+
271
+ # Create the gather indices for mapping back.
272
+ residx_atom37_to_atom14 = restype_atom37_to_atom14[prot["aatype"]]
273
+ prot["residx_atom37_to_atom14"] = residx_atom37_to_atom14.astype(np.int64)
274
+
275
+ # Create the corresponding mask.
276
+ restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
277
+ for restype, restype_letter in enumerate(residue_constants.restypes):
278
+ restype_name = residue_constants.restype_1to3[restype_letter]
279
+ atom_names = residue_constants.residue_atoms[restype_name]
280
+ for atom_name in atom_names:
281
+ atom_type = residue_constants.atom_order[atom_name]
282
+ restype_atom37_mask[restype, atom_type] = 1
283
+
284
+ residx_atom37_mask = restype_atom37_mask[prot["aatype"]]
285
+ prot["atom37_atom_exists"] = residx_atom37_mask
286
+
287
+ # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
288
+ # alternative ground truth coordinates where the naming is swapped
289
+ restype_3 = [
290
+ residue_constants.restype_1to3[res]
291
+ for res in residue_constants.restypes
292
+ ]
293
+ restype_3 += ["UNK"]
294
+
295
+ # Matrices for renaming ambiguous atoms.
296
+ all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3}
297
+ for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
298
+ correspondences = np.arange(14)
299
+ for source_atom_swap, target_atom_swap in swap.items():
300
+ source_index = residue_constants.restype_name_to_atom14_names[
301
+ resname
302
+ ].index(source_atom_swap)
303
+ target_index = residue_constants.restype_name_to_atom14_names[
304
+ resname
305
+ ].index(target_atom_swap)
306
+ correspondences[source_index] = target_index
307
+ correspondences[target_index] = source_index
308
+ renaming_matrix = np.zeros((14, 14), dtype=np.float32)
309
+ for index, correspondence in enumerate(correspondences):
310
+ renaming_matrix[index, correspondence] = 1.0
311
+ all_matrices[resname] = renaming_matrix.astype(np.float32)
312
+ renaming_matrices = np.stack(
313
+ [all_matrices[restype] for restype in restype_3]
314
+ )
315
+
316
+ # Pick the transformation matrices for the given residue sequence
317
+ # shape (num_res, 14, 14).
318
+ renaming_transform = renaming_matrices[prot["aatype"]]
319
+
320
+ # Apply it to the ground truth positions. shape (num_res, 14, 3).
321
+ alternative_gt_positions = np.einsum(
322
+ "rac,rab->rbc", residx_atom14_gt_positions, renaming_transform
323
+ )
324
+ prot["atom14_alt_gt_positions"] = alternative_gt_positions
325
+
326
+ # Create the mask for the alternative ground truth (differs from the
327
+ # ground truth mask, if only one of the atoms in an ambiguous pair has a
328
+ # ground truth position).
329
+ alternative_gt_mask = np.einsum(
330
+ "ra,rab->rb", residx_atom14_gt_mask, renaming_transform
331
+ )
332
+
333
+ prot["atom14_alt_gt_exists"] = alternative_gt_mask
334
+
335
+ # Create an ambiguous atoms mask. shape: (21, 14).
336
+ restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32)
337
+ for resname, swap in residue_constants.residue_atom_renaming_swaps.items():
338
+ for atom_name1, atom_name2 in swap.items():
339
+ restype = residue_constants.restype_order[
340
+ residue_constants.restype_3to1[resname]
341
+ ]
342
+ atom_idx1 = residue_constants.restype_name_to_atom14_names[
343
+ resname
344
+ ].index(atom_name1)
345
+ atom_idx2 = residue_constants.restype_name_to_atom14_names[
346
+ resname
347
+ ].index(atom_name2)
348
+ restype_atom14_is_ambiguous[restype, atom_idx1] = 1
349
+ restype_atom14_is_ambiguous[restype, atom_idx2] = 1
350
+
351
+ # From this create an ambiguous_mask for the given sequence.
352
+ prot["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
353
+ prot["aatype"]
354
+ ]
355
+
356
+ return prot
357
+
358
+
359
+ def find_violations(prot_np: protein.Protein):
360
+ """Analyzes a protein and returns structural violation information.
361
+
362
+ Args:
363
+ prot_np: A protein.
364
+
365
+ Returns:
366
+ violations: A `dict` of structure components with structural violations.
367
+ violation_metrics: A `dict` of violation metrics.
368
+ """
369
+ batch = {
370
+ "aatype": prot_np.aatype,
371
+ "all_atom_positions": prot_np.atom_positions.astype(np.float32),
372
+ "all_atom_mask": prot_np.atom_mask.astype(np.float32),
373
+ "residue_index": prot_np.residue_index,
374
+ }
375
+
376
+ batch["seq_mask"] = np.ones_like(batch["aatype"], np.float32)
377
+ batch = make_atom14_positions(batch)
378
+
379
+ violations = loss.find_structural_violations_np(
380
+ batch=batch,
381
+ atom14_pred_positions=batch["atom14_gt_positions"],
382
+ config=ml_collections.ConfigDict(
383
+ {
384
+ "violation_tolerance_factor": 12, # Taken from model config.
385
+ "clash_overlap_tolerance": 1.5, # Taken from model config.
386
+ }
387
+ ),
388
+ )
389
+ violation_metrics = loss.compute_violation_metrics_np(
390
+ batch=batch,
391
+ atom14_pred_positions=batch["atom14_gt_positions"],
392
+ violations=violations,
393
+ )
394
+
395
+ return violations, violation_metrics
396
+
397
+
398
+ def get_violation_metrics(prot: protein.Protein):
399
+ """Computes violation and alignment metrics."""
400
+ structural_violations, struct_metrics = find_violations(prot)
401
+ violation_idx = np.flatnonzero(
402
+ structural_violations["total_per_residue_violations_mask"]
403
+ )
404
+
405
+ struct_metrics["residue_violations"] = violation_idx
406
+ struct_metrics["num_residue_violations"] = len(violation_idx)
407
+ struct_metrics["structural_violations"] = structural_violations
408
+ return struct_metrics
409
+
410
+
411
+ def _run_one_iteration(
412
+ *,
413
+ pdb_string: str,
414
+ max_iterations: int,
415
+ tolerance: float,
416
+ stiffness: float,
417
+ restraint_set: str,
418
+ max_attempts: int,
419
+ exclude_residues: Optional[Collection[int]] = None,
420
+ use_gpu: bool,
421
+ ):
422
+ """Runs the minimization pipeline.
423
+
424
+ Args:
425
+ pdb_string: A pdb string.
426
+ max_iterations: An `int` specifying the maximum number of L-BFGS iterations.
427
+ A value of 0 specifies no limit.
428
+ tolerance: kcal/mol, the energy tolerance of L-BFGS.
429
+ stiffness: kcal/mol A**2, spring constant of heavy atom restraining
430
+ potential.
431
+ restraint_set: The set of atoms to restrain.
432
+ max_attempts: The maximum number of minimization attempts.
433
+ exclude_residues: An optional list of zero-indexed residues to exclude from
434
+ restraints.
435
+ use_gpu: Whether to run relaxation on GPU
436
+ Returns:
437
+ A `dict` of minimization info.
438
+ """
439
+ exclude_residues = exclude_residues or []
440
+
441
+ # Assign physical dimensions.
442
+ tolerance = tolerance * ENERGY
443
+ stiffness = stiffness * ENERGY / (LENGTH ** 2)
444
+
445
+ start = time.perf_counter()
446
+ minimized = False
447
+ attempts = 0
448
+ while not minimized and attempts < max_attempts:
449
+ attempts += 1
450
+ try:
451
+ logging.info(
452
+ "Minimizing protein, attempt %d of %d.", attempts, max_attempts
453
+ )
454
+ ret = _openmm_minimize(
455
+ pdb_string,
456
+ max_iterations=max_iterations,
457
+ tolerance=tolerance,
458
+ stiffness=stiffness,
459
+ restraint_set=restraint_set,
460
+ exclude_residues=exclude_residues,
461
+ use_gpu=use_gpu,
462
+ )
463
+ minimized = True
464
+ except Exception as e: # pylint: disable=broad-except
465
+ print(e)
466
+ logging.info(e)
467
+ if not minimized:
468
+ raise ValueError(f"Minimization failed after {max_attempts} attempts.")
469
+ ret["opt_time"] = time.perf_counter() - start
470
+ ret["min_attempts"] = attempts
471
+ return ret
472
+
473
+
474
+ def run_pipeline(
475
+ prot: protein.Protein,
476
+ stiffness: float,
477
+ use_gpu: bool,
478
+ max_outer_iterations: int = 1,
479
+ place_hydrogens_every_iteration: bool = True,
480
+ max_iterations: int = 0,
481
+ tolerance: float = 2.39,
482
+ restraint_set: str = "non_hydrogen",
483
+ max_attempts: int = 100,
484
+ checks: bool = True,
485
+ exclude_residues: Optional[Sequence[int]] = None,
486
+ ):
487
+ """Run iterative amber relax.
488
+
489
+ Successive relax iterations are performed until all violations have been
490
+ resolved. Each iteration involves a restrained Amber minimization, with
491
+ restraint exclusions determined by violation-participating residues.
492
+
493
+ Args:
494
+ prot: A protein to be relaxed.
495
+ stiffness: kcal/mol A**2, the restraint stiffness.
496
+ use_gpu: Whether to run on GPU
497
+ max_outer_iterations: The maximum number of iterative minimization.
498
+ place_hydrogens_every_iteration: Whether hydrogens are re-initialized
499
+ prior to every minimization.
500
+ max_iterations: An `int` specifying the maximum number of L-BFGS steps
501
+ per relax iteration. A value of 0 specifies no limit.
502
+ tolerance: kcal/mol, the energy tolerance of L-BFGS.
503
+ The default value is the OpenMM default.
504
+ restraint_set: The set of atoms to restrain.
505
+ max_attempts: The maximum number of minimization attempts per iteration.
506
+ checks: Whether to perform cleaning checks.
507
+ exclude_residues: An optional list of zero-indexed residues to exclude from
508
+ restraints.
509
+
510
+ Returns:
511
+ out: A dictionary of output values.
512
+ """
513
+
514
+ # `protein.to_pdb` will strip any poorly-defined residues so we need to
515
+ # perform this check before `clean_protein`.
516
+ _check_residues_are_well_defined(prot)
517
+ pdb_string = clean_protein(prot, checks=checks)
518
+
519
+ # We keep the input around to restore metadata deleted by the relaxer
520
+ input_prot = prot
521
+
522
+ exclude_residues = exclude_residues or []
523
+ exclude_residues = set(exclude_residues)
524
+ violations = np.inf
525
+ iteration = 0
526
+
527
+ while violations > 0 and iteration < max_outer_iterations:
528
+ ret = _run_one_iteration(
529
+ pdb_string=pdb_string,
530
+ exclude_residues=exclude_residues,
531
+ max_iterations=max_iterations,
532
+ tolerance=tolerance,
533
+ stiffness=stiffness,
534
+ restraint_set=restraint_set,
535
+ max_attempts=max_attempts,
536
+ use_gpu=use_gpu,
537
+ )
538
+
539
+ headers = protein.get_pdb_headers(prot)
540
+ if(len(headers) > 0):
541
+ ret["min_pdb"] = '\n'.join(['\n'.join(headers), ret["min_pdb"]])
542
+
543
+ prot = protein.from_pdb_string(ret["min_pdb"])
544
+ if place_hydrogens_every_iteration:
545
+ pdb_string = clean_protein(prot, checks=True)
546
+ else:
547
+ pdb_string = ret["min_pdb"]
548
+ ret.update(get_violation_metrics(prot))
549
+ ret.update(
550
+ {
551
+ "num_exclusions": len(exclude_residues),
552
+ "iteration": iteration,
553
+ }
554
+ )
555
+ violations = ret["violations_per_residue"]
556
+ exclude_residues = exclude_residues.union(ret["residue_violations"])
557
+
558
+ logging.info(
559
+ "Iteration completed: Einit %.2f Efinal %.2f Time %.2f s "
560
+ "num residue violations %d num residue exclusions %d ",
561
+ ret["einit"],
562
+ ret["efinal"],
563
+ ret["opt_time"],
564
+ ret["num_residue_violations"],
565
+ ret["num_exclusions"],
566
+ )
567
+ iteration += 1
568
+ return ret
569
+
570
+
571
+ def get_initial_energies(
572
+ pdb_strs: Sequence[str],
573
+ stiffness: float = 0.0,
574
+ restraint_set: str = "non_hydrogen",
575
+ exclude_residues: Optional[Sequence[int]] = None,
576
+ ):
577
+ """Returns initial potential energies for a sequence of PDBs.
578
+
579
+ Assumes the input PDBs are ready for minimization, and all have the same
580
+ topology.
581
+ Allows time to be saved by not pdbfixing / rebuilding the system.
582
+
583
+ Args:
584
+ pdb_strs: List of PDB strings.
585
+ stiffness: kcal/mol A**2, spring constant of heavy atom restraining
586
+ potential.
587
+ restraint_set: Which atom types to restrain.
588
+ exclude_residues: An optional list of zero-indexed residues to exclude from
589
+ restraints.
590
+
591
+ Returns:
592
+ A list of initial energies in the same order as pdb_strs.
593
+ """
594
+ exclude_residues = exclude_residues or []
595
+
596
+ openmm_pdbs = [
597
+ openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs
598
+ ]
599
+ force_field = openmm_app.ForceField("amber99sb.xml")
600
+ system = force_field.createSystem(
601
+ openmm_pdbs[0].topology, constraints=openmm_app.HBonds
602
+ )
603
+ stiffness = stiffness * ENERGY / (LENGTH ** 2)
604
+ if stiffness > 0 * ENERGY / (LENGTH ** 2):
605
+ _add_restraints(
606
+ system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues
607
+ )
608
+ simulation = openmm_app.Simulation(
609
+ openmm_pdbs[0].topology,
610
+ system,
611
+ openmm.LangevinIntegrator(0, 0.01, 0.0),
612
+ openmm.Platform.getPlatformByName("CPU"),
613
+ )
614
+ energies = []
615
+ for pdb in openmm_pdbs:
616
+ try:
617
+ simulation.context.setPositions(pdb.positions)
618
+ state = simulation.context.getState(getEnergy=True)
619
+ energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
620
+ except Exception as e: # pylint: disable=broad-except
621
+ logging.error(
622
+ "Error getting initial energy, returning large value %s", e
623
+ )
624
+ energies.append(unit.Quantity(1e20, ENERGY))
625
+ return energies
openfold/np/relax/cleanup.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
16
+
17
+ fix_pdb uses a third-party tool. We also support fixing some additional edge
18
+ cases like removing chains of length one (see clean_structure).
19
+ """
20
+ import io
21
+
22
+ import pdbfixer
23
+ from simtk.openmm import app
24
+ from simtk.openmm.app import element
25
+
26
+
27
+ def fix_pdb(pdbfile, alterations_info):
28
+ """Apply pdbfixer to the contents of a PDB file; return a PDB string result.
29
+
30
+ 1) Replaces nonstandard residues.
31
+ 2) Removes heterogens (non protein residues) including water.
32
+ 3) Adds missing residues and missing atoms within existing residues.
33
+ 4) Adds hydrogens assuming pH=7.0.
34
+ 5) KeepIds is currently true, so the fixer must keep the existing chain and
35
+ residue identifiers. This will fail for some files in wider PDB that have
36
+ invalid IDs.
37
+
38
+ Args:
39
+ pdbfile: Input PDB file handle.
40
+ alterations_info: A dict that will store details of changes made.
41
+
42
+ Returns:
43
+ A PDB string representing the fixed structure.
44
+ """
45
+ fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
46
+ fixer.findNonstandardResidues()
47
+ alterations_info["nonstandard_residues"] = fixer.nonstandardResidues
48
+ fixer.replaceNonstandardResidues()
49
+ _remove_heterogens(fixer, alterations_info, keep_water=False)
50
+ fixer.findMissingResidues()
51
+ alterations_info["missing_residues"] = fixer.missingResidues
52
+ fixer.findMissingAtoms()
53
+ alterations_info["missing_heavy_atoms"] = fixer.missingAtoms
54
+ alterations_info["missing_terminals"] = fixer.missingTerminals
55
+ fixer.addMissingAtoms(seed=0)
56
+ fixer.addMissingHydrogens()
57
+ out_handle = io.StringIO()
58
+ app.PDBFile.writeFile(
59
+ fixer.topology, fixer.positions, out_handle, keepIds=True
60
+ )
61
+ return out_handle.getvalue()
62
+
63
+
64
+ def clean_structure(pdb_structure, alterations_info):
65
+ """Applies additional fixes to an OpenMM structure, to handle edge cases.
66
+
67
+ Args:
68
+ pdb_structure: An OpenMM structure to modify and fix.
69
+ alterations_info: A dict that will store details of changes made.
70
+ """
71
+ _replace_met_se(pdb_structure, alterations_info)
72
+ _remove_chains_of_length_one(pdb_structure, alterations_info)
73
+
74
+
75
+ def _remove_heterogens(fixer, alterations_info, keep_water):
76
+ """Removes the residues that Pdbfixer considers to be heterogens.
77
+
78
+ Args:
79
+ fixer: A Pdbfixer instance.
80
+ alterations_info: A dict that will store details of changes made.
81
+ keep_water: If True, water (HOH) is not considered to be a heterogen.
82
+ """
83
+ initial_resnames = set()
84
+ for chain in fixer.topology.chains():
85
+ for residue in chain.residues():
86
+ initial_resnames.add(residue.name)
87
+ fixer.removeHeterogens(keepWater=keep_water)
88
+ final_resnames = set()
89
+ for chain in fixer.topology.chains():
90
+ for residue in chain.residues():
91
+ final_resnames.add(residue.name)
92
+ alterations_info["removed_heterogens"] = initial_resnames.difference(
93
+ final_resnames
94
+ )
95
+
96
+
97
+ def _replace_met_se(pdb_structure, alterations_info):
98
+ """Replace the Se in any MET residues that were not marked as modified."""
99
+ modified_met_residues = []
100
+ for res in pdb_structure.iter_residues():
101
+ name = res.get_name_with_spaces().strip()
102
+ if name == "MET":
103
+ s_atom = res.get_atom("SD")
104
+ if s_atom.element_symbol == "Se":
105
+ s_atom.element_symbol = "S"
106
+ s_atom.element = element.get_by_symbol("S")
107
+ modified_met_residues.append(s_atom.residue_number)
108
+ alterations_info["Se_in_MET"] = modified_met_residues
109
+
110
+
111
+ def _remove_chains_of_length_one(pdb_structure, alterations_info):
112
+ """Removes chains that correspond to a single amino acid.
113
+
114
+ A single amino acid in a chain is both N and C terminus. There is no force
115
+ template for this case.
116
+
117
+ Args:
118
+ pdb_structure: An OpenMM pdb_structure to modify and fix.
119
+ alterations_info: A dict that will store details of changes made.
120
+ """
121
+ removed_chains = {}
122
+ for model in pdb_structure.iter_models():
123
+ valid_chains = [c for c in model.iter_chains() if len(c) > 1]
124
+ invalid_chain_ids = [
125
+ c.chain_id for c in model.iter_chains() if len(c) <= 1
126
+ ]
127
+ model.chains = valid_chains
128
+ for chain_id in invalid_chain_ids:
129
+ model.chains_by_id.pop(chain_id)
130
+ removed_chains[model.number] = invalid_chain_ids
131
+ alterations_info["removed_chains"] = removed_chains
openfold/np/relax/relax.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Amber relaxation."""
17
+ from typing import Any, Dict, Sequence, Tuple
18
+ from openfold.np import protein
19
+ from openfold.np.relax import amber_minimize, utils
20
+ import numpy as np
21
+
22
+
23
+ class AmberRelaxation(object):
24
+ """Amber relaxation."""
25
+ def __init__(
26
+ self,
27
+ *,
28
+ max_iterations: int,
29
+ tolerance: float,
30
+ stiffness: float,
31
+ exclude_residues: Sequence[int],
32
+ max_outer_iterations: int,
33
+ use_gpu: bool,
34
+ ):
35
+ """Initialize Amber Relaxer.
36
+
37
+ Args:
38
+ max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
39
+ tolerance: kcal/mol, the energy tolerance of L-BFGS.
40
+ stiffness: kcal/mol A**2, spring constant of heavy atom restraining
41
+ potential.
42
+ exclude_residues: Residues to exclude from per-atom restraining.
43
+ Zero-indexed.
44
+ max_outer_iterations: Maximum number of violation-informed relax
45
+ iterations. A value of 1 will run the non-iterative procedure used in
46
+ CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
47
+ as soon as there are no violations, hence in most cases this causes no
48
+ slowdown. In the worst case we do 20 outer iterations.
49
+ use_gpu: Whether to run on GPU
50
+ """
51
+
52
+ self._max_iterations = max_iterations
53
+ self._tolerance = tolerance
54
+ self._stiffness = stiffness
55
+ self._exclude_residues = exclude_residues
56
+ self._max_outer_iterations = max_outer_iterations
57
+ self._use_gpu = use_gpu
58
+
59
+ def process(
60
+ self, *, prot: protein.Protein
61
+ ) -> Tuple[str, Dict[str, Any], np.ndarray]:
62
+ """Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
63
+ out = amber_minimize.run_pipeline(
64
+ prot=prot,
65
+ max_iterations=self._max_iterations,
66
+ tolerance=self._tolerance,
67
+ stiffness=self._stiffness,
68
+ exclude_residues=self._exclude_residues,
69
+ max_outer_iterations=self._max_outer_iterations,
70
+ use_gpu=self._use_gpu,
71
+ )
72
+ min_pos = out["pos"]
73
+ start_pos = out["posinit"]
74
+ rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0])
75
+ debug_data = {
76
+ "initial_energy": out["einit"],
77
+ "final_energy": out["efinal"],
78
+ "attempts": out["min_attempts"],
79
+ "rmsd": rmsd,
80
+ }
81
+ pdb_str = amber_minimize.clean_protein(prot)
82
+ min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
83
+ min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
84
+ utils.assert_equal_nonterminal_atom_types(
85
+ protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask
86
+ )
87
+ violations = out["structural_violations"][
88
+ "total_per_residue_violations_mask"
89
+ ]
90
+
91
+ min_pdb = protein.add_pdb_headers(prot, min_pdb)
92
+
93
+ return min_pdb, debug_data, violations
openfold/np/relax/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Utils for minimization."""
17
+ import io
18
+ from openfold.np import residue_constants
19
+ from Bio import PDB
20
+ import numpy as np
21
+ from simtk.openmm import app as openmm_app
22
+ from simtk.openmm.app.internal.pdbstructure import PdbStructure
23
+
24
+
25
+ def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
26
+ pdb_file = io.StringIO(pdb_str)
27
+ structure = PdbStructure(pdb_file)
28
+ topology = openmm_app.PDBFile(structure).getTopology()
29
+ with io.StringIO() as f:
30
+ openmm_app.PDBFile.writeFile(topology, pos, f)
31
+ return f.getvalue()
32
+
33
+
34
+ def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
35
+ """Overwrites the B-factors in pdb_str with contents of bfactors array.
36
+
37
+ Args:
38
+ pdb_str: An input PDB string.
39
+ bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
40
+ B-factors are per residue; i.e. that the nonzero entries are identical in
41
+ [0, i, :].
42
+
43
+ Returns:
44
+ A new PDB string with the B-factors replaced.
45
+ """
46
+ if bfactors.shape[-1] != residue_constants.atom_type_num:
47
+ raise ValueError(
48
+ f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}."
49
+ )
50
+
51
+ parser = PDB.PDBParser(QUIET=True)
52
+ handle = io.StringIO(pdb_str)
53
+ structure = parser.get_structure("", handle)
54
+
55
+ curr_resid = ("", "", "")
56
+ idx = -1
57
+ for atom in structure.get_atoms():
58
+ atom_resid = atom.parent.get_id()
59
+ if atom_resid != curr_resid:
60
+ idx += 1
61
+ if idx >= bfactors.shape[0]:
62
+ raise ValueError(
63
+ "Index into bfactors exceeds number of residues. "
64
+ "B-factors shape: {shape}, idx: {idx}."
65
+ )
66
+ curr_resid = atom_resid
67
+ atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]]
68
+
69
+ new_pdb = io.StringIO()
70
+ pdb_io = PDB.PDBIO()
71
+ pdb_io.set_structure(structure)
72
+ pdb_io.save(new_pdb)
73
+ return new_pdb.getvalue()
74
+
75
+
76
+ def assert_equal_nonterminal_atom_types(
77
+ atom_mask: np.ndarray, ref_atom_mask: np.ndarray
78
+ ):
79
+ """Checks that pre- and post-minimized proteins have same atom set."""
80
+ # Ignore any terminal OXT atoms which may have been added by minimization.
81
+ oxt = residue_constants.atom_order["OXT"]
82
+ no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
83
+ no_oxt_mask[..., oxt] = False
84
+ np.testing.assert_almost_equal(
85
+ ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
86
+ )
openfold/np/residue_constants.py ADDED
@@ -0,0 +1,1310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Constants used in AlphaFold."""
17
+
18
+ import collections
19
+ import functools
20
+ from typing import Mapping, List, Tuple
21
+ from importlib import resources
22
+
23
+ import numpy as np
24
+ import tree
25
+
26
+ # Internal import (35fd).
27
+
28
+
29
+ # Distance from one CA to next CA [trans configuration: omega = 180].
30
+ ca_ca = 3.80209737096
31
+
32
+ # Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
33
+ # this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
34
+ # chi angles so their chi angle lists are empty.
35
+ chi_angles_atoms = {
36
+ "ALA": [],
37
+ # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
38
+ "ARG": [
39
+ ["N", "CA", "CB", "CG"],
40
+ ["CA", "CB", "CG", "CD"],
41
+ ["CB", "CG", "CD", "NE"],
42
+ ["CG", "CD", "NE", "CZ"],
43
+ ],
44
+ "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
45
+ "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
46
+ "CYS": [["N", "CA", "CB", "SG"]],
47
+ "GLN": [
48
+ ["N", "CA", "CB", "CG"],
49
+ ["CA", "CB", "CG", "CD"],
50
+ ["CB", "CG", "CD", "OE1"],
51
+ ],
52
+ "GLU": [
53
+ ["N", "CA", "CB", "CG"],
54
+ ["CA", "CB", "CG", "CD"],
55
+ ["CB", "CG", "CD", "OE1"],
56
+ ],
57
+ "GLY": [],
58
+ "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
59
+ "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
60
+ "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
61
+ "LYS": [
62
+ ["N", "CA", "CB", "CG"],
63
+ ["CA", "CB", "CG", "CD"],
64
+ ["CB", "CG", "CD", "CE"],
65
+ ["CG", "CD", "CE", "NZ"],
66
+ ],
67
+ "MET": [
68
+ ["N", "CA", "CB", "CG"],
69
+ ["CA", "CB", "CG", "SD"],
70
+ ["CB", "CG", "SD", "CE"],
71
+ ],
72
+ "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
73
+ "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
74
+ "SER": [["N", "CA", "CB", "OG"]],
75
+ "THR": [["N", "CA", "CB", "OG1"]],
76
+ "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
77
+ "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
78
+ "VAL": [["N", "CA", "CB", "CG1"]],
79
+ }
80
+
81
+ # If chi angles given in fixed-length array, this matrix determines how to mask
82
+ # them for each AA type. The order is as per restype_order (see below).
83
+ chi_angles_mask = [
84
+ [0.0, 0.0, 0.0, 0.0], # ALA
85
+ [1.0, 1.0, 1.0, 1.0], # ARG
86
+ [1.0, 1.0, 0.0, 0.0], # ASN
87
+ [1.0, 1.0, 0.0, 0.0], # ASP
88
+ [1.0, 0.0, 0.0, 0.0], # CYS
89
+ [1.0, 1.0, 1.0, 0.0], # GLN
90
+ [1.0, 1.0, 1.0, 0.0], # GLU
91
+ [0.0, 0.0, 0.0, 0.0], # GLY
92
+ [1.0, 1.0, 0.0, 0.0], # HIS
93
+ [1.0, 1.0, 0.0, 0.0], # ILE
94
+ [1.0, 1.0, 0.0, 0.0], # LEU
95
+ [1.0, 1.0, 1.0, 1.0], # LYS
96
+ [1.0, 1.0, 1.0, 0.0], # MET
97
+ [1.0, 1.0, 0.0, 0.0], # PHE
98
+ [1.0, 1.0, 0.0, 0.0], # PRO
99
+ [1.0, 0.0, 0.0, 0.0], # SER
100
+ [1.0, 0.0, 0.0, 0.0], # THR
101
+ [1.0, 1.0, 0.0, 0.0], # TRP
102
+ [1.0, 1.0, 0.0, 0.0], # TYR
103
+ [1.0, 0.0, 0.0, 0.0], # VAL
104
+ ]
105
+
106
+ # The following chi angles are pi periodic: they can be rotated by a multiple
107
+ # of pi without affecting the structure.
108
+ chi_pi_periodic = [
109
+ [0.0, 0.0, 0.0, 0.0], # ALA
110
+ [0.0, 0.0, 0.0, 0.0], # ARG
111
+ [0.0, 0.0, 0.0, 0.0], # ASN
112
+ [0.0, 1.0, 0.0, 0.0], # ASP
113
+ [0.0, 0.0, 0.0, 0.0], # CYS
114
+ [0.0, 0.0, 0.0, 0.0], # GLN
115
+ [0.0, 0.0, 1.0, 0.0], # GLU
116
+ [0.0, 0.0, 0.0, 0.0], # GLY
117
+ [0.0, 0.0, 0.0, 0.0], # HIS
118
+ [0.0, 0.0, 0.0, 0.0], # ILE
119
+ [0.0, 0.0, 0.0, 0.0], # LEU
120
+ [0.0, 0.0, 0.0, 0.0], # LYS
121
+ [0.0, 0.0, 0.0, 0.0], # MET
122
+ [0.0, 1.0, 0.0, 0.0], # PHE
123
+ [0.0, 0.0, 0.0, 0.0], # PRO
124
+ [0.0, 0.0, 0.0, 0.0], # SER
125
+ [0.0, 0.0, 0.0, 0.0], # THR
126
+ [0.0, 0.0, 0.0, 0.0], # TRP
127
+ [0.0, 1.0, 0.0, 0.0], # TYR
128
+ [0.0, 0.0, 0.0, 0.0], # VAL
129
+ [0.0, 0.0, 0.0, 0.0], # UNK
130
+ ]
131
+
132
+ # Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
133
+ # psi and chi angles:
134
+ # 0: 'backbone group',
135
+ # 1: 'pre-omega-group', (empty)
136
+ # 2: 'phi-group', (currently empty, because it defines only hydrogens)
137
+ # 3: 'psi-group',
138
+ # 4,5,6,7: 'chi1,2,3,4-group'
139
+ # The atom positions are relative to the axis-end-atom of the corresponding
140
+ # rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
141
+ # is defined such that the dihedral-angle-definiting atom (the last entry in
142
+ # chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
143
+ # format: [atomname, group_idx, rel_position]
144
+ rigid_group_atom_positions = {
145
+ "ALA": [
146
+ ["N", 0, (-0.525, 1.363, 0.000)],
147
+ ["CA", 0, (0.000, 0.000, 0.000)],
148
+ ["C", 0, (1.526, -0.000, -0.000)],
149
+ ["CB", 0, (-0.529, -0.774, -1.205)],
150
+ ["O", 3, (0.627, 1.062, 0.000)],
151
+ ],
152
+ "ARG": [
153
+ ["N", 0, (-0.524, 1.362, -0.000)],
154
+ ["CA", 0, (0.000, 0.000, 0.000)],
155
+ ["C", 0, (1.525, -0.000, -0.000)],
156
+ ["CB", 0, (-0.524, -0.778, -1.209)],
157
+ ["O", 3, (0.626, 1.062, 0.000)],
158
+ ["CG", 4, (0.616, 1.390, -0.000)],
159
+ ["CD", 5, (0.564, 1.414, 0.000)],
160
+ ["NE", 6, (0.539, 1.357, -0.000)],
161
+ ["NH1", 7, (0.206, 2.301, 0.000)],
162
+ ["NH2", 7, (2.078, 0.978, -0.000)],
163
+ ["CZ", 7, (0.758, 1.093, -0.000)],
164
+ ],
165
+ "ASN": [
166
+ ["N", 0, (-0.536, 1.357, 0.000)],
167
+ ["CA", 0, (0.000, 0.000, 0.000)],
168
+ ["C", 0, (1.526, -0.000, -0.000)],
169
+ ["CB", 0, (-0.531, -0.787, -1.200)],
170
+ ["O", 3, (0.625, 1.062, 0.000)],
171
+ ["CG", 4, (0.584, 1.399, 0.000)],
172
+ ["ND2", 5, (0.593, -1.188, 0.001)],
173
+ ["OD1", 5, (0.633, 1.059, 0.000)],
174
+ ],
175
+ "ASP": [
176
+ ["N", 0, (-0.525, 1.362, -0.000)],
177
+ ["CA", 0, (0.000, 0.000, 0.000)],
178
+ ["C", 0, (1.527, 0.000, -0.000)],
179
+ ["CB", 0, (-0.526, -0.778, -1.208)],
180
+ ["O", 3, (0.626, 1.062, -0.000)],
181
+ ["CG", 4, (0.593, 1.398, -0.000)],
182
+ ["OD1", 5, (0.610, 1.091, 0.000)],
183
+ ["OD2", 5, (0.592, -1.101, -0.003)],
184
+ ],
185
+ "CYS": [
186
+ ["N", 0, (-0.522, 1.362, -0.000)],
187
+ ["CA", 0, (0.000, 0.000, 0.000)],
188
+ ["C", 0, (1.524, 0.000, 0.000)],
189
+ ["CB", 0, (-0.519, -0.773, -1.212)],
190
+ ["O", 3, (0.625, 1.062, -0.000)],
191
+ ["SG", 4, (0.728, 1.653, 0.000)],
192
+ ],
193
+ "GLN": [
194
+ ["N", 0, (-0.526, 1.361, -0.000)],
195
+ ["CA", 0, (0.000, 0.000, 0.000)],
196
+ ["C", 0, (1.526, 0.000, 0.000)],
197
+ ["CB", 0, (-0.525, -0.779, -1.207)],
198
+ ["O", 3, (0.626, 1.062, -0.000)],
199
+ ["CG", 4, (0.615, 1.393, 0.000)],
200
+ ["CD", 5, (0.587, 1.399, -0.000)],
201
+ ["NE2", 6, (0.593, -1.189, -0.001)],
202
+ ["OE1", 6, (0.634, 1.060, 0.000)],
203
+ ],
204
+ "GLU": [
205
+ ["N", 0, (-0.528, 1.361, 0.000)],
206
+ ["CA", 0, (0.000, 0.000, 0.000)],
207
+ ["C", 0, (1.526, -0.000, -0.000)],
208
+ ["CB", 0, (-0.526, -0.781, -1.207)],
209
+ ["O", 3, (0.626, 1.062, 0.000)],
210
+ ["CG", 4, (0.615, 1.392, 0.000)],
211
+ ["CD", 5, (0.600, 1.397, 0.000)],
212
+ ["OE1", 6, (0.607, 1.095, -0.000)],
213
+ ["OE2", 6, (0.589, -1.104, -0.001)],
214
+ ],
215
+ "GLY": [
216
+ ["N", 0, (-0.572, 1.337, 0.000)],
217
+ ["CA", 0, (0.000, 0.000, 0.000)],
218
+ ["C", 0, (1.517, -0.000, -0.000)],
219
+ ["O", 3, (0.626, 1.062, -0.000)],
220
+ ],
221
+ "HIS": [
222
+ ["N", 0, (-0.527, 1.360, 0.000)],
223
+ ["CA", 0, (0.000, 0.000, 0.000)],
224
+ ["C", 0, (1.525, 0.000, 0.000)],
225
+ ["CB", 0, (-0.525, -0.778, -1.208)],
226
+ ["O", 3, (0.625, 1.063, 0.000)],
227
+ ["CG", 4, (0.600, 1.370, -0.000)],
228
+ ["CD2", 5, (0.889, -1.021, 0.003)],
229
+ ["ND1", 5, (0.744, 1.160, -0.000)],
230
+ ["CE1", 5, (2.030, 0.851, 0.002)],
231
+ ["NE2", 5, (2.145, -0.466, 0.004)],
232
+ ],
233
+ "ILE": [
234
+ ["N", 0, (-0.493, 1.373, -0.000)],
235
+ ["CA", 0, (0.000, 0.000, 0.000)],
236
+ ["C", 0, (1.527, -0.000, -0.000)],
237
+ ["CB", 0, (-0.536, -0.793, -1.213)],
238
+ ["O", 3, (0.627, 1.062, -0.000)],
239
+ ["CG1", 4, (0.534, 1.437, -0.000)],
240
+ ["CG2", 4, (0.540, -0.785, -1.199)],
241
+ ["CD1", 5, (0.619, 1.391, 0.000)],
242
+ ],
243
+ "LEU": [
244
+ ["N", 0, (-0.520, 1.363, 0.000)],
245
+ ["CA", 0, (0.000, 0.000, 0.000)],
246
+ ["C", 0, (1.525, -0.000, -0.000)],
247
+ ["CB", 0, (-0.522, -0.773, -1.214)],
248
+ ["O", 3, (0.625, 1.063, -0.000)],
249
+ ["CG", 4, (0.678, 1.371, 0.000)],
250
+ ["CD1", 5, (0.530, 1.430, -0.000)],
251
+ ["CD2", 5, (0.535, -0.774, 1.200)],
252
+ ],
253
+ "LYS": [
254
+ ["N", 0, (-0.526, 1.362, -0.000)],
255
+ ["CA", 0, (0.000, 0.000, 0.000)],
256
+ ["C", 0, (1.526, 0.000, 0.000)],
257
+ ["CB", 0, (-0.524, -0.778, -1.208)],
258
+ ["O", 3, (0.626, 1.062, -0.000)],
259
+ ["CG", 4, (0.619, 1.390, 0.000)],
260
+ ["CD", 5, (0.559, 1.417, 0.000)],
261
+ ["CE", 6, (0.560, 1.416, 0.000)],
262
+ ["NZ", 7, (0.554, 1.387, 0.000)],
263
+ ],
264
+ "MET": [
265
+ ["N", 0, (-0.521, 1.364, -0.000)],
266
+ ["CA", 0, (0.000, 0.000, 0.000)],
267
+ ["C", 0, (1.525, 0.000, 0.000)],
268
+ ["CB", 0, (-0.523, -0.776, -1.210)],
269
+ ["O", 3, (0.625, 1.062, -0.000)],
270
+ ["CG", 4, (0.613, 1.391, -0.000)],
271
+ ["SD", 5, (0.703, 1.695, 0.000)],
272
+ ["CE", 6, (0.320, 1.786, -0.000)],
273
+ ],
274
+ "PHE": [
275
+ ["N", 0, (-0.518, 1.363, 0.000)],
276
+ ["CA", 0, (0.000, 0.000, 0.000)],
277
+ ["C", 0, (1.524, 0.000, -0.000)],
278
+ ["CB", 0, (-0.525, -0.776, -1.212)],
279
+ ["O", 3, (0.626, 1.062, -0.000)],
280
+ ["CG", 4, (0.607, 1.377, 0.000)],
281
+ ["CD1", 5, (0.709, 1.195, -0.000)],
282
+ ["CD2", 5, (0.706, -1.196, 0.000)],
283
+ ["CE1", 5, (2.102, 1.198, -0.000)],
284
+ ["CE2", 5, (2.098, -1.201, -0.000)],
285
+ ["CZ", 5, (2.794, -0.003, -0.001)],
286
+ ],
287
+ "PRO": [
288
+ ["N", 0, (-0.566, 1.351, -0.000)],
289
+ ["CA", 0, (0.000, 0.000, 0.000)],
290
+ ["C", 0, (1.527, -0.000, 0.000)],
291
+ ["CB", 0, (-0.546, -0.611, -1.293)],
292
+ ["O", 3, (0.621, 1.066, 0.000)],
293
+ ["CG", 4, (0.382, 1.445, 0.0)],
294
+ # ['CD', 5, (0.427, 1.440, 0.0)],
295
+ ["CD", 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
296
+ ],
297
+ "SER": [
298
+ ["N", 0, (-0.529, 1.360, -0.000)],
299
+ ["CA", 0, (0.000, 0.000, 0.000)],
300
+ ["C", 0, (1.525, -0.000, -0.000)],
301
+ ["CB", 0, (-0.518, -0.777, -1.211)],
302
+ ["O", 3, (0.626, 1.062, -0.000)],
303
+ ["OG", 4, (0.503, 1.325, 0.000)],
304
+ ],
305
+ "THR": [
306
+ ["N", 0, (-0.517, 1.364, 0.000)],
307
+ ["CA", 0, (0.000, 0.000, 0.000)],
308
+ ["C", 0, (1.526, 0.000, -0.000)],
309
+ ["CB", 0, (-0.516, -0.793, -1.215)],
310
+ ["O", 3, (0.626, 1.062, 0.000)],
311
+ ["CG2", 4, (0.550, -0.718, -1.228)],
312
+ ["OG1", 4, (0.472, 1.353, 0.000)],
313
+ ],
314
+ "TRP": [
315
+ ["N", 0, (-0.521, 1.363, 0.000)],
316
+ ["CA", 0, (0.000, 0.000, 0.000)],
317
+ ["C", 0, (1.525, -0.000, 0.000)],
318
+ ["CB", 0, (-0.523, -0.776, -1.212)],
319
+ ["O", 3, (0.627, 1.062, 0.000)],
320
+ ["CG", 4, (0.609, 1.370, -0.000)],
321
+ ["CD1", 5, (0.824, 1.091, 0.000)],
322
+ ["CD2", 5, (0.854, -1.148, -0.005)],
323
+ ["CE2", 5, (2.186, -0.678, -0.007)],
324
+ ["CE3", 5, (0.622, -2.530, -0.007)],
325
+ ["NE1", 5, (2.140, 0.690, -0.004)],
326
+ ["CH2", 5, (3.028, -2.890, -0.013)],
327
+ ["CZ2", 5, (3.283, -1.543, -0.011)],
328
+ ["CZ3", 5, (1.715, -3.389, -0.011)],
329
+ ],
330
+ "TYR": [
331
+ ["N", 0, (-0.522, 1.362, 0.000)],
332
+ ["CA", 0, (0.000, 0.000, 0.000)],
333
+ ["C", 0, (1.524, -0.000, -0.000)],
334
+ ["CB", 0, (-0.522, -0.776, -1.213)],
335
+ ["O", 3, (0.627, 1.062, -0.000)],
336
+ ["CG", 4, (0.607, 1.382, -0.000)],
337
+ ["CD1", 5, (0.716, 1.195, -0.000)],
338
+ ["CD2", 5, (0.713, -1.194, -0.001)],
339
+ ["CE1", 5, (2.107, 1.200, -0.002)],
340
+ ["CE2", 5, (2.104, -1.201, -0.003)],
341
+ ["OH", 5, (4.168, -0.002, -0.005)],
342
+ ["CZ", 5, (2.791, -0.001, -0.003)],
343
+ ],
344
+ "VAL": [
345
+ ["N", 0, (-0.494, 1.373, -0.000)],
346
+ ["CA", 0, (0.000, 0.000, 0.000)],
347
+ ["C", 0, (1.527, -0.000, -0.000)],
348
+ ["CB", 0, (-0.533, -0.795, -1.213)],
349
+ ["O", 3, (0.627, 1.062, -0.000)],
350
+ ["CG1", 4, (0.540, 1.429, -0.000)],
351
+ ["CG2", 4, (0.533, -0.776, 1.203)],
352
+ ],
353
+ }
354
+
355
+ # A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
356
+ residue_atoms = {
357
+ "ALA": ["C", "CA", "CB", "N", "O"],
358
+ "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
359
+ "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
360
+ "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
361
+ "CYS": ["C", "CA", "CB", "N", "O", "SG"],
362
+ "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
363
+ "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
364
+ "GLY": ["C", "CA", "N", "O"],
365
+ "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
366
+ "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
367
+ "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
368
+ "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
369
+ "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
370
+ "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
371
+ "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
372
+ "SER": ["C", "CA", "CB", "N", "O", "OG"],
373
+ "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
374
+ "TRP": [
375
+ "C",
376
+ "CA",
377
+ "CB",
378
+ "CG",
379
+ "CD1",
380
+ "CD2",
381
+ "CE2",
382
+ "CE3",
383
+ "CZ2",
384
+ "CZ3",
385
+ "CH2",
386
+ "N",
387
+ "NE1",
388
+ "O",
389
+ ],
390
+ "TYR": [
391
+ "C",
392
+ "CA",
393
+ "CB",
394
+ "CG",
395
+ "CD1",
396
+ "CD2",
397
+ "CE1",
398
+ "CE2",
399
+ "CZ",
400
+ "N",
401
+ "O",
402
+ "OH",
403
+ ],
404
+ "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
405
+ }
406
+
407
+ # Naming swaps for ambiguous atom names.
408
+ # Due to symmetries in the amino acids the naming of atoms is ambiguous in
409
+ # 4 of the 20 amino acids.
410
+ # (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
411
+ # in LEU, VAL and ARG can be resolved by using the 3d constellations of
412
+ # the 'ambiguous' atoms and their neighbours)
413
+ # TODO: ^ interpret this
414
+ residue_atom_renaming_swaps = {
415
+ "ASP": {"OD1": "OD2"},
416
+ "GLU": {"OE1": "OE2"},
417
+ "PHE": {"CD1": "CD2", "CE1": "CE2"},
418
+ "TYR": {"CD1": "CD2", "CE1": "CE2"},
419
+ }
420
+
421
+ # Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
422
+ van_der_waals_radius = {
423
+ "C": 1.7,
424
+ "N": 1.55,
425
+ "O": 1.52,
426
+ "S": 1.8,
427
+ }
428
+
429
+ Bond = collections.namedtuple(
430
+ "Bond", ["atom1_name", "atom2_name", "length", "stddev"]
431
+ )
432
+ BondAngle = collections.namedtuple(
433
+ "BondAngle",
434
+ ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
435
+ )
436
+
437
+
438
+ @functools.lru_cache(maxsize=None)
439
+ def load_stereo_chemical_props() -> Tuple[
440
+ Mapping[str, List[Bond]],
441
+ Mapping[str, List[Bond]],
442
+ Mapping[str, List[BondAngle]],
443
+ ]:
444
+ """Load stereo_chemical_props.txt into a nice structure.
445
+
446
+ Load literature values for bond lengths and bond angles and translate
447
+ bond angles into the length of the opposite edge of the triangle
448
+ ("residue_virtual_bonds").
449
+
450
+ Returns:
451
+ residue_bonds: dict that maps resname --> list of Bond tuples
452
+ residue_virtual_bonds: dict that maps resname --> list of Bond tuples
453
+ residue_bond_angles: dict that maps resname --> list of BondAngle tuples
454
+ """
455
+ # TODO: this file should be downloaded in a setup script
456
+ stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt")
457
+
458
+ lines_iter = iter(stereo_chemical_props.splitlines())
459
+ # Load bond lengths.
460
+ residue_bonds = {}
461
+ next(lines_iter) # Skip header line.
462
+ for line in lines_iter:
463
+ if line.strip() == "-":
464
+ break
465
+ bond, resname, length, stddev = line.split()
466
+ atom1, atom2 = bond.split("-")
467
+ if resname not in residue_bonds:
468
+ residue_bonds[resname] = []
469
+ residue_bonds[resname].append(
470
+ Bond(atom1, atom2, float(length), float(stddev))
471
+ )
472
+ residue_bonds["UNK"] = []
473
+
474
+ # Load bond angles.
475
+ residue_bond_angles = {}
476
+ next(lines_iter) # Skip empty line.
477
+ next(lines_iter) # Skip header line.
478
+ for line in lines_iter:
479
+ if line.strip() == "-":
480
+ break
481
+ bond, resname, angle_degree, stddev_degree = line.split()
482
+ atom1, atom2, atom3 = bond.split("-")
483
+ if resname not in residue_bond_angles:
484
+ residue_bond_angles[resname] = []
485
+ residue_bond_angles[resname].append(
486
+ BondAngle(
487
+ atom1,
488
+ atom2,
489
+ atom3,
490
+ float(angle_degree) / 180.0 * np.pi,
491
+ float(stddev_degree) / 180.0 * np.pi,
492
+ )
493
+ )
494
+ residue_bond_angles["UNK"] = []
495
+
496
+ def make_bond_key(atom1_name, atom2_name):
497
+ """Unique key to lookup bonds."""
498
+ return "-".join(sorted([atom1_name, atom2_name]))
499
+
500
+ # Translate bond angles into distances ("virtual bonds").
501
+ residue_virtual_bonds = {}
502
+ for resname, bond_angles in residue_bond_angles.items():
503
+ # Create a fast lookup dict for bond lengths.
504
+ bond_cache = {}
505
+ for b in residue_bonds[resname]:
506
+ bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
507
+ residue_virtual_bonds[resname] = []
508
+ for ba in bond_angles:
509
+ bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
510
+ bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
511
+
512
+ # Compute distance between atom1 and atom3 using the law of cosines
513
+ # c^2 = a^2 + b^2 - 2ab*cos(gamma).
514
+ gamma = ba.angle_rad
515
+ length = np.sqrt(
516
+ bond1.length ** 2
517
+ + bond2.length ** 2
518
+ - 2 * bond1.length * bond2.length * np.cos(gamma)
519
+ )
520
+
521
+ # Propagation of uncertainty assuming uncorrelated errors.
522
+ dl_outer = 0.5 / length
523
+ dl_dgamma = (
524
+ 2 * bond1.length * bond2.length * np.sin(gamma)
525
+ ) * dl_outer
526
+ dl_db1 = (
527
+ 2 * bond1.length - 2 * bond2.length * np.cos(gamma)
528
+ ) * dl_outer
529
+ dl_db2 = (
530
+ 2 * bond2.length - 2 * bond1.length * np.cos(gamma)
531
+ ) * dl_outer
532
+ stddev = np.sqrt(
533
+ (dl_dgamma * ba.stddev) ** 2
534
+ + (dl_db1 * bond1.stddev) ** 2
535
+ + (dl_db2 * bond2.stddev) ** 2
536
+ )
537
+ residue_virtual_bonds[resname].append(
538
+ Bond(ba.atom1_name, ba.atom3name, length, stddev)
539
+ )
540
+
541
+ return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
542
+
543
+
544
+ # Between-residue bond lengths for general bonds (first element) and for Proline
545
+ # (second element).
546
+ between_res_bond_length_c_n = [1.329, 1.341]
547
+ between_res_bond_length_stddev_c_n = [0.014, 0.016]
548
+
549
+ # Between-residue cos_angles.
550
+ between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
551
+ between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
552
+
553
+ # This mapping is used when we need to store atom data in a format that requires
554
+ # fixed atom data size for every residue (e.g. a numpy array).
555
+ atom_types = [
556
+ "N",
557
+ "CA",
558
+ "C",
559
+ "CB",
560
+ "O",
561
+ "CG",
562
+ "CG1",
563
+ "CG2",
564
+ "OG",
565
+ "OG1",
566
+ "SG",
567
+ "CD",
568
+ "CD1",
569
+ "CD2",
570
+ "ND1",
571
+ "ND2",
572
+ "OD1",
573
+ "OD2",
574
+ "SD",
575
+ "CE",
576
+ "CE1",
577
+ "CE2",
578
+ "CE3",
579
+ "NE",
580
+ "NE1",
581
+ "NE2",
582
+ "OE1",
583
+ "OE2",
584
+ "CH2",
585
+ "NH1",
586
+ "NH2",
587
+ "OH",
588
+ "CZ",
589
+ "CZ2",
590
+ "CZ3",
591
+ "NZ",
592
+ "OXT",
593
+ ]
594
+ atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
595
+ atom_type_num = len(atom_types) # := 37.
596
+
597
+ # A compact atom encoding with 14 columns
598
+ # pylint: disable=line-too-long
599
+ # pylint: disable=bad-whitespace
600
+ restype_name_to_atom14_names = {
601
+ "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
602
+ "ARG": [
603
+ "N",
604
+ "CA",
605
+ "C",
606
+ "O",
607
+ "CB",
608
+ "CG",
609
+ "CD",
610
+ "NE",
611
+ "CZ",
612
+ "NH1",
613
+ "NH2",
614
+ "",
615
+ "",
616
+ "",
617
+ ],
618
+ "ASN": [
619
+ "N",
620
+ "CA",
621
+ "C",
622
+ "O",
623
+ "CB",
624
+ "CG",
625
+ "OD1",
626
+ "ND2",
627
+ "",
628
+ "",
629
+ "",
630
+ "",
631
+ "",
632
+ "",
633
+ ],
634
+ "ASP": [
635
+ "N",
636
+ "CA",
637
+ "C",
638
+ "O",
639
+ "CB",
640
+ "CG",
641
+ "OD1",
642
+ "OD2",
643
+ "",
644
+ "",
645
+ "",
646
+ "",
647
+ "",
648
+ "",
649
+ ],
650
+ "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
651
+ "GLN": [
652
+ "N",
653
+ "CA",
654
+ "C",
655
+ "O",
656
+ "CB",
657
+ "CG",
658
+ "CD",
659
+ "OE1",
660
+ "NE2",
661
+ "",
662
+ "",
663
+ "",
664
+ "",
665
+ "",
666
+ ],
667
+ "GLU": [
668
+ "N",
669
+ "CA",
670
+ "C",
671
+ "O",
672
+ "CB",
673
+ "CG",
674
+ "CD",
675
+ "OE1",
676
+ "OE2",
677
+ "",
678
+ "",
679
+ "",
680
+ "",
681
+ "",
682
+ ],
683
+ "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
684
+ "HIS": [
685
+ "N",
686
+ "CA",
687
+ "C",
688
+ "O",
689
+ "CB",
690
+ "CG",
691
+ "ND1",
692
+ "CD2",
693
+ "CE1",
694
+ "NE2",
695
+ "",
696
+ "",
697
+ "",
698
+ "",
699
+ ],
700
+ "ILE": [
701
+ "N",
702
+ "CA",
703
+ "C",
704
+ "O",
705
+ "CB",
706
+ "CG1",
707
+ "CG2",
708
+ "CD1",
709
+ "",
710
+ "",
711
+ "",
712
+ "",
713
+ "",
714
+ "",
715
+ ],
716
+ "LEU": [
717
+ "N",
718
+ "CA",
719
+ "C",
720
+ "O",
721
+ "CB",
722
+ "CG",
723
+ "CD1",
724
+ "CD2",
725
+ "",
726
+ "",
727
+ "",
728
+ "",
729
+ "",
730
+ "",
731
+ ],
732
+ "LYS": [
733
+ "N",
734
+ "CA",
735
+ "C",
736
+ "O",
737
+ "CB",
738
+ "CG",
739
+ "CD",
740
+ "CE",
741
+ "NZ",
742
+ "",
743
+ "",
744
+ "",
745
+ "",
746
+ "",
747
+ ],
748
+ "MET": [
749
+ "N",
750
+ "CA",
751
+ "C",
752
+ "O",
753
+ "CB",
754
+ "CG",
755
+ "SD",
756
+ "CE",
757
+ "",
758
+ "",
759
+ "",
760
+ "",
761
+ "",
762
+ "",
763
+ ],
764
+ "PHE": [
765
+ "N",
766
+ "CA",
767
+ "C",
768
+ "O",
769
+ "CB",
770
+ "CG",
771
+ "CD1",
772
+ "CD2",
773
+ "CE1",
774
+ "CE2",
775
+ "CZ",
776
+ "",
777
+ "",
778
+ "",
779
+ ],
780
+ "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
781
+ "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
782
+ "THR": [
783
+ "N",
784
+ "CA",
785
+ "C",
786
+ "O",
787
+ "CB",
788
+ "OG1",
789
+ "CG2",
790
+ "",
791
+ "",
792
+ "",
793
+ "",
794
+ "",
795
+ "",
796
+ "",
797
+ ],
798
+ "TRP": [
799
+ "N",
800
+ "CA",
801
+ "C",
802
+ "O",
803
+ "CB",
804
+ "CG",
805
+ "CD1",
806
+ "CD2",
807
+ "CE2",
808
+ "CE3",
809
+ "NE1",
810
+ "CZ2",
811
+ "CZ3",
812
+ "CH2",
813
+ ],
814
+ "TYR": [
815
+ "N",
816
+ "CA",
817
+ "C",
818
+ "O",
819
+ "CB",
820
+ "CG",
821
+ "CD1",
822
+ "CD2",
823
+ "CE1",
824
+ "CE2",
825
+ "CZ",
826
+ "OH",
827
+ "",
828
+ "",
829
+ ],
830
+ "VAL": [
831
+ "N",
832
+ "CA",
833
+ "C",
834
+ "O",
835
+ "CB",
836
+ "CG1",
837
+ "CG2",
838
+ "",
839
+ "",
840
+ "",
841
+ "",
842
+ "",
843
+ "",
844
+ "",
845
+ ],
846
+ "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
847
+ }
848
+ # pylint: enable=line-too-long
849
+ # pylint: enable=bad-whitespace
850
+
851
+
852
+ # This is the standard residue order when coding AA type as a number.
853
+ # Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
854
+ restypes = [
855
+ "A",
856
+ "R",
857
+ "N",
858
+ "D",
859
+ "C",
860
+ "Q",
861
+ "E",
862
+ "G",
863
+ "H",
864
+ "I",
865
+ "L",
866
+ "K",
867
+ "M",
868
+ "F",
869
+ "P",
870
+ "S",
871
+ "T",
872
+ "W",
873
+ "Y",
874
+ "V",
875
+ ]
876
+ restype_order = {restype: i for i, restype in enumerate(restypes)}
877
+ restype_num = len(restypes) # := 20.
878
+ unk_restype_index = restype_num # Catch-all index for unknown restypes.
879
+
880
+ restypes_with_x = restypes + ["X"]
881
+ restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
882
+
883
+
884
+ def sequence_to_onehot(
885
+ sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
886
+ ) -> np.ndarray:
887
+ """Maps the given sequence into a one-hot encoded matrix.
888
+
889
+ Args:
890
+ sequence: An amino acid sequence.
891
+ mapping: A dictionary mapping amino acids to integers.
892
+ map_unknown_to_x: If True, any amino acid that is not in the mapping will be
893
+ mapped to the unknown amino acid 'X'. If the mapping doesn't contain
894
+ amino acid 'X', an error will be thrown. If False, any amino acid not in
895
+ the mapping will throw an error.
896
+
897
+ Returns:
898
+ A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
899
+ the sequence.
900
+
901
+ Raises:
902
+ ValueError: If the mapping doesn't contain values from 0 to
903
+ num_unique_aas - 1 without any gaps.
904
+ """
905
+ num_entries = max(mapping.values()) + 1
906
+
907
+ if sorted(set(mapping.values())) != list(range(num_entries)):
908
+ raise ValueError(
909
+ "The mapping must have values from 0 to num_unique_aas-1 "
910
+ "without any gaps. Got: %s" % sorted(mapping.values())
911
+ )
912
+
913
+ one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
914
+
915
+ for aa_index, aa_type in enumerate(sequence):
916
+ if map_unknown_to_x:
917
+ if aa_type.isalpha() and aa_type.isupper():
918
+ aa_id = mapping.get(aa_type, mapping["X"])
919
+ else:
920
+ raise ValueError(
921
+ f"Invalid character in the sequence: {aa_type}"
922
+ )
923
+ else:
924
+ aa_id = mapping[aa_type]
925
+ one_hot_arr[aa_index, aa_id] = 1
926
+
927
+ return one_hot_arr
928
+
929
+
930
+ restype_1to3 = {
931
+ "A": "ALA",
932
+ "R": "ARG",
933
+ "N": "ASN",
934
+ "D": "ASP",
935
+ "C": "CYS",
936
+ "Q": "GLN",
937
+ "E": "GLU",
938
+ "G": "GLY",
939
+ "H": "HIS",
940
+ "I": "ILE",
941
+ "L": "LEU",
942
+ "K": "LYS",
943
+ "M": "MET",
944
+ "F": "PHE",
945
+ "P": "PRO",
946
+ "S": "SER",
947
+ "T": "THR",
948
+ "W": "TRP",
949
+ "Y": "TYR",
950
+ "V": "VAL",
951
+ }
952
+
953
+
954
+ # NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
955
+ # 1-to-1 mapping of 3 letter names to one letter names. The latter contains
956
+ # many more, and less common, three letter names as keys and maps many of these
957
+ # to the same one letter name (including 'X' and 'U' which we don't use here).
958
+ restype_3to1 = {v: k for k, v in restype_1to3.items()}
959
+
960
+ # Define a restype name for all unknown residues.
961
+ unk_restype = "UNK"
962
+
963
+ resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
964
+ resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
965
+
966
+
967
+ # The mapping here uses hhblits convention, so that B is mapped to D, J and O
968
+ # are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
969
+ # remaining 20 amino acids are kept in alphabetical order.
970
+ # There are 2 non-amino acid codes, X (representing any amino acid) and
971
+ # "-" representing a missing amino acid in an alignment. The id for these
972
+ # codes is put at the end (20 and 21) so that they can easily be ignored if
973
+ # desired.
974
+ HHBLITS_AA_TO_ID = {
975
+ "A": 0,
976
+ "B": 2,
977
+ "C": 1,
978
+ "D": 2,
979
+ "E": 3,
980
+ "F": 4,
981
+ "G": 5,
982
+ "H": 6,
983
+ "I": 7,
984
+ "J": 20,
985
+ "K": 8,
986
+ "L": 9,
987
+ "M": 10,
988
+ "N": 11,
989
+ "O": 20,
990
+ "P": 12,
991
+ "Q": 13,
992
+ "R": 14,
993
+ "S": 15,
994
+ "T": 16,
995
+ "U": 1,
996
+ "V": 17,
997
+ "W": 18,
998
+ "X": 20,
999
+ "Y": 19,
1000
+ "Z": 3,
1001
+ "-": 21,
1002
+ }
1003
+
1004
+ # Partial inversion of HHBLITS_AA_TO_ID.
1005
+ ID_TO_HHBLITS_AA = {
1006
+ 0: "A",
1007
+ 1: "C", # Also U.
1008
+ 2: "D", # Also B.
1009
+ 3: "E", # Also Z.
1010
+ 4: "F",
1011
+ 5: "G",
1012
+ 6: "H",
1013
+ 7: "I",
1014
+ 8: "K",
1015
+ 9: "L",
1016
+ 10: "M",
1017
+ 11: "N",
1018
+ 12: "P",
1019
+ 13: "Q",
1020
+ 14: "R",
1021
+ 15: "S",
1022
+ 16: "T",
1023
+ 17: "V",
1024
+ 18: "W",
1025
+ 19: "Y",
1026
+ 20: "X", # Includes J and O.
1027
+ 21: "-",
1028
+ }
1029
+
1030
+ restypes_with_x_and_gap = restypes + ["X", "-"]
1031
+ MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
1032
+ restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
1033
+ for i in range(len(restypes_with_x_and_gap))
1034
+ )
1035
+
1036
+
1037
+ def _make_standard_atom_mask() -> np.ndarray:
1038
+ """Returns [num_res_types, num_atom_types] mask array."""
1039
+ # +1 to account for unknown (all 0s).
1040
+ mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
1041
+ for restype, restype_letter in enumerate(restypes):
1042
+ restype_name = restype_1to3[restype_letter]
1043
+ atom_names = residue_atoms[restype_name]
1044
+ for atom_name in atom_names:
1045
+ atom_type = atom_order[atom_name]
1046
+ mask[restype, atom_type] = 1
1047
+ return mask
1048
+
1049
+
1050
+ STANDARD_ATOM_MASK = _make_standard_atom_mask()
1051
+
1052
+
1053
+ # A one hot representation for the first and second atoms defining the axis
1054
+ # of rotation for each chi-angle in each residue.
1055
+ def chi_angle_atom(atom_index: int) -> np.ndarray:
1056
+ """Define chi-angle rigid groups via one-hot representations."""
1057
+ chi_angles_index = {}
1058
+ one_hots = []
1059
+
1060
+ for k, v in chi_angles_atoms.items():
1061
+ indices = [atom_types.index(s[atom_index]) for s in v]
1062
+ indices.extend([-1] * (4 - len(indices)))
1063
+ chi_angles_index[k] = indices
1064
+
1065
+ for r in restypes:
1066
+ res3 = restype_1to3[r]
1067
+ one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
1068
+ one_hots.append(one_hot)
1069
+
1070
+ one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
1071
+ one_hot = np.stack(one_hots, axis=0)
1072
+ one_hot = np.transpose(one_hot, [0, 2, 1])
1073
+
1074
+ return one_hot
1075
+
1076
+
1077
+ chi_atom_1_one_hot = chi_angle_atom(1)
1078
+ chi_atom_2_one_hot = chi_angle_atom(2)
1079
+
1080
+ # An array like chi_angles_atoms but using indices rather than names.
1081
+ chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
1082
+ chi_angles_atom_indices = tree.map_structure(
1083
+ lambda atom_name: atom_order[atom_name], chi_angles_atom_indices
1084
+ )
1085
+ chi_angles_atom_indices = np.array(
1086
+ [
1087
+ chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
1088
+ for chi_atoms in chi_angles_atom_indices
1089
+ ]
1090
+ )
1091
+
1092
+ # Mapping from (res_name, atom_name) pairs to the atom's chi group index
1093
+ # and atom index within that group.
1094
+ chi_groups_for_atom = collections.defaultdict(list)
1095
+ for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
1096
+ for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
1097
+ for atom_i, atom in enumerate(chi_group):
1098
+ chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
1099
+ chi_groups_for_atom = dict(chi_groups_for_atom)
1100
+
1101
+
1102
+ def _make_rigid_transformation_4x4(ex, ey, translation):
1103
+ """Create a rigid 4x4 transformation matrix from two axes and transl."""
1104
+ # Normalize ex.
1105
+ ex_normalized = ex / np.linalg.norm(ex)
1106
+
1107
+ # make ey perpendicular to ex
1108
+ ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
1109
+ ey_normalized /= np.linalg.norm(ey_normalized)
1110
+
1111
+ # compute ez as cross product
1112
+ eznorm = np.cross(ex_normalized, ey_normalized)
1113
+ m = np.stack(
1114
+ [ex_normalized, ey_normalized, eznorm, translation]
1115
+ ).transpose()
1116
+ m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
1117
+ return m
1118
+
1119
+
1120
+ # create an array with (restype, atomtype) --> rigid_group_idx
1121
+ # and an array with (restype, atomtype, coord) for the atom positions
1122
+ # and compute affine transformation matrices (4,4) from one rigid group to the
1123
+ # previous group
1124
+ restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int)
1125
+ restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
1126
+ restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
1127
+ restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int)
1128
+ restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
1129
+ restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
1130
+ restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
1131
+
1132
+
1133
+ def _make_rigid_group_constants():
1134
+ """Fill the arrays above."""
1135
+ for restype, restype_letter in enumerate(restypes):
1136
+ resname = restype_1to3[restype_letter]
1137
+ for atomname, group_idx, atom_position in rigid_group_atom_positions[
1138
+ resname
1139
+ ]:
1140
+ atomtype = atom_order[atomname]
1141
+ restype_atom37_to_rigid_group[restype, atomtype] = group_idx
1142
+ restype_atom37_mask[restype, atomtype] = 1
1143
+ restype_atom37_rigid_group_positions[
1144
+ restype, atomtype, :
1145
+ ] = atom_position
1146
+
1147
+ atom14idx = restype_name_to_atom14_names[resname].index(atomname)
1148
+ restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
1149
+ restype_atom14_mask[restype, atom14idx] = 1
1150
+ restype_atom14_rigid_group_positions[
1151
+ restype, atom14idx, :
1152
+ ] = atom_position
1153
+
1154
+ for restype, restype_letter in enumerate(restypes):
1155
+ resname = restype_1to3[restype_letter]
1156
+ atom_positions = {
1157
+ name: np.array(pos)
1158
+ for name, _, pos in rigid_group_atom_positions[resname]
1159
+ }
1160
+
1161
+ # backbone to backbone is the identity transform
1162
+ restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
1163
+
1164
+ # pre-omega-frame to backbone (currently dummy identity matrix)
1165
+ restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
1166
+
1167
+ # phi-frame to backbone
1168
+ mat = _make_rigid_transformation_4x4(
1169
+ ex=atom_positions["N"] - atom_positions["CA"],
1170
+ ey=np.array([1.0, 0.0, 0.0]),
1171
+ translation=atom_positions["N"],
1172
+ )
1173
+ restype_rigid_group_default_frame[restype, 2, :, :] = mat
1174
+
1175
+ # psi-frame to backbone
1176
+ mat = _make_rigid_transformation_4x4(
1177
+ ex=atom_positions["C"] - atom_positions["CA"],
1178
+ ey=atom_positions["CA"] - atom_positions["N"],
1179
+ translation=atom_positions["C"],
1180
+ )
1181
+ restype_rigid_group_default_frame[restype, 3, :, :] = mat
1182
+
1183
+ # chi1-frame to backbone
1184
+ if chi_angles_mask[restype][0]:
1185
+ base_atom_names = chi_angles_atoms[resname][0]
1186
+ base_atom_positions = [
1187
+ atom_positions[name] for name in base_atom_names
1188
+ ]
1189
+ mat = _make_rigid_transformation_4x4(
1190
+ ex=base_atom_positions[2] - base_atom_positions[1],
1191
+ ey=base_atom_positions[0] - base_atom_positions[1],
1192
+ translation=base_atom_positions[2],
1193
+ )
1194
+ restype_rigid_group_default_frame[restype, 4, :, :] = mat
1195
+
1196
+ # chi2-frame to chi1-frame
1197
+ # chi3-frame to chi2-frame
1198
+ # chi4-frame to chi3-frame
1199
+ # luckily all rotation axes for the next frame start at (0,0,0) of the
1200
+ # previous frame
1201
+ for chi_idx in range(1, 4):
1202
+ if chi_angles_mask[restype][chi_idx]:
1203
+ axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
1204
+ axis_end_atom_position = atom_positions[axis_end_atom_name]
1205
+ mat = _make_rigid_transformation_4x4(
1206
+ ex=axis_end_atom_position,
1207
+ ey=np.array([-1.0, 0.0, 0.0]),
1208
+ translation=axis_end_atom_position,
1209
+ )
1210
+ restype_rigid_group_default_frame[
1211
+ restype, 4 + chi_idx, :, :
1212
+ ] = mat
1213
+
1214
+
1215
+ _make_rigid_group_constants()
1216
+
1217
+
1218
+ def make_atom14_dists_bounds(
1219
+ overlap_tolerance=1.5, bond_length_tolerance_factor=15
1220
+ ):
1221
+ """compute upper and lower bounds for bonds to assess violations."""
1222
+ restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
1223
+ restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
1224
+ restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
1225
+ residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
1226
+ for restype, restype_letter in enumerate(restypes):
1227
+ resname = restype_1to3[restype_letter]
1228
+ atom_list = restype_name_to_atom14_names[resname]
1229
+
1230
+ # create lower and upper bounds for clashes
1231
+ for atom1_idx, atom1_name in enumerate(atom_list):
1232
+ if not atom1_name:
1233
+ continue
1234
+ atom1_radius = van_der_waals_radius[atom1_name[0]]
1235
+ for atom2_idx, atom2_name in enumerate(atom_list):
1236
+ if (not atom2_name) or atom1_idx == atom2_idx:
1237
+ continue
1238
+ atom2_radius = van_der_waals_radius[atom2_name[0]]
1239
+ lower = atom1_radius + atom2_radius - overlap_tolerance
1240
+ upper = 1e10
1241
+ restype_atom14_bond_lower_bound[
1242
+ restype, atom1_idx, atom2_idx
1243
+ ] = lower
1244
+ restype_atom14_bond_lower_bound[
1245
+ restype, atom2_idx, atom1_idx
1246
+ ] = lower
1247
+ restype_atom14_bond_upper_bound[
1248
+ restype, atom1_idx, atom2_idx
1249
+ ] = upper
1250
+ restype_atom14_bond_upper_bound[
1251
+ restype, atom2_idx, atom1_idx
1252
+ ] = upper
1253
+
1254
+ # overwrite lower and upper bounds for bonds and angles
1255
+ for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
1256
+ atom1_idx = atom_list.index(b.atom1_name)
1257
+ atom2_idx = atom_list.index(b.atom2_name)
1258
+ lower = b.length - bond_length_tolerance_factor * b.stddev
1259
+ upper = b.length + bond_length_tolerance_factor * b.stddev
1260
+ restype_atom14_bond_lower_bound[
1261
+ restype, atom1_idx, atom2_idx
1262
+ ] = lower
1263
+ restype_atom14_bond_lower_bound[
1264
+ restype, atom2_idx, atom1_idx
1265
+ ] = lower
1266
+ restype_atom14_bond_upper_bound[
1267
+ restype, atom1_idx, atom2_idx
1268
+ ] = upper
1269
+ restype_atom14_bond_upper_bound[
1270
+ restype, atom2_idx, atom1_idx
1271
+ ] = upper
1272
+ restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
1273
+ restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
1274
+ return {
1275
+ "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
1276
+ "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
1277
+ "stddev": restype_atom14_bond_stddev, # shape (21,14,14)
1278
+ }
1279
+
1280
+
1281
+ restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
1282
+ restype_atom14_ambiguous_atoms_swap_idx = np.tile(
1283
+ np.arange(14, dtype=np.int), (21, 1)
1284
+ )
1285
+
1286
+
1287
+ def _make_atom14_ambiguity_feats():
1288
+ for res, pairs in residue_atom_renaming_swaps.items():
1289
+ res_idx = restype_order[restype_3to1[res]]
1290
+ for atom1, atom2 in pairs.items():
1291
+ atom1_idx = restype_name_to_atom14_names[res].index(atom1)
1292
+ atom2_idx = restype_name_to_atom14_names[res].index(atom2)
1293
+ restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
1294
+ restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
1295
+ restype_atom14_ambiguous_atoms_swap_idx[
1296
+ res_idx, atom1_idx
1297
+ ] = atom2_idx
1298
+ restype_atom14_ambiguous_atoms_swap_idx[
1299
+ res_idx, atom2_idx
1300
+ ] = atom1_idx
1301
+
1302
+
1303
+ _make_atom14_ambiguity_feats()
1304
+
1305
+
1306
+ def aatype_to_str_sequence(aatype):
1307
+ return ''.join([
1308
+ restypes_with_x[aatype[i]]
1309
+ for i in range(len(aatype))
1310
+ ])
openfold/resources/__init__.py ADDED
File without changes
openfold/utils/feats.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from typing import Dict
22
+
23
+ from openfold.np import protein
24
+ import openfold.np.residue_constants as rc
25
+ from openfold.utils.rigid_utils import Rotation, Rigid
26
+ from openfold.utils.tensor_utils import (
27
+ batched_gather,
28
+ one_hot,
29
+ tree_map,
30
+ tensor_tree_map,
31
+ )
32
+
33
+
34
+ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
35
+ is_gly = aatype == rc.restype_order["G"]
36
+ ca_idx = rc.atom_order["CA"]
37
+ cb_idx = rc.atom_order["CB"]
38
+ pseudo_beta = torch.where(
39
+ is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
40
+ all_atom_positions[..., ca_idx, :],
41
+ all_atom_positions[..., cb_idx, :],
42
+ )
43
+
44
+ if all_atom_masks is not None:
45
+ pseudo_beta_mask = torch.where(
46
+ is_gly,
47
+ all_atom_masks[..., ca_idx],
48
+ all_atom_masks[..., cb_idx],
49
+ )
50
+ return pseudo_beta, pseudo_beta_mask
51
+ else:
52
+ return pseudo_beta
53
+
54
+
55
+ def atom14_to_atom37(atom14, batch):
56
+ atom37_data = batched_gather(
57
+ atom14,
58
+ batch["residx_atom37_to_atom14"],
59
+ dim=-2,
60
+ no_batch_dims=len(atom14.shape[:-2]),
61
+ )
62
+
63
+ atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
64
+
65
+ return atom37_data
66
+
67
+
68
+ def build_template_angle_feat(template_feats):
69
+ template_aatype = template_feats["template_aatype"]
70
+ torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
71
+ alt_torsion_angles_sin_cos = template_feats[
72
+ "template_alt_torsion_angles_sin_cos"
73
+ ]
74
+ torsion_angles_mask = template_feats["template_torsion_angles_mask"]
75
+ template_angle_feat = torch.cat(
76
+ [
77
+ nn.functional.one_hot(template_aatype, 22),
78
+ torsion_angles_sin_cos.reshape(
79
+ *torsion_angles_sin_cos.shape[:-2], 14
80
+ ),
81
+ alt_torsion_angles_sin_cos.reshape(
82
+ *alt_torsion_angles_sin_cos.shape[:-2], 14
83
+ ),
84
+ torsion_angles_mask,
85
+ ],
86
+ dim=-1,
87
+ )
88
+
89
+ return template_angle_feat
90
+
91
+
92
+ def build_template_pair_feat(
93
+ batch,
94
+ min_bin, max_bin, no_bins,
95
+ use_unit_vector=False,
96
+ eps=1e-20, inf=1e8
97
+ ):
98
+ template_mask = batch["template_pseudo_beta_mask"]
99
+ template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
100
+
101
+ # Compute distogram (this seems to differ slightly from Alg. 5)
102
+ tpb = batch["template_pseudo_beta"]
103
+ dgram = torch.sum(
104
+ (tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True
105
+ )
106
+ lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
107
+ upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
108
+ dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
109
+
110
+ to_concat = [dgram, template_mask_2d[..., None]]
111
+
112
+ aatype_one_hot = nn.functional.one_hot(
113
+ batch["template_aatype"],
114
+ rc.restype_num + 2,
115
+ )
116
+
117
+ n_res = batch["template_aatype"].shape[-1]
118
+ to_concat.append(
119
+ aatype_one_hot[..., None, :, :].expand(
120
+ *aatype_one_hot.shape[:-2], n_res, -1, -1
121
+ )
122
+ )
123
+ to_concat.append(
124
+ aatype_one_hot[..., None, :].expand(
125
+ *aatype_one_hot.shape[:-2], -1, n_res, -1
126
+ )
127
+ )
128
+
129
+ n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
130
+ rigids = Rigid.make_transform_from_reference(
131
+ n_xyz=batch["template_all_atom_positions"][..., n, :],
132
+ ca_xyz=batch["template_all_atom_positions"][..., ca, :],
133
+ c_xyz=batch["template_all_atom_positions"][..., c, :],
134
+ eps=eps,
135
+ )
136
+ points = rigids.get_trans()[..., None, :, :]
137
+ rigid_vec = rigids[..., None].invert_apply(points)
138
+
139
+ inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec ** 2, dim=-1))
140
+
141
+ t_aa_masks = batch["template_all_atom_mask"]
142
+ template_mask = (
143
+ t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
144
+ )
145
+ template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
146
+
147
+ inv_distance_scalar = inv_distance_scalar * template_mask_2d
148
+ unit_vector = rigid_vec * inv_distance_scalar[..., None]
149
+
150
+ if(not use_unit_vector):
151
+ unit_vector = unit_vector * 0.
152
+
153
+ to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
154
+ to_concat.append(template_mask_2d[..., None])
155
+
156
+ act = torch.cat(to_concat, dim=-1)
157
+ act = act * template_mask_2d[..., None]
158
+
159
+ return act
160
+
161
+
162
+ def build_extra_msa_feat(batch):
163
+ msa_1hot = nn.functional.one_hot(batch["extra_msa"], 23)
164
+ msa_feat = [
165
+ msa_1hot,
166
+ batch["extra_has_deletion"].unsqueeze(-1),
167
+ batch["extra_deletion_value"].unsqueeze(-1),
168
+ ]
169
+ return torch.cat(msa_feat, dim=-1)
170
+
171
+
172
+ def torsion_angles_to_frames(
173
+ r: Rigid,
174
+ alpha: torch.Tensor,
175
+ aatype: torch.Tensor,
176
+ rrgdf: torch.Tensor,
177
+ ):
178
+ # [*, N, 8, 4, 4]
179
+ default_4x4 = rrgdf[aatype, ...]
180
+
181
+ # [*, N, 8] transformations, i.e.
182
+ # One [*, N, 8, 3, 3] rotation matrix and
183
+ # One [*, N, 8, 3] translation matrix
184
+ default_r = r.from_tensor_4x4(default_4x4)
185
+
186
+ bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
187
+ bb_rot[..., 1] = 1
188
+
189
+ # [*, N, 8, 2]
190
+ alpha = torch.cat(
191
+ [bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
192
+ )
193
+
194
+ # [*, N, 8, 3, 3]
195
+ # Produces rotation matrices of the form:
196
+ # [
197
+ # [1, 0 , 0 ],
198
+ # [0, a_2,-a_1],
199
+ # [0, a_1, a_2]
200
+ # ]
201
+ # This follows the original code rather than the supplement, which uses
202
+ # different indices.
203
+
204
+ all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
205
+ all_rots[..., 0, 0] = 1
206
+ all_rots[..., 1, 1] = alpha[..., 1]
207
+ all_rots[..., 1, 2] = -alpha[..., 0]
208
+ all_rots[..., 2, 1:] = alpha
209
+
210
+ all_rots = Rigid(Rotation(rot_mats=all_rots), None)
211
+
212
+ all_frames = default_r.compose(all_rots)
213
+
214
+ chi2_frame_to_frame = all_frames[..., 5]
215
+ chi3_frame_to_frame = all_frames[..., 6]
216
+ chi4_frame_to_frame = all_frames[..., 7]
217
+
218
+ chi1_frame_to_bb = all_frames[..., 4]
219
+ chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
220
+ chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
221
+ chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
222
+
223
+ all_frames_to_bb = Rigid.cat(
224
+ [
225
+ all_frames[..., :5],
226
+ chi2_frame_to_bb.unsqueeze(-1),
227
+ chi3_frame_to_bb.unsqueeze(-1),
228
+ chi4_frame_to_bb.unsqueeze(-1),
229
+ ],
230
+ dim=-1,
231
+ )
232
+
233
+ all_frames_to_global = r[..., None].compose(all_frames_to_bb)
234
+
235
+ return all_frames_to_global
236
+
237
+
238
+ def frames_and_literature_positions_to_atom14_pos(
239
+ r: Rigid,
240
+ aatype: torch.Tensor,
241
+ default_frames,
242
+ group_idx,
243
+ atom_mask,
244
+ lit_positions,
245
+ ):
246
+ # [*, N, 14, 4, 4]
247
+ default_4x4 = default_frames[aatype, ...]
248
+
249
+ # [*, N, 14]
250
+ group_mask = group_idx[aatype, ...]
251
+
252
+ # [*, N, 14, 8]
253
+ group_mask = nn.functional.one_hot(
254
+ group_mask,
255
+ num_classes=default_frames.shape[-3],
256
+ )
257
+
258
+ # [*, N, 14, 8]
259
+ t_atoms_to_global = r[..., None, :] * group_mask
260
+
261
+ # [*, N, 14]
262
+ t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
263
+ lambda x: torch.sum(x, dim=-1)
264
+ )
265
+
266
+ # [*, N, 14, 1]
267
+ atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
268
+
269
+ # [*, N, 14, 3]
270
+ lit_positions = lit_positions[aatype, ...]
271
+ pred_positions = t_atoms_to_global.apply(lit_positions)
272
+ pred_positions = pred_positions * atom_mask
273
+
274
+ return pred_positions
openfold/utils/loss.py ADDED
@@ -0,0 +1,1614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ import logging
18
+ import ml_collections
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.distributions.bernoulli import Bernoulli
23
+ from typing import Dict, Optional, Tuple
24
+
25
+ from openfold.np import residue_constants
26
+ from openfold.utils import feats
27
+ from openfold.utils.rigid_utils import Rotation, Rigid
28
+ from openfold.utils.tensor_utils import (
29
+ tree_map,
30
+ tensor_tree_map,
31
+ masked_mean,
32
+ permute_final_dims,
33
+ batched_gather,
34
+ )
35
+
36
+
37
+ def softmax_cross_entropy(logits, labels):
38
+ loss = -1 * torch.sum(
39
+ labels * torch.nn.functional.log_softmax(logits, dim=-1),
40
+ dim=-1,
41
+ )
42
+ return loss
43
+
44
+
45
+ def sigmoid_cross_entropy(logits, labels):
46
+ logits_dtype = logits.dtype
47
+ logits = logits.double()
48
+ labels = labels.double()
49
+ log_p = torch.nn.functional.logsigmoid(logits)
50
+ # log_p = torch.log(torch.sigmoid(logits))
51
+ log_not_p = torch.nn.functional.logsigmoid(-1 * logits)
52
+ # log_not_p = torch.log(torch.sigmoid(-logits))
53
+ loss = (-1. * labels) * log_p - (1. - labels) * log_not_p
54
+ loss = loss.to(dtype=logits_dtype)
55
+ return loss
56
+
57
+
58
+ def torsion_angle_loss(
59
+ a, # [*, N, 7, 2]
60
+ a_gt, # [*, N, 7, 2]
61
+ a_alt_gt, # [*, N, 7, 2]
62
+ ):
63
+ # [*, N, 7]
64
+ norm = torch.norm(a, dim=-1)
65
+
66
+ # [*, N, 7, 2]
67
+ a = a / norm.unsqueeze(-1)
68
+
69
+ # [*, N, 7]
70
+ diff_norm_gt = torch.norm(a - a_gt, dim=-1)
71
+ diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1)
72
+ min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2)
73
+
74
+ # [*]
75
+ l_torsion = torch.mean(min_diff, dim=(-1, -2))
76
+ l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2))
77
+
78
+ an_weight = 0.02
79
+ return l_torsion + an_weight * l_angle_norm
80
+
81
+
82
+ def compute_fape(
83
+ pred_frames: Rigid,
84
+ target_frames: Rigid,
85
+ frames_mask: torch.Tensor,
86
+ pred_positions: torch.Tensor,
87
+ target_positions: torch.Tensor,
88
+ positions_mask: torch.Tensor,
89
+ length_scale: float,
90
+ l1_clamp_distance: Optional[float] = None,
91
+ eps=1e-8,
92
+ ) -> torch.Tensor:
93
+ """
94
+ Computes FAPE loss.
95
+
96
+ Args:
97
+ pred_frames:
98
+ [*, N_frames] Rigid object of predicted frames
99
+ target_frames:
100
+ [*, N_frames] Rigid object of ground truth frames
101
+ frames_mask:
102
+ [*, N_frames] binary mask for the frames
103
+ pred_positions:
104
+ [*, N_pts, 3] predicted atom positions
105
+ target_positions:
106
+ [*, N_pts, 3] ground truth positions
107
+ positions_mask:
108
+ [*, N_pts] positions mask
109
+ length_scale:
110
+ Length scale by which the loss is divided
111
+ l1_clamp_distance:
112
+ Cutoff above which distance errors are disregarded
113
+ eps:
114
+ Small value used to regularize denominators
115
+ Returns:
116
+ [*] loss tensor
117
+ """
118
+ # [*, N_frames, N_pts, 3]
119
+ local_pred_pos = pred_frames.invert()[..., None].apply(
120
+ pred_positions[..., None, :, :],
121
+ )
122
+ local_target_pos = target_frames.invert()[..., None].apply(
123
+ target_positions[..., None, :, :],
124
+ )
125
+
126
+ error_dist = torch.sqrt(
127
+ torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
128
+ )
129
+
130
+ if l1_clamp_distance is not None:
131
+ error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
132
+
133
+ normed_error = error_dist / length_scale
134
+ normed_error = normed_error * frames_mask[..., None]
135
+ normed_error = normed_error * positions_mask[..., None, :]
136
+
137
+ # FP16-friendly averaging. Roughly equivalent to:
138
+ #
139
+ # norm_factor = (
140
+ # torch.sum(frames_mask, dim=-1) *
141
+ # torch.sum(positions_mask, dim=-1)
142
+ # )
143
+ # normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
144
+ #
145
+ # ("roughly" because eps is necessarily duplicated in the latter)
146
+ normed_error = torch.sum(normed_error, dim=-1)
147
+ normed_error = (
148
+ normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
149
+ )
150
+ normed_error = torch.sum(normed_error, dim=-1)
151
+ normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
152
+
153
+ return normed_error
154
+
155
+
156
+ def backbone_loss(
157
+ backbone_rigid_tensor: torch.Tensor,
158
+ backbone_rigid_mask: torch.Tensor,
159
+ traj: torch.Tensor,
160
+ use_clamped_fape: Optional[torch.Tensor] = None,
161
+ clamp_distance: float = 10.0,
162
+ loss_unit_distance: float = 10.0,
163
+ eps: float = 1e-4,
164
+ **kwargs,
165
+ ) -> torch.Tensor:
166
+ pred_aff = Rigid.from_tensor_7(traj)
167
+ pred_aff = Rigid(
168
+ Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
169
+ pred_aff.get_trans(),
170
+ )
171
+
172
+ # DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
173
+ # backbone tensor, normalizes it, and then turns it back to a rotation
174
+ # matrix. To avoid a potentially numerically unstable rotation matrix
175
+ # to quaternion conversion, we just use the original rotation matrix
176
+ # outright. This one hasn't been composed a bunch of times, though, so
177
+ # it might be fine.
178
+ gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
179
+
180
+ fape_loss = compute_fape(
181
+ pred_aff,
182
+ gt_aff[None],
183
+ backbone_rigid_mask[None],
184
+ pred_aff.get_trans(),
185
+ gt_aff[None].get_trans(),
186
+ backbone_rigid_mask[None],
187
+ l1_clamp_distance=clamp_distance,
188
+ length_scale=loss_unit_distance,
189
+ eps=eps,
190
+ )
191
+ if use_clamped_fape is not None:
192
+ unclamped_fape_loss = compute_fape(
193
+ pred_aff,
194
+ gt_aff[None],
195
+ backbone_rigid_mask[None],
196
+ pred_aff.get_trans(),
197
+ gt_aff[None].get_trans(),
198
+ backbone_rigid_mask[None],
199
+ l1_clamp_distance=None,
200
+ length_scale=loss_unit_distance,
201
+ eps=eps,
202
+ )
203
+
204
+ fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (
205
+ 1 - use_clamped_fape
206
+ )
207
+
208
+ # Average over the batch dimension
209
+ fape_loss = torch.mean(fape_loss)
210
+
211
+ return fape_loss
212
+
213
+
214
+ def sidechain_loss(
215
+ sidechain_frames: torch.Tensor,
216
+ sidechain_atom_pos: torch.Tensor,
217
+ rigidgroups_gt_frames: torch.Tensor,
218
+ rigidgroups_alt_gt_frames: torch.Tensor,
219
+ rigidgroups_gt_exists: torch.Tensor,
220
+ renamed_atom14_gt_positions: torch.Tensor,
221
+ renamed_atom14_gt_exists: torch.Tensor,
222
+ alt_naming_is_better: torch.Tensor,
223
+ clamp_distance: float = 10.0,
224
+ length_scale: float = 10.0,
225
+ eps: float = 1e-4,
226
+ **kwargs,
227
+ ) -> torch.Tensor:
228
+ renamed_gt_frames = (
229
+ 1.0 - alt_naming_is_better[..., None, None, None]
230
+ ) * rigidgroups_gt_frames + alt_naming_is_better[
231
+ ..., None, None, None
232
+ ] * rigidgroups_alt_gt_frames
233
+
234
+ # Steamroll the inputs
235
+ sidechain_frames = sidechain_frames[-1]
236
+ batch_dims = sidechain_frames.shape[:-4]
237
+ sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
238
+ sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
239
+ renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
240
+ renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
241
+ rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
242
+ sidechain_atom_pos = sidechain_atom_pos[-1]
243
+ sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
244
+ renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(
245
+ *batch_dims, -1, 3
246
+ )
247
+ renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
248
+
249
+ fape = compute_fape(
250
+ sidechain_frames,
251
+ renamed_gt_frames,
252
+ rigidgroups_gt_exists,
253
+ sidechain_atom_pos,
254
+ renamed_atom14_gt_positions,
255
+ renamed_atom14_gt_exists,
256
+ l1_clamp_distance=clamp_distance,
257
+ length_scale=length_scale,
258
+ eps=eps,
259
+ )
260
+
261
+ return fape
262
+
263
+
264
+ def fape_loss(
265
+ out: Dict[str, torch.Tensor],
266
+ batch: Dict[str, torch.Tensor],
267
+ config: ml_collections.ConfigDict,
268
+ ) -> torch.Tensor:
269
+ bb_loss = backbone_loss(
270
+ traj=out["sm"]["frames"],
271
+ **{**batch, **config.backbone},
272
+ )
273
+
274
+ sc_loss = sidechain_loss(
275
+ out["sm"]["sidechain_frames"],
276
+ out["sm"]["positions"],
277
+ **{**batch, **config.sidechain},
278
+ )
279
+
280
+ loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
281
+
282
+ # Average over the batch dimension
283
+ loss = torch.mean(loss)
284
+
285
+ return loss
286
+
287
+
288
+ def supervised_chi_loss(
289
+ angles_sin_cos: torch.Tensor,
290
+ unnormalized_angles_sin_cos: torch.Tensor,
291
+ aatype: torch.Tensor,
292
+ seq_mask: torch.Tensor,
293
+ chi_mask: torch.Tensor,
294
+ chi_angles_sin_cos: torch.Tensor,
295
+ chi_weight: float,
296
+ angle_norm_weight: float,
297
+ eps=1e-6,
298
+ **kwargs,
299
+ ) -> torch.Tensor:
300
+ """
301
+ Implements Algorithm 27 (torsionAngleLoss)
302
+
303
+ Args:
304
+ angles_sin_cos:
305
+ [*, N, 7, 2] predicted angles
306
+ unnormalized_angles_sin_cos:
307
+ The same angles, but unnormalized
308
+ aatype:
309
+ [*, N] residue indices
310
+ seq_mask:
311
+ [*, N] sequence mask
312
+ chi_mask:
313
+ [*, N, 7] angle mask
314
+ chi_angles_sin_cos:
315
+ [*, N, 7, 2] ground truth angles
316
+ chi_weight:
317
+ Weight for the angle component of the loss
318
+ angle_norm_weight:
319
+ Weight for the normalization component of the loss
320
+ Returns:
321
+ [*] loss tensor
322
+ """
323
+ pred_angles = angles_sin_cos[..., 3:, :]
324
+ residue_type_one_hot = torch.nn.functional.one_hot(
325
+ aatype,
326
+ residue_constants.restype_num + 1,
327
+ )
328
+ chi_pi_periodic = torch.einsum(
329
+ "...ij,jk->ik",
330
+ residue_type_one_hot.type(angles_sin_cos.dtype),
331
+ angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
332
+ )
333
+
334
+ true_chi = chi_angles_sin_cos[None]
335
+
336
+ shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
337
+ true_chi_shifted = shifted_mask * true_chi
338
+ sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
339
+ sq_chi_error_shifted = torch.sum(
340
+ (true_chi_shifted - pred_angles) ** 2, dim=-1
341
+ )
342
+ sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
343
+
344
+ # The ol' switcheroo
345
+ sq_chi_error = sq_chi_error.permute(
346
+ *range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
347
+ )
348
+
349
+ sq_chi_loss = masked_mean(
350
+ chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
351
+ )
352
+
353
+ loss = chi_weight * sq_chi_loss
354
+
355
+ angle_norm = torch.sqrt(
356
+ torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
357
+ )
358
+ norm_error = torch.abs(angle_norm - 1.0)
359
+ norm_error = norm_error.permute(
360
+ *range(len(norm_error.shape))[1:-2], 0, -2, -1
361
+ )
362
+ angle_norm_loss = masked_mean(
363
+ seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
364
+ )
365
+
366
+ loss = loss + angle_norm_weight * angle_norm_loss
367
+
368
+ # Average over the batch dimension
369
+ loss = torch.mean(loss)
370
+
371
+ return loss
372
+
373
+
374
+ def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
375
+ num_bins = logits.shape[-1]
376
+ bin_width = 1.0 / num_bins
377
+ bounds = torch.arange(
378
+ start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device
379
+ )
380
+ probs = torch.nn.functional.softmax(logits, dim=-1)
381
+ pred_lddt_ca = torch.sum(
382
+ probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
383
+ dim=-1,
384
+ )
385
+ return pred_lddt_ca * 100
386
+
387
+
388
+ def lddt(
389
+ all_atom_pred_pos: torch.Tensor,
390
+ all_atom_positions: torch.Tensor,
391
+ all_atom_mask: torch.Tensor,
392
+ cutoff: float = 15.0,
393
+ eps: float = 1e-10,
394
+ per_residue: bool = True,
395
+ ) -> torch.Tensor:
396
+ n = all_atom_mask.shape[-2]
397
+ dmat_true = torch.sqrt(
398
+ eps
399
+ + torch.sum(
400
+ (
401
+ all_atom_positions[..., None, :]
402
+ - all_atom_positions[..., None, :, :]
403
+ )
404
+ ** 2,
405
+ dim=-1,
406
+ )
407
+ )
408
+
409
+ dmat_pred = torch.sqrt(
410
+ eps
411
+ + torch.sum(
412
+ (
413
+ all_atom_pred_pos[..., None, :]
414
+ - all_atom_pred_pos[..., None, :, :]
415
+ )
416
+ ** 2,
417
+ dim=-1,
418
+ )
419
+ )
420
+ dists_to_score = (
421
+ (dmat_true < cutoff)
422
+ * all_atom_mask
423
+ * permute_final_dims(all_atom_mask, (1, 0))
424
+ * (1.0 - torch.eye(n, device=all_atom_mask.device))
425
+ )
426
+
427
+ dist_l1 = torch.abs(dmat_true - dmat_pred)
428
+
429
+ score = (
430
+ (dist_l1 < 0.5).type(dist_l1.dtype)
431
+ + (dist_l1 < 1.0).type(dist_l1.dtype)
432
+ + (dist_l1 < 2.0).type(dist_l1.dtype)
433
+ + (dist_l1 < 4.0).type(dist_l1.dtype)
434
+ )
435
+ score = score * 0.25
436
+
437
+ dims = (-1,) if per_residue else (-2, -1)
438
+ norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
439
+ score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
440
+
441
+ return score
442
+
443
+
444
+ def lddt_ca(
445
+ all_atom_pred_pos: torch.Tensor,
446
+ all_atom_positions: torch.Tensor,
447
+ all_atom_mask: torch.Tensor,
448
+ cutoff: float = 15.0,
449
+ eps: float = 1e-10,
450
+ per_residue: bool = True,
451
+ ) -> torch.Tensor:
452
+ ca_pos = residue_constants.atom_order["CA"]
453
+ all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
454
+ all_atom_positions = all_atom_positions[..., ca_pos, :]
455
+ all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
456
+
457
+ return lddt(
458
+ all_atom_pred_pos,
459
+ all_atom_positions,
460
+ all_atom_mask,
461
+ cutoff=cutoff,
462
+ eps=eps,
463
+ per_residue=per_residue,
464
+ )
465
+
466
+
467
+ def lddt_loss(
468
+ logits: torch.Tensor,
469
+ all_atom_pred_pos: torch.Tensor,
470
+ all_atom_positions: torch.Tensor,
471
+ all_atom_mask: torch.Tensor,
472
+ resolution: torch.Tensor,
473
+ cutoff: float = 15.0,
474
+ no_bins: int = 50,
475
+ min_resolution: float = 0.1,
476
+ max_resolution: float = 3.0,
477
+ eps: float = 1e-10,
478
+ **kwargs,
479
+ ) -> torch.Tensor:
480
+ n = all_atom_mask.shape[-2]
481
+
482
+ ca_pos = residue_constants.atom_order["CA"]
483
+ all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
484
+ all_atom_positions = all_atom_positions[..., ca_pos, :]
485
+ all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
486
+
487
+ score = lddt(
488
+ all_atom_pred_pos,
489
+ all_atom_positions,
490
+ all_atom_mask,
491
+ cutoff=cutoff,
492
+ eps=eps
493
+ )
494
+
495
+ score = score.detach()
496
+
497
+ bin_index = torch.floor(score * no_bins).long()
498
+ bin_index = torch.clamp(bin_index, max=(no_bins - 1))
499
+ lddt_ca_one_hot = torch.nn.functional.one_hot(
500
+ bin_index, num_classes=no_bins
501
+ )
502
+
503
+ errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
504
+ all_atom_mask = all_atom_mask.squeeze(-1)
505
+ loss = torch.sum(errors * all_atom_mask, dim=-1) / (
506
+ eps + torch.sum(all_atom_mask, dim=-1)
507
+ )
508
+
509
+ loss = loss * (
510
+ (resolution >= min_resolution) & (resolution <= max_resolution)
511
+ )
512
+
513
+ # Average over the batch dimension
514
+ loss = torch.mean(loss)
515
+
516
+ return loss
517
+
518
+
519
+ def distogram_loss(
520
+ logits,
521
+ pseudo_beta,
522
+ pseudo_beta_mask,
523
+ min_bin=2.3125,
524
+ max_bin=21.6875,
525
+ no_bins=64,
526
+ eps=1e-6,
527
+ **kwargs,
528
+ ):
529
+ boundaries = torch.linspace(
530
+ min_bin,
531
+ max_bin,
532
+ no_bins - 1,
533
+ device=logits.device,
534
+ )
535
+ boundaries = boundaries ** 2
536
+
537
+ dists = torch.sum(
538
+ (pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
539
+ dim=-1,
540
+ keepdims=True,
541
+ )
542
+
543
+ true_bins = torch.sum(dists > boundaries, dim=-1)
544
+
545
+ errors = softmax_cross_entropy(
546
+ logits,
547
+ torch.nn.functional.one_hot(true_bins, no_bins),
548
+ )
549
+
550
+ square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
551
+
552
+ # FP16-friendly sum. Equivalent to:
553
+ # mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
554
+ # (eps + torch.sum(square_mask, dim=(-1, -2))))
555
+ denom = eps + torch.sum(square_mask, dim=(-1, -2))
556
+ mean = errors * square_mask
557
+ mean = torch.sum(mean, dim=-1)
558
+ mean = mean / denom[..., None]
559
+ mean = torch.sum(mean, dim=-1)
560
+
561
+ # Average over the batch dimensions
562
+ mean = torch.mean(mean)
563
+
564
+ return mean
565
+
566
+
567
+ def _calculate_bin_centers(boundaries: torch.Tensor):
568
+ step = boundaries[1] - boundaries[0]
569
+ bin_centers = boundaries + step / 2
570
+ bin_centers = torch.cat(
571
+ [bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
572
+ )
573
+ return bin_centers
574
+
575
+
576
+ def _calculate_expected_aligned_error(
577
+ alignment_confidence_breaks: torch.Tensor,
578
+ aligned_distance_error_probs: torch.Tensor,
579
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
580
+ bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
581
+ return (
582
+ torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
583
+ bin_centers[-1],
584
+ )
585
+
586
+
587
+ def compute_predicted_aligned_error(
588
+ logits: torch.Tensor,
589
+ max_bin: int = 31,
590
+ no_bins: int = 64,
591
+ **kwargs,
592
+ ) -> Dict[str, torch.Tensor]:
593
+ """Computes aligned confidence metrics from logits.
594
+
595
+ Args:
596
+ logits: [*, num_res, num_res, num_bins] the logits output from
597
+ PredictedAlignedErrorHead.
598
+ max_bin: Maximum bin value
599
+ no_bins: Number of bins
600
+ Returns:
601
+ aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
602
+ aligned error probabilities over bins for each residue pair.
603
+ predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
604
+ error for each pair of residues.
605
+ max_predicted_aligned_error: [*] the maximum predicted error possible.
606
+ """
607
+ boundaries = torch.linspace(
608
+ 0, max_bin, steps=(no_bins - 1), device=logits.device
609
+ )
610
+
611
+ aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
612
+ (
613
+ predicted_aligned_error,
614
+ max_predicted_aligned_error,
615
+ ) = _calculate_expected_aligned_error(
616
+ alignment_confidence_breaks=boundaries,
617
+ aligned_distance_error_probs=aligned_confidence_probs,
618
+ )
619
+
620
+ return {
621
+ "aligned_confidence_probs": aligned_confidence_probs,
622
+ "predicted_aligned_error": predicted_aligned_error,
623
+ "max_predicted_aligned_error": max_predicted_aligned_error,
624
+ }
625
+
626
+
627
+ def compute_tm(
628
+ logits: torch.Tensor,
629
+ residue_weights: Optional[torch.Tensor] = None,
630
+ max_bin: int = 31,
631
+ no_bins: int = 64,
632
+ eps: float = 1e-8,
633
+ **kwargs,
634
+ ) -> torch.Tensor:
635
+ if residue_weights is None:
636
+ residue_weights = logits.new_ones(logits.shape[-2])
637
+
638
+ boundaries = torch.linspace(
639
+ 0, max_bin, steps=(no_bins - 1), device=logits.device
640
+ )
641
+
642
+ bin_centers = _calculate_bin_centers(boundaries)
643
+ torch.sum(residue_weights)
644
+ n = logits.shape[-2]
645
+ clipped_n = max(n, 19)
646
+
647
+ d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
648
+
649
+ probs = torch.nn.functional.softmax(logits, dim=-1)
650
+
651
+ tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
652
+ predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
653
+
654
+ normed_residue_mask = residue_weights / (eps + residue_weights.sum())
655
+ per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
656
+ weighted = per_alignment * residue_weights
657
+ argmax = (weighted == torch.max(weighted)).nonzero()[0]
658
+ return per_alignment[tuple(argmax)]
659
+
660
+
661
+ def tm_loss(
662
+ logits,
663
+ final_affine_tensor,
664
+ backbone_rigid_tensor,
665
+ backbone_rigid_mask,
666
+ resolution,
667
+ max_bin=31,
668
+ no_bins=64,
669
+ min_resolution: float = 0.1,
670
+ max_resolution: float = 3.0,
671
+ eps=1e-8,
672
+ **kwargs,
673
+ ):
674
+ pred_affine = Rigid.from_tensor_7(final_affine_tensor)
675
+ backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
676
+
677
+ def _points(affine):
678
+ pts = affine.get_trans()[..., None, :, :]
679
+ return affine.invert()[..., None].apply(pts)
680
+
681
+ sq_diff = torch.sum(
682
+ (_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1
683
+ )
684
+
685
+ sq_diff = sq_diff.detach()
686
+
687
+ boundaries = torch.linspace(
688
+ 0, max_bin, steps=(no_bins - 1), device=logits.device
689
+ )
690
+ boundaries = boundaries ** 2
691
+ true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1)
692
+
693
+ errors = softmax_cross_entropy(
694
+ logits, torch.nn.functional.one_hot(true_bins, no_bins)
695
+ )
696
+
697
+ square_mask = (
698
+ backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
699
+ )
700
+
701
+ loss = torch.sum(errors * square_mask, dim=-1)
702
+ scale = 0.5 # hack to help FP16 training along
703
+ denom = eps + torch.sum(scale * square_mask, dim=(-1, -2))
704
+ loss = loss / denom[..., None]
705
+ loss = torch.sum(loss, dim=-1)
706
+ loss = loss * scale
707
+
708
+ loss = loss * (
709
+ (resolution >= min_resolution) & (resolution <= max_resolution)
710
+ )
711
+
712
+ # Average over the loss dimension
713
+ loss = torch.mean(loss)
714
+
715
+ return loss
716
+
717
+
718
+ def between_residue_bond_loss(
719
+ pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3)
720
+ pred_atom_mask: torch.Tensor, # (*, N, 37/14)
721
+ residue_index: torch.Tensor, # (*, N)
722
+ aatype: torch.Tensor, # (*, N)
723
+ tolerance_factor_soft=12.0,
724
+ tolerance_factor_hard=12.0,
725
+ eps=1e-6,
726
+ ) -> Dict[str, torch.Tensor]:
727
+ """Flat-bottom loss to penalize structural violations between residues.
728
+
729
+ This is a loss penalizing any violation of the geometry around the peptide
730
+ bond between consecutive amino acids. This loss corresponds to
731
+ Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
732
+
733
+ Args:
734
+ pred_atom_positions: Atom positions in atom37/14 representation
735
+ pred_atom_mask: Atom mask in atom37/14 representation
736
+ residue_index: Residue index for given amino acid, this is assumed to be
737
+ monotonically increasing.
738
+ aatype: Amino acid type of given residue
739
+ tolerance_factor_soft: soft tolerance factor measured in standard deviations
740
+ of pdb distributions
741
+ tolerance_factor_hard: hard tolerance factor measured in standard deviations
742
+ of pdb distributions
743
+
744
+ Returns:
745
+ Dict containing:
746
+ * 'c_n_loss_mean': Loss for peptide bond length violations
747
+ * 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned
748
+ by CA, C, N
749
+ * 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned
750
+ by C, N, CA
751
+ * 'per_residue_loss_sum': sum of all losses for each residue
752
+ * 'per_residue_violation_mask': mask denoting all residues with violation
753
+ present.
754
+ """
755
+ # Get the positions of the relevant backbone atoms.
756
+ this_ca_pos = pred_atom_positions[..., :-1, 1, :]
757
+ this_ca_mask = pred_atom_mask[..., :-1, 1]
758
+ this_c_pos = pred_atom_positions[..., :-1, 2, :]
759
+ this_c_mask = pred_atom_mask[..., :-1, 2]
760
+ next_n_pos = pred_atom_positions[..., 1:, 0, :]
761
+ next_n_mask = pred_atom_mask[..., 1:, 0]
762
+ next_ca_pos = pred_atom_positions[..., 1:, 1, :]
763
+ next_ca_mask = pred_atom_mask[..., 1:, 1]
764
+ has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
765
+
766
+ # Compute loss for the C--N bond.
767
+ c_n_bond_length = torch.sqrt(
768
+ eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1)
769
+ )
770
+
771
+ # The C-N bond to proline has slightly different length because of the ring.
772
+ next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
773
+ gt_length = (
774
+ ~next_is_proline
775
+ ) * residue_constants.between_res_bond_length_c_n[
776
+ 0
777
+ ] + next_is_proline * residue_constants.between_res_bond_length_c_n[
778
+ 1
779
+ ]
780
+ gt_stddev = (
781
+ ~next_is_proline
782
+ ) * residue_constants.between_res_bond_length_stddev_c_n[
783
+ 0
784
+ ] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
785
+ 1
786
+ ]
787
+ c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
788
+ c_n_loss_per_residue = torch.nn.functional.relu(
789
+ c_n_bond_length_error - tolerance_factor_soft * gt_stddev
790
+ )
791
+ mask = this_c_mask * next_n_mask * has_no_gap_mask
792
+ c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (
793
+ torch.sum(mask, dim=-1) + eps
794
+ )
795
+ c_n_violation_mask = mask * (
796
+ c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)
797
+ )
798
+
799
+ # Compute loss for the angles.
800
+ ca_c_bond_length = torch.sqrt(
801
+ eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1)
802
+ )
803
+ n_ca_bond_length = torch.sqrt(
804
+ eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1)
805
+ )
806
+
807
+ c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None]
808
+ c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None]
809
+ n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None]
810
+
811
+ ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1)
812
+ gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
813
+ gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
814
+ ca_c_n_cos_angle_error = torch.sqrt(
815
+ eps + (ca_c_n_cos_angle - gt_angle) ** 2
816
+ )
817
+ ca_c_n_loss_per_residue = torch.nn.functional.relu(
818
+ ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev
819
+ )
820
+ mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
821
+ ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (
822
+ torch.sum(mask, dim=-1) + eps
823
+ )
824
+ ca_c_n_violation_mask = mask * (
825
+ ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)
826
+ )
827
+
828
+ c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1)
829
+ gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
830
+ gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
831
+ c_n_ca_cos_angle_error = torch.sqrt(
832
+ eps + torch.square(c_n_ca_cos_angle - gt_angle)
833
+ )
834
+ c_n_ca_loss_per_residue = torch.nn.functional.relu(
835
+ c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev
836
+ )
837
+ mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
838
+ c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (
839
+ torch.sum(mask, dim=-1) + eps
840
+ )
841
+ c_n_ca_violation_mask = mask * (
842
+ c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)
843
+ )
844
+
845
+ # Compute a per residue loss (equally distribute the loss to both
846
+ # neighbouring residues).
847
+ per_residue_loss_sum = (
848
+ c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
849
+ )
850
+ per_residue_loss_sum = 0.5 * (
851
+ torch.nn.functional.pad(per_residue_loss_sum, (0, 1))
852
+ + torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
853
+ )
854
+
855
+ # Compute hard violations.
856
+ violation_mask = torch.max(
857
+ torch.stack(
858
+ [c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask],
859
+ dim=-2,
860
+ ),
861
+ dim=-2,
862
+ )[0]
863
+ violation_mask = torch.maximum(
864
+ torch.nn.functional.pad(violation_mask, (0, 1)),
865
+ torch.nn.functional.pad(violation_mask, (1, 0)),
866
+ )
867
+
868
+ return {
869
+ "c_n_loss_mean": c_n_loss,
870
+ "ca_c_n_loss_mean": ca_c_n_loss,
871
+ "c_n_ca_loss_mean": c_n_ca_loss,
872
+ "per_residue_loss_sum": per_residue_loss_sum,
873
+ "per_residue_violation_mask": violation_mask,
874
+ }
875
+
876
+
877
+ def between_residue_clash_loss(
878
+ atom14_pred_positions: torch.Tensor,
879
+ atom14_atom_exists: torch.Tensor,
880
+ atom14_atom_radius: torch.Tensor,
881
+ residue_index: torch.Tensor,
882
+ overlap_tolerance_soft=1.5,
883
+ overlap_tolerance_hard=1.5,
884
+ eps=1e-10,
885
+ ) -> Dict[str, torch.Tensor]:
886
+ """Loss to penalize steric clashes between residues.
887
+
888
+ This is a loss penalizing any steric clashes due to non bonded atoms in
889
+ different peptides coming too close. This loss corresponds to the part with
890
+ different residues of
891
+ Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
892
+
893
+ Args:
894
+ atom14_pred_positions: Predicted positions of atoms in
895
+ global prediction frame
896
+ atom14_atom_exists: Mask denoting whether atom at positions exists for given
897
+ amino acid type
898
+ atom14_atom_radius: Van der Waals radius for each atom.
899
+ residue_index: Residue index for given amino acid.
900
+ overlap_tolerance_soft: Soft tolerance factor.
901
+ overlap_tolerance_hard: Hard tolerance factor.
902
+
903
+ Returns:
904
+ Dict containing:
905
+ * 'mean_loss': average clash loss
906
+ * 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
907
+ * 'per_atom_clash_mask': mask whether atom clashes with any other atom
908
+ shape (N, 14)
909
+ """
910
+ fp_type = atom14_pred_positions.dtype
911
+
912
+ # Create the distance matrix.
913
+ # (N, N, 14, 14)
914
+ dists = torch.sqrt(
915
+ eps
916
+ + torch.sum(
917
+ (
918
+ atom14_pred_positions[..., :, None, :, None, :]
919
+ - atom14_pred_positions[..., None, :, None, :, :]
920
+ )
921
+ ** 2,
922
+ dim=-1,
923
+ )
924
+ )
925
+
926
+ # Create the mask for valid distances.
927
+ # shape (N, N, 14, 14)
928
+ dists_mask = (
929
+ atom14_atom_exists[..., :, None, :, None]
930
+ * atom14_atom_exists[..., None, :, None, :]
931
+ ).type(fp_type)
932
+
933
+ # Mask out all the duplicate entries in the lower triangular matrix.
934
+ # Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
935
+ # are handled separately.
936
+ dists_mask = dists_mask * (
937
+ residue_index[..., :, None, None, None]
938
+ < residue_index[..., None, :, None, None]
939
+ )
940
+
941
+ # Backbone C--N bond between subsequent residues is no clash.
942
+ c_one_hot = torch.nn.functional.one_hot(
943
+ residue_index.new_tensor(2), num_classes=14
944
+ )
945
+ c_one_hot = c_one_hot.reshape(
946
+ *((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape
947
+ )
948
+ c_one_hot = c_one_hot.type(fp_type)
949
+ n_one_hot = torch.nn.functional.one_hot(
950
+ residue_index.new_tensor(0), num_classes=14
951
+ )
952
+ n_one_hot = n_one_hot.reshape(
953
+ *((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape
954
+ )
955
+ n_one_hot = n_one_hot.type(fp_type)
956
+
957
+ neighbour_mask = (
958
+ residue_index[..., :, None, None, None] + 1
959
+ ) == residue_index[..., None, :, None, None]
960
+ c_n_bonds = (
961
+ neighbour_mask
962
+ * c_one_hot[..., None, None, :, None]
963
+ * n_one_hot[..., None, None, None, :]
964
+ )
965
+ dists_mask = dists_mask * (1.0 - c_n_bonds)
966
+
967
+ # Disulfide bridge between two cysteines is no clash.
968
+ cys = residue_constants.restype_name_to_atom14_names["CYS"]
969
+ cys_sg_idx = cys.index("SG")
970
+ cys_sg_idx = residue_index.new_tensor(cys_sg_idx)
971
+ cys_sg_idx = cys_sg_idx.reshape(
972
+ *((1,) * len(residue_index.shape[:-1])), 1
973
+ ).squeeze(-1)
974
+ cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14)
975
+ disulfide_bonds = (
976
+ cys_sg_one_hot[..., None, None, :, None]
977
+ * cys_sg_one_hot[..., None, None, None, :]
978
+ )
979
+ dists_mask = dists_mask * (1.0 - disulfide_bonds)
980
+
981
+ # Compute the lower bound for the allowed distances.
982
+ # shape (N, N, 14, 14)
983
+ dists_lower_bound = dists_mask * (
984
+ atom14_atom_radius[..., :, None, :, None]
985
+ + atom14_atom_radius[..., None, :, None, :]
986
+ )
987
+
988
+ # Compute the error.
989
+ # shape (N, N, 14, 14)
990
+ dists_to_low_error = dists_mask * torch.nn.functional.relu(
991
+ dists_lower_bound - overlap_tolerance_soft - dists
992
+ )
993
+
994
+ # Compute the mean loss.
995
+ # shape ()
996
+ mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask))
997
+
998
+ # Compute the per atom loss sum.
999
+ # shape (N, 14)
1000
+ per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
1001
+ dists_to_low_error, axis=(-3, -1)
1002
+ )
1003
+
1004
+ # Compute the hard clash mask.
1005
+ # shape (N, N, 14, 14)
1006
+ clash_mask = dists_mask * (
1007
+ dists < (dists_lower_bound - overlap_tolerance_hard)
1008
+ )
1009
+
1010
+ # Compute the per atom clash.
1011
+ # shape (N, 14)
1012
+ per_atom_clash_mask = torch.maximum(
1013
+ torch.amax(clash_mask, axis=(-4, -2)),
1014
+ torch.amax(clash_mask, axis=(-3, -1)),
1015
+ )
1016
+
1017
+ return {
1018
+ "mean_loss": mean_loss, # shape ()
1019
+ "per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
1020
+ "per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
1021
+ }
1022
+
1023
+
1024
+ def within_residue_violations(
1025
+ atom14_pred_positions: torch.Tensor,
1026
+ atom14_atom_exists: torch.Tensor,
1027
+ atom14_dists_lower_bound: torch.Tensor,
1028
+ atom14_dists_upper_bound: torch.Tensor,
1029
+ tighten_bounds_for_loss=0.0,
1030
+ eps=1e-10,
1031
+ ) -> Dict[str, torch.Tensor]:
1032
+ """Loss to penalize steric clashes within residues.
1033
+
1034
+ This is a loss penalizing any steric violations or clashes of non-bonded atoms
1035
+ in a given peptide. This loss corresponds to the part with
1036
+ the same residues of
1037
+ Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
1038
+
1039
+ Args:
1040
+ atom14_pred_positions ([*, N, 14, 3]):
1041
+ Predicted positions of atoms in global prediction frame.
1042
+ atom14_atom_exists ([*, N, 14]):
1043
+ Mask denoting whether atom at positions exists for given
1044
+ amino acid type
1045
+ atom14_dists_lower_bound ([*, N, 14]):
1046
+ Lower bound on allowed distances.
1047
+ atom14_dists_upper_bound ([*, N, 14]):
1048
+ Upper bound on allowed distances
1049
+ tighten_bounds_for_loss ([*, N]):
1050
+ Extra factor to tighten loss
1051
+
1052
+ Returns:
1053
+ Dict containing:
1054
+ * 'per_atom_loss_sum' ([*, N, 14]):
1055
+ sum of all clash losses per atom, shape
1056
+ * 'per_atom_clash_mask' ([*, N, 14]):
1057
+ mask whether atom clashes with any other atom shape
1058
+ """
1059
+ # Compute the mask for each residue.
1060
+ dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None]
1061
+ dists_masks = dists_masks.reshape(
1062
+ *((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape
1063
+ )
1064
+ dists_masks = (
1065
+ atom14_atom_exists[..., :, :, None]
1066
+ * atom14_atom_exists[..., :, None, :]
1067
+ * dists_masks
1068
+ )
1069
+
1070
+ # Distance matrix
1071
+ dists = torch.sqrt(
1072
+ eps
1073
+ + torch.sum(
1074
+ (
1075
+ atom14_pred_positions[..., :, :, None, :]
1076
+ - atom14_pred_positions[..., :, None, :, :]
1077
+ )
1078
+ ** 2,
1079
+ dim=-1,
1080
+ )
1081
+ )
1082
+
1083
+ # Compute the loss.
1084
+ dists_to_low_error = torch.nn.functional.relu(
1085
+ atom14_dists_lower_bound + tighten_bounds_for_loss - dists
1086
+ )
1087
+ dists_to_high_error = torch.nn.functional.relu(
1088
+ dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)
1089
+ )
1090
+ loss = dists_masks * (dists_to_low_error + dists_to_high_error)
1091
+
1092
+ # Compute the per atom loss sum.
1093
+ per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1)
1094
+
1095
+ # Compute the violations mask.
1096
+ violations = dists_masks * (
1097
+ (dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
1098
+ )
1099
+
1100
+ # Compute the per atom violations.
1101
+ per_atom_violations = torch.maximum(
1102
+ torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]
1103
+ )
1104
+
1105
+ return {
1106
+ "per_atom_loss_sum": per_atom_loss_sum,
1107
+ "per_atom_violations": per_atom_violations,
1108
+ }
1109
+
1110
+
1111
+ def find_structural_violations(
1112
+ batch: Dict[str, torch.Tensor],
1113
+ atom14_pred_positions: torch.Tensor,
1114
+ violation_tolerance_factor: float,
1115
+ clash_overlap_tolerance: float,
1116
+ **kwargs,
1117
+ ) -> Dict[str, torch.Tensor]:
1118
+ """Computes several checks for structural violations."""
1119
+
1120
+ # Compute between residue backbone violations of bonds and angles.
1121
+ connection_violations = between_residue_bond_loss(
1122
+ pred_atom_positions=atom14_pred_positions,
1123
+ pred_atom_mask=batch["atom14_atom_exists"],
1124
+ residue_index=batch["residue_index"],
1125
+ aatype=batch["aatype"],
1126
+ tolerance_factor_soft=violation_tolerance_factor,
1127
+ tolerance_factor_hard=violation_tolerance_factor,
1128
+ )
1129
+
1130
+ # Compute the Van der Waals radius for every atom
1131
+ # (the first letter of the atom name is the element type).
1132
+ # Shape: (N, 14).
1133
+ atomtype_radius = [
1134
+ residue_constants.van_der_waals_radius[name[0]]
1135
+ for name in residue_constants.atom_types
1136
+ ]
1137
+ atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
1138
+ atom14_atom_radius = (
1139
+ batch["atom14_atom_exists"]
1140
+ * atomtype_radius[batch["residx_atom14_to_atom37"]]
1141
+ )
1142
+
1143
+ # Compute the between residue clash loss.
1144
+ between_residue_clashes = between_residue_clash_loss(
1145
+ atom14_pred_positions=atom14_pred_positions,
1146
+ atom14_atom_exists=batch["atom14_atom_exists"],
1147
+ atom14_atom_radius=atom14_atom_radius,
1148
+ residue_index=batch["residue_index"],
1149
+ overlap_tolerance_soft=clash_overlap_tolerance,
1150
+ overlap_tolerance_hard=clash_overlap_tolerance,
1151
+ )
1152
+
1153
+ # Compute all within-residue violations (clashes,
1154
+ # bond length and angle violations).
1155
+ restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
1156
+ overlap_tolerance=clash_overlap_tolerance,
1157
+ bond_length_tolerance_factor=violation_tolerance_factor,
1158
+ )
1159
+ atom14_atom_exists = batch["atom14_atom_exists"]
1160
+ atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
1161
+ restype_atom14_bounds["lower_bound"]
1162
+ )[batch["aatype"]]
1163
+ atom14_dists_upper_bound = atom14_pred_positions.new_tensor(
1164
+ restype_atom14_bounds["upper_bound"]
1165
+ )[batch["aatype"]]
1166
+ residue_violations = within_residue_violations(
1167
+ atom14_pred_positions=atom14_pred_positions,
1168
+ atom14_atom_exists=batch["atom14_atom_exists"],
1169
+ atom14_dists_lower_bound=atom14_dists_lower_bound,
1170
+ atom14_dists_upper_bound=atom14_dists_upper_bound,
1171
+ tighten_bounds_for_loss=0.0,
1172
+ )
1173
+
1174
+ # Combine them to a single per-residue violation mask (used later for LDDT).
1175
+ per_residue_violations_mask = torch.max(
1176
+ torch.stack(
1177
+ [
1178
+ connection_violations["per_residue_violation_mask"],
1179
+ torch.max(
1180
+ between_residue_clashes["per_atom_clash_mask"], dim=-1
1181
+ )[0],
1182
+ torch.max(residue_violations["per_atom_violations"], dim=-1)[0],
1183
+ ],
1184
+ dim=-1,
1185
+ ),
1186
+ dim=-1,
1187
+ )[0]
1188
+
1189
+ return {
1190
+ "between_residues": {
1191
+ "bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # ()
1192
+ "angles_ca_c_n_loss_mean": connection_violations[
1193
+ "ca_c_n_loss_mean"
1194
+ ], # ()
1195
+ "angles_c_n_ca_loss_mean": connection_violations[
1196
+ "c_n_ca_loss_mean"
1197
+ ], # ()
1198
+ "connections_per_residue_loss_sum": connection_violations[
1199
+ "per_residue_loss_sum"
1200
+ ], # (N)
1201
+ "connections_per_residue_violation_mask": connection_violations[
1202
+ "per_residue_violation_mask"
1203
+ ], # (N)
1204
+ "clashes_mean_loss": between_residue_clashes["mean_loss"], # ()
1205
+ "clashes_per_atom_loss_sum": between_residue_clashes[
1206
+ "per_atom_loss_sum"
1207
+ ], # (N, 14)
1208
+ "clashes_per_atom_clash_mask": between_residue_clashes[
1209
+ "per_atom_clash_mask"
1210
+ ], # (N, 14)
1211
+ },
1212
+ "within_residues": {
1213
+ "per_atom_loss_sum": residue_violations[
1214
+ "per_atom_loss_sum"
1215
+ ], # (N, 14)
1216
+ "per_atom_violations": residue_violations[
1217
+ "per_atom_violations"
1218
+ ], # (N, 14),
1219
+ },
1220
+ "total_per_residue_violations_mask": per_residue_violations_mask, # (N)
1221
+ }
1222
+
1223
+
1224
+ def find_structural_violations_np(
1225
+ batch: Dict[str, np.ndarray],
1226
+ atom14_pred_positions: np.ndarray,
1227
+ config: ml_collections.ConfigDict,
1228
+ ) -> Dict[str, np.ndarray]:
1229
+ to_tensor = lambda x: torch.tensor(x)
1230
+ batch = tree_map(to_tensor, batch, np.ndarray)
1231
+ atom14_pred_positions = to_tensor(atom14_pred_positions)
1232
+
1233
+ out = find_structural_violations(batch, atom14_pred_positions, **config)
1234
+
1235
+ to_np = lambda x: np.array(x)
1236
+ np_out = tensor_tree_map(to_np, out)
1237
+
1238
+ return np_out
1239
+
1240
+
1241
+ def extreme_ca_ca_distance_violations(
1242
+ pred_atom_positions: torch.Tensor, # (N, 37(14), 3)
1243
+ pred_atom_mask: torch.Tensor, # (N, 37(14))
1244
+ residue_index: torch.Tensor, # (N)
1245
+ max_angstrom_tolerance=1.5,
1246
+ eps=1e-6,
1247
+ ) -> torch.Tensor:
1248
+ """Counts residues whose Ca is a large distance from its neighbour.
1249
+
1250
+ Measures the fraction of CA-CA pairs between consecutive amino acids that are
1251
+ more than 'max_angstrom_tolerance' apart.
1252
+
1253
+ Args:
1254
+ pred_atom_positions: Atom positions in atom37/14 representation
1255
+ pred_atom_mask: Atom mask in atom37/14 representation
1256
+ residue_index: Residue index for given amino acid, this is assumed to be
1257
+ monotonically increasing.
1258
+ max_angstrom_tolerance: Maximum distance allowed to not count as violation.
1259
+ Returns:
1260
+ Fraction of consecutive CA-CA pairs with violation.
1261
+ """
1262
+ this_ca_pos = pred_atom_positions[..., :-1, 1, :]
1263
+ this_ca_mask = pred_atom_mask[..., :-1, 1]
1264
+ next_ca_pos = pred_atom_positions[..., 1:, 1, :]
1265
+ next_ca_mask = pred_atom_mask[..., 1:, 1]
1266
+ has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
1267
+ ca_ca_distance = torch.sqrt(
1268
+ eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1)
1269
+ )
1270
+ violations = (
1271
+ ca_ca_distance - residue_constants.ca_ca
1272
+ ) > max_angstrom_tolerance
1273
+ mask = this_ca_mask * next_ca_mask * has_no_gap_mask
1274
+ mean = masked_mean(mask, violations, -1)
1275
+ return mean
1276
+
1277
+
1278
+ def compute_violation_metrics(
1279
+ batch: Dict[str, torch.Tensor],
1280
+ atom14_pred_positions: torch.Tensor, # (N, 14, 3)
1281
+ violations: Dict[str, torch.Tensor],
1282
+ ) -> Dict[str, torch.Tensor]:
1283
+ """Compute several metrics to assess the structural violations."""
1284
+ ret = {}
1285
+ extreme_ca_ca_violations = extreme_ca_ca_distance_violations(
1286
+ pred_atom_positions=atom14_pred_positions,
1287
+ pred_atom_mask=batch["atom14_atom_exists"],
1288
+ residue_index=batch["residue_index"],
1289
+ )
1290
+ ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations
1291
+ ret["violations_between_residue_bond"] = masked_mean(
1292
+ batch["seq_mask"],
1293
+ violations["between_residues"][
1294
+ "connections_per_residue_violation_mask"
1295
+ ],
1296
+ dim=-1,
1297
+ )
1298
+ ret["violations_between_residue_clash"] = masked_mean(
1299
+ mask=batch["seq_mask"],
1300
+ value=torch.max(
1301
+ violations["between_residues"]["clashes_per_atom_clash_mask"],
1302
+ dim=-1,
1303
+ )[0],
1304
+ dim=-1,
1305
+ )
1306
+ ret["violations_within_residue"] = masked_mean(
1307
+ mask=batch["seq_mask"],
1308
+ value=torch.max(
1309
+ violations["within_residues"]["per_atom_violations"], dim=-1
1310
+ )[0],
1311
+ dim=-1,
1312
+ )
1313
+ ret["violations_per_residue"] = masked_mean(
1314
+ mask=batch["seq_mask"],
1315
+ value=violations["total_per_residue_violations_mask"],
1316
+ dim=-1,
1317
+ )
1318
+ return ret
1319
+
1320
+
1321
+ def compute_violation_metrics_np(
1322
+ batch: Dict[str, np.ndarray],
1323
+ atom14_pred_positions: np.ndarray,
1324
+ violations: Dict[str, np.ndarray],
1325
+ ) -> Dict[str, np.ndarray]:
1326
+ to_tensor = lambda x: torch.tensor(x)
1327
+ batch = tree_map(to_tensor, batch, np.ndarray)
1328
+ atom14_pred_positions = to_tensor(atom14_pred_positions)
1329
+ violations = tree_map(to_tensor, violations, np.ndarray)
1330
+
1331
+ out = compute_violation_metrics(batch, atom14_pred_positions, violations)
1332
+
1333
+ to_np = lambda x: np.array(x)
1334
+ return tree_map(to_np, out, torch.Tensor)
1335
+
1336
+
1337
+ def violation_loss(
1338
+ violations: Dict[str, torch.Tensor],
1339
+ atom14_atom_exists: torch.Tensor,
1340
+ eps=1e-6,
1341
+ **kwargs,
1342
+ ) -> torch.Tensor:
1343
+ num_atoms = torch.sum(atom14_atom_exists)
1344
+ l_clash = torch.sum(
1345
+ violations["between_residues"]["clashes_per_atom_loss_sum"]
1346
+ + violations["within_residues"]["per_atom_loss_sum"]
1347
+ )
1348
+ l_clash = l_clash / (eps + num_atoms)
1349
+ loss = (
1350
+ violations["between_residues"]["bonds_c_n_loss_mean"]
1351
+ + violations["between_residues"]["angles_ca_c_n_loss_mean"]
1352
+ + violations["between_residues"]["angles_c_n_ca_loss_mean"]
1353
+ + l_clash
1354
+ )
1355
+
1356
+ return loss
1357
+
1358
+
1359
+ def compute_renamed_ground_truth(
1360
+ batch: Dict[str, torch.Tensor],
1361
+ atom14_pred_positions: torch.Tensor,
1362
+ eps=1e-10,
1363
+ ) -> Dict[str, torch.Tensor]:
1364
+ """
1365
+ Find optimal renaming of ground truth based on the predicted positions.
1366
+
1367
+ Alg. 26 "renameSymmetricGroundTruthAtoms"
1368
+
1369
+ This renamed ground truth is then used for all losses,
1370
+ such that each loss moves the atoms in the same direction.
1371
+
1372
+ Args:
1373
+ batch: Dictionary containing:
1374
+ * atom14_gt_positions: Ground truth positions.
1375
+ * atom14_alt_gt_positions: Ground truth positions with renaming swaps.
1376
+ * atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
1377
+ renaming swaps.
1378
+ * atom14_gt_exists: Mask for which atoms exist in ground truth.
1379
+ * atom14_alt_gt_exists: Mask for which atoms exist in ground truth
1380
+ after renaming.
1381
+ * atom14_atom_exists: Mask for whether each atom is part of the given
1382
+ amino acid type.
1383
+ atom14_pred_positions: Array of atom positions in global frame with shape
1384
+ Returns:
1385
+ Dictionary containing:
1386
+ alt_naming_is_better: Array with 1.0 where alternative swap is better.
1387
+ renamed_atom14_gt_positions: Array of optimal ground truth positions
1388
+ after renaming swaps are performed.
1389
+ renamed_atom14_gt_exists: Mask after renaming swap is performed.
1390
+ """
1391
+
1392
+ pred_dists = torch.sqrt(
1393
+ eps
1394
+ + torch.sum(
1395
+ (
1396
+ atom14_pred_positions[..., None, :, None, :]
1397
+ - atom14_pred_positions[..., None, :, None, :, :]
1398
+ )
1399
+ ** 2,
1400
+ dim=-1,
1401
+ )
1402
+ )
1403
+
1404
+ atom14_gt_positions = batch["atom14_gt_positions"]
1405
+ gt_dists = torch.sqrt(
1406
+ eps
1407
+ + torch.sum(
1408
+ (
1409
+ atom14_gt_positions[..., None, :, None, :]
1410
+ - atom14_gt_positions[..., None, :, None, :, :]
1411
+ )
1412
+ ** 2,
1413
+ dim=-1,
1414
+ )
1415
+ )
1416
+
1417
+ atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
1418
+ alt_gt_dists = torch.sqrt(
1419
+ eps
1420
+ + torch.sum(
1421
+ (
1422
+ atom14_alt_gt_positions[..., None, :, None, :]
1423
+ - atom14_alt_gt_positions[..., None, :, None, :, :]
1424
+ )
1425
+ ** 2,
1426
+ dim=-1,
1427
+ )
1428
+ )
1429
+
1430
+ lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
1431
+ alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
1432
+
1433
+ atom14_gt_exists = batch["atom14_gt_exists"]
1434
+ atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
1435
+ mask = (
1436
+ atom14_gt_exists[..., None, :, None]
1437
+ * atom14_atom_is_ambiguous[..., None, :, None]
1438
+ * atom14_gt_exists[..., None, :, None, :]
1439
+ * (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
1440
+ )
1441
+
1442
+ per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
1443
+ alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3))
1444
+
1445
+ fp_type = atom14_pred_positions.dtype
1446
+ alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
1447
+
1448
+ renamed_atom14_gt_positions = (
1449
+ 1.0 - alt_naming_is_better[..., None, None]
1450
+ ) * atom14_gt_positions + alt_naming_is_better[
1451
+ ..., None, None
1452
+ ] * atom14_alt_gt_positions
1453
+
1454
+ renamed_atom14_gt_mask = (
1455
+ 1.0 - alt_naming_is_better[..., None]
1456
+ ) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
1457
+ "atom14_alt_gt_exists"
1458
+ ]
1459
+
1460
+ return {
1461
+ "alt_naming_is_better": alt_naming_is_better,
1462
+ "renamed_atom14_gt_positions": renamed_atom14_gt_positions,
1463
+ "renamed_atom14_gt_exists": renamed_atom14_gt_mask,
1464
+ }
1465
+
1466
+
1467
+ def experimentally_resolved_loss(
1468
+ logits: torch.Tensor,
1469
+ atom37_atom_exists: torch.Tensor,
1470
+ all_atom_mask: torch.Tensor,
1471
+ resolution: torch.Tensor,
1472
+ min_resolution: float,
1473
+ max_resolution: float,
1474
+ eps: float = 1e-8,
1475
+ **kwargs,
1476
+ ) -> torch.Tensor:
1477
+ errors = sigmoid_cross_entropy(logits, all_atom_mask)
1478
+ loss = torch.sum(errors * atom37_atom_exists, dim=-1)
1479
+ loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
1480
+ loss = torch.sum(loss, dim=-1)
1481
+
1482
+ loss = loss * (
1483
+ (resolution >= min_resolution) & (resolution <= max_resolution)
1484
+ )
1485
+
1486
+ loss = torch.mean(loss)
1487
+
1488
+ return loss
1489
+
1490
+
1491
+ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
1492
+ """
1493
+ Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
1494
+
1495
+ Args:
1496
+ logits: [*, N_seq, N_res, 23] predicted residue distribution
1497
+ true_msa: [*, N_seq, N_res] true MSA
1498
+ bert_mask: [*, N_seq, N_res] MSA mask
1499
+ Returns:
1500
+ Masked MSA loss
1501
+ """
1502
+ errors = softmax_cross_entropy(
1503
+ logits, torch.nn.functional.one_hot(true_msa, num_classes=23)
1504
+ )
1505
+
1506
+ # FP16-friendly averaging. Equivalent to:
1507
+ # loss = (
1508
+ # torch.sum(errors * bert_mask, dim=(-1, -2)) /
1509
+ # (eps + torch.sum(bert_mask, dim=(-1, -2)))
1510
+ # )
1511
+ loss = errors * bert_mask
1512
+ loss = torch.sum(loss, dim=-1)
1513
+ scale = 0.5
1514
+ denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2))
1515
+ loss = loss / denom[..., None]
1516
+ loss = torch.sum(loss, dim=-1)
1517
+ loss = loss * scale
1518
+
1519
+ loss = torch.mean(loss)
1520
+
1521
+ return loss
1522
+
1523
+
1524
+ class AlphaFoldLoss(nn.Module):
1525
+ """Aggregation of the various losses described in the supplement"""
1526
+ def __init__(self, config):
1527
+ super(AlphaFoldLoss, self).__init__()
1528
+ self.config = config
1529
+
1530
+ def forward(self, out, batch, _return_breakdown=False):
1531
+ if "violation" not in out.keys():
1532
+ out["violation"] = find_structural_violations(
1533
+ batch,
1534
+ out["sm"]["positions"][-1],
1535
+ **self.config.violation,
1536
+ )
1537
+
1538
+ if "renamed_atom14_gt_positions" not in out.keys():
1539
+ batch.update(
1540
+ compute_renamed_ground_truth(
1541
+ batch,
1542
+ out["sm"]["positions"][-1],
1543
+ )
1544
+ )
1545
+
1546
+ loss_fns = {
1547
+ "distogram": lambda: distogram_loss(
1548
+ logits=out["distogram_logits"],
1549
+ **{**batch, **self.config.distogram},
1550
+ ),
1551
+ "experimentally_resolved": lambda: experimentally_resolved_loss(
1552
+ logits=out["experimentally_resolved_logits"],
1553
+ **{**batch, **self.config.experimentally_resolved},
1554
+ ),
1555
+ "fape": lambda: fape_loss(
1556
+ out,
1557
+ batch,
1558
+ self.config.fape,
1559
+ ),
1560
+ "lddt": lambda: lddt_loss(
1561
+ logits=out["lddt_logits"],
1562
+ all_atom_pred_pos=out["final_atom_positions"],
1563
+ **{**batch, **self.config.lddt},
1564
+ ),
1565
+ "masked_msa": lambda: masked_msa_loss(
1566
+ logits=out["masked_msa_logits"],
1567
+ **{**batch, **self.config.masked_msa},
1568
+ ),
1569
+ "supervised_chi": lambda: supervised_chi_loss(
1570
+ out["sm"]["angles"],
1571
+ out["sm"]["unnormalized_angles"],
1572
+ **{**batch, **self.config.supervised_chi},
1573
+ ),
1574
+ "violation": lambda: violation_loss(
1575
+ out["violation"],
1576
+ **batch,
1577
+ ),
1578
+ }
1579
+
1580
+ if(self.config.tm.enabled):
1581
+ loss_fns["tm"] = lambda: tm_loss(
1582
+ logits=out["tm_logits"],
1583
+ **{**batch, **out, **self.config.tm},
1584
+ )
1585
+
1586
+ cum_loss = 0.
1587
+ losses = {}
1588
+ for loss_name, loss_fn in loss_fns.items():
1589
+ weight = self.config[loss_name].weight
1590
+ loss = loss_fn()
1591
+ if(torch.isnan(loss) or torch.isinf(loss)):
1592
+ #for k,v in batch.items():
1593
+ # if(torch.any(torch.isnan(v)) or torch.any(torch.isinf(v))):
1594
+ # logging.warning(f"{k}: is nan")
1595
+ #logging.warning(f"{loss_name}: {loss}")
1596
+ logging.warning(f"{loss_name} loss is NaN. Skipping...")
1597
+ loss = loss.new_tensor(0., requires_grad=True)
1598
+ cum_loss = cum_loss + weight * loss
1599
+ losses[loss_name] = loss.detach().clone()
1600
+
1601
+ losses["unscaled_loss"] = cum_loss.detach().clone()
1602
+
1603
+ # Scale the loss by the square root of the minimum of the crop size and
1604
+ # the (average) sequence length. See subsection 1.9.
1605
+ seq_len = torch.mean(batch["seq_length"].float())
1606
+ crop_len = batch["aatype"].shape[-1]
1607
+ cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
1608
+
1609
+ losses["loss"] = cum_loss.detach().clone()
1610
+
1611
+ if(not _return_breakdown):
1612
+ return cum_loss
1613
+
1614
+ return cum_loss, losses
openfold/utils/rigid_utils.py ADDED
@@ -0,0 +1,1367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+ from typing import Tuple, Any, Sequence, Callable, Optional
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+
23
+ def rot_matmul(
24
+ a: torch.Tensor,
25
+ b: torch.Tensor
26
+ ) -> torch.Tensor:
27
+ """
28
+ Performs matrix multiplication of two rotation matrix tensors. Written
29
+ out by hand to avoid AMP downcasting.
30
+
31
+ Args:
32
+ a: [*, 3, 3] left multiplicand
33
+ b: [*, 3, 3] right multiplicand
34
+ Returns:
35
+ The product ab
36
+ """
37
+ def row_mul(i):
38
+ return torch.stack(
39
+ [
40
+ a[..., i, 0] * b[..., 0, 0]
41
+ + a[..., i, 1] * b[..., 1, 0]
42
+ + a[..., i, 2] * b[..., 2, 0],
43
+ a[..., i, 0] * b[..., 0, 1]
44
+ + a[..., i, 1] * b[..., 1, 1]
45
+ + a[..., i, 2] * b[..., 2, 1],
46
+ a[..., i, 0] * b[..., 0, 2]
47
+ + a[..., i, 1] * b[..., 1, 2]
48
+ + a[..., i, 2] * b[..., 2, 2],
49
+ ],
50
+ dim=-1,
51
+ )
52
+
53
+ return torch.stack(
54
+ [
55
+ row_mul(0),
56
+ row_mul(1),
57
+ row_mul(2),
58
+ ],
59
+ dim=-2
60
+ )
61
+
62
+
63
+ def rot_vec_mul(
64
+ r: torch.Tensor,
65
+ t: torch.Tensor
66
+ ) -> torch.Tensor:
67
+ """
68
+ Applies a rotation to a vector. Written out by hand to avoid transfer
69
+ to avoid AMP downcasting.
70
+
71
+ Args:
72
+ r: [*, 3, 3] rotation matrices
73
+ t: [*, 3] coordinate tensors
74
+ Returns:
75
+ [*, 3] rotated coordinates
76
+ """
77
+ x, y, z = torch.unbind(t, dim=-1)
78
+ return torch.stack(
79
+ [
80
+ r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
81
+ r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
82
+ r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
83
+ ],
84
+ dim=-1,
85
+ )
86
+
87
+
88
+ def identity_rot_mats(
89
+ batch_dims: Tuple[int],
90
+ dtype: Optional[torch.dtype] = None,
91
+ device: Optional[torch.device] = None,
92
+ requires_grad: bool = True,
93
+ ) -> torch.Tensor:
94
+ rots = torch.eye(
95
+ 3, dtype=dtype, device=device, requires_grad=requires_grad
96
+ )
97
+ rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
98
+ rots = rots.expand(*batch_dims, -1, -1)
99
+
100
+ return rots
101
+
102
+
103
+ def identity_trans(
104
+ batch_dims: Tuple[int],
105
+ dtype: Optional[torch.dtype] = None,
106
+ device: Optional[torch.device] = None,
107
+ requires_grad: bool = True,
108
+ ) -> torch.Tensor:
109
+ trans = torch.zeros(
110
+ (*batch_dims, 3),
111
+ dtype=dtype,
112
+ device=device,
113
+ requires_grad=requires_grad
114
+ )
115
+ return trans
116
+
117
+
118
+ def identity_quats(
119
+ batch_dims: Tuple[int],
120
+ dtype: Optional[torch.dtype] = None,
121
+ device: Optional[torch.device] = None,
122
+ requires_grad: bool = True,
123
+ ) -> torch.Tensor:
124
+ quat = torch.zeros(
125
+ (*batch_dims, 4),
126
+ dtype=dtype,
127
+ device=device,
128
+ requires_grad=requires_grad
129
+ )
130
+
131
+ with torch.no_grad():
132
+ quat[..., 0] = 1
133
+
134
+ return quat
135
+
136
+
137
+ _quat_elements = ["a", "b", "c", "d"]
138
+ _qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
139
+ _qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
140
+
141
+
142
+ def _to_mat(pairs):
143
+ mat = np.zeros((4, 4))
144
+ for pair in pairs:
145
+ key, value = pair
146
+ ind = _qtr_ind_dict[key]
147
+ mat[ind // 4][ind % 4] = value
148
+
149
+ return mat
150
+
151
+
152
+ _QTR_MAT = np.zeros((4, 4, 3, 3))
153
+ _QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
154
+ _QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
155
+ _QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
156
+ _QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
157
+ _QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
158
+ _QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
159
+ _QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
160
+ _QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
161
+ _QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
162
+
163
+
164
+ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
165
+ """
166
+ Converts a quaternion to a rotation matrix.
167
+
168
+ Args:
169
+ quat: [*, 4] quaternions
170
+ Returns:
171
+ [*, 3, 3] rotation matrices
172
+ """
173
+ # [*, 4, 4]
174
+ quat = quat[..., None] * quat[..., None, :]
175
+
176
+ # [4, 4, 3, 3]
177
+ mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
178
+
179
+ # [*, 4, 4, 3, 3]
180
+ shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
181
+ quat = quat[..., None, None] * shaped_qtr_mat
182
+
183
+ # [*, 3, 3]
184
+ return torch.sum(quat, dim=(-3, -4))
185
+
186
+
187
+ def rot_to_quat(
188
+ rot: torch.Tensor,
189
+ ):
190
+ if(rot.shape[-2:] != (3, 3)):
191
+ raise ValueError("Input rotation is incorrectly shaped")
192
+
193
+ rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
194
+ [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
195
+
196
+ k = [
197
+ [ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
198
+ [ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
199
+ [ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
200
+ [ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
201
+ ]
202
+
203
+ k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
204
+
205
+ _, vectors = torch.linalg.eigh(k)
206
+ return vectors[..., -1]
207
+
208
+
209
+ _QUAT_MULTIPLY = np.zeros((4, 4, 4))
210
+ _QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
211
+ [ 0,-1, 0, 0],
212
+ [ 0, 0,-1, 0],
213
+ [ 0, 0, 0,-1]]
214
+
215
+ _QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
216
+ [ 1, 0, 0, 0],
217
+ [ 0, 0, 0, 1],
218
+ [ 0, 0,-1, 0]]
219
+
220
+ _QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
221
+ [ 0, 0, 0,-1],
222
+ [ 1, 0, 0, 0],
223
+ [ 0, 1, 0, 0]]
224
+
225
+ _QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
226
+ [ 0, 0, 1, 0],
227
+ [ 0,-1, 0, 0],
228
+ [ 1, 0, 0, 0]]
229
+
230
+ _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
231
+
232
+
233
+ def quat_multiply(quat1, quat2):
234
+ """Multiply a quaternion by another quaternion."""
235
+ mat = quat1.new_tensor(_QUAT_MULTIPLY)
236
+ reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
237
+ return torch.sum(
238
+ reshaped_mat *
239
+ quat1[..., :, None, None] *
240
+ quat2[..., None, :, None],
241
+ dim=(-3, -2)
242
+ )
243
+
244
+
245
+ def quat_multiply_by_vec(quat, vec):
246
+ """Multiply a quaternion by a pure-vector quaternion."""
247
+ mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
248
+ reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
249
+ return torch.sum(
250
+ reshaped_mat *
251
+ quat[..., :, None, None] *
252
+ vec[..., None, :, None],
253
+ dim=(-3, -2)
254
+ )
255
+
256
+
257
+ def invert_rot_mat(rot_mat: torch.Tensor):
258
+ return rot_mat.transpose(-1, -2)
259
+
260
+
261
+ def invert_quat(quat: torch.Tensor):
262
+ quat_prime = quat.clone()
263
+ quat_prime[..., 1:] *= -1
264
+ inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
265
+ return inv
266
+
267
+
268
+ class Rotation:
269
+ """
270
+ A 3D rotation. Depending on how the object is initialized, the
271
+ rotation is represented by either a rotation matrix or a
272
+ quaternion, though both formats are made available by helper functions.
273
+ To simplify gradient computation, the underlying format of the
274
+ rotation cannot be changed in-place. Like Rigid, the class is designed
275
+ to mimic the behavior of a torch Tensor, almost as if each Rotation
276
+ object were a tensor of rotations, in one format or another.
277
+ """
278
+ def __init__(self,
279
+ rot_mats: Optional[torch.Tensor] = None,
280
+ quats: Optional[torch.Tensor] = None,
281
+ normalize_quats: bool = True,
282
+ ):
283
+ """
284
+ Args:
285
+ rot_mats:
286
+ A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
287
+ quats
288
+ quats:
289
+ A [*, 4] quaternion. Mutually exclusive with rot_mats. If
290
+ normalize_quats is not True, must be a unit quaternion
291
+ normalize_quats:
292
+ If quats is specified, whether to normalize quats
293
+ """
294
+ if((rot_mats is None and quats is None) or
295
+ (rot_mats is not None and quats is not None)):
296
+ raise ValueError("Exactly one input argument must be specified")
297
+
298
+ if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
299
+ (quats is not None and quats.shape[-1] != 4)):
300
+ raise ValueError(
301
+ "Incorrectly shaped rotation matrix or quaternion"
302
+ )
303
+
304
+ # Force full-precision
305
+ if(quats is not None):
306
+ quats = quats.to(dtype=torch.float32)
307
+ if(rot_mats is not None):
308
+ rot_mats = rot_mats.to(dtype=torch.float32)
309
+
310
+ if(quats is not None and normalize_quats):
311
+ quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
312
+
313
+ self._rot_mats = rot_mats
314
+ self._quats = quats
315
+
316
+ @staticmethod
317
+ def identity(
318
+ shape,
319
+ dtype: Optional[torch.dtype] = None,
320
+ device: Optional[torch.device] = None,
321
+ requires_grad: bool = True,
322
+ fmt: str = "quat",
323
+ ) -> Rotation:
324
+ """
325
+ Returns an identity Rotation.
326
+
327
+ Args:
328
+ shape:
329
+ The "shape" of the resulting Rotation object. See documentation
330
+ for the shape property
331
+ dtype:
332
+ The torch dtype for the rotation
333
+ device:
334
+ The torch device for the new rotation
335
+ requires_grad:
336
+ Whether the underlying tensors in the new rotation object
337
+ should require gradient computation
338
+ fmt:
339
+ One of "quat" or "rot_mat". Determines the underlying format
340
+ of the new object's rotation
341
+ Returns:
342
+ A new identity rotation
343
+ """
344
+ if(fmt == "rot_mat"):
345
+ rot_mats = identity_rot_mats(
346
+ shape, dtype, device, requires_grad,
347
+ )
348
+ return Rotation(rot_mats=rot_mats, quats=None)
349
+ elif(fmt == "quat"):
350
+ quats = identity_quats(shape, dtype, device, requires_grad)
351
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
352
+ else:
353
+ raise ValueError(f"Invalid format: f{fmt}")
354
+
355
+ # Magic methods
356
+
357
+ def __getitem__(self, index: Any) -> Rotation:
358
+ """
359
+ Allows torch-style indexing over the virtual shape of the rotation
360
+ object. See documentation for the shape property.
361
+
362
+ Args:
363
+ index:
364
+ A torch index. E.g. (1, 3, 2), or (slice(None,))
365
+ Returns:
366
+ The indexed rotation
367
+ """
368
+ if type(index) != tuple:
369
+ index = (index,)
370
+
371
+ if(self._rot_mats is not None):
372
+ rot_mats = self._rot_mats[index + (slice(None), slice(None))]
373
+ return Rotation(rot_mats=rot_mats)
374
+ elif(self._quats is not None):
375
+ quats = self._quats[index + (slice(None),)]
376
+ return Rotation(quats=quats, normalize_quats=False)
377
+ else:
378
+ raise ValueError("Both rotations are None")
379
+
380
+ def __mul__(self,
381
+ right: torch.Tensor,
382
+ ) -> Rotation:
383
+ """
384
+ Pointwise left multiplication of the rotation with a tensor. Can be
385
+ used to e.g. mask the Rotation.
386
+
387
+ Args:
388
+ right:
389
+ The tensor multiplicand
390
+ Returns:
391
+ The product
392
+ """
393
+ if not(isinstance(right, torch.Tensor)):
394
+ raise TypeError("The other multiplicand must be a Tensor")
395
+
396
+ if(self._rot_mats is not None):
397
+ rot_mats = self._rot_mats * right[..., None, None]
398
+ return Rotation(rot_mats=rot_mats, quats=None)
399
+ elif(self._quats is not None):
400
+ quats = self._quats * right[..., None]
401
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
402
+ else:
403
+ raise ValueError("Both rotations are None")
404
+
405
+ def __rmul__(self,
406
+ left: torch.Tensor,
407
+ ) -> Rotation:
408
+ """
409
+ Reverse pointwise multiplication of the rotation with a tensor.
410
+
411
+ Args:
412
+ left:
413
+ The left multiplicand
414
+ Returns:
415
+ The product
416
+ """
417
+ return self.__mul__(left)
418
+
419
+ # Properties
420
+
421
+ @property
422
+ def shape(self) -> torch.Size:
423
+ """
424
+ Returns the virtual shape of the rotation object. This shape is
425
+ defined as the batch dimensions of the underlying rotation matrix
426
+ or quaternion. If the Rotation was initialized with a [10, 3, 3]
427
+ rotation matrix tensor, for example, the resulting shape would be
428
+ [10].
429
+
430
+ Returns:
431
+ The virtual shape of the rotation object
432
+ """
433
+ s = None
434
+ if(self._quats is not None):
435
+ s = self._quats.shape[:-1]
436
+ else:
437
+ s = self._rot_mats.shape[:-2]
438
+
439
+ return s
440
+
441
+ @property
442
+ def dtype(self) -> torch.dtype:
443
+ """
444
+ Returns the dtype of the underlying rotation.
445
+
446
+ Returns:
447
+ The dtype of the underlying rotation
448
+ """
449
+ if(self._rot_mats is not None):
450
+ return self._rot_mats.dtype
451
+ elif(self._quats is not None):
452
+ return self._quats.dtype
453
+ else:
454
+ raise ValueError("Both rotations are None")
455
+
456
+ @property
457
+ def device(self) -> torch.device:
458
+ """
459
+ The device of the underlying rotation
460
+
461
+ Returns:
462
+ The device of the underlying rotation
463
+ """
464
+ if(self._rot_mats is not None):
465
+ return self._rot_mats.device
466
+ elif(self._quats is not None):
467
+ return self._quats.device
468
+ else:
469
+ raise ValueError("Both rotations are None")
470
+
471
+ @property
472
+ def requires_grad(self) -> bool:
473
+ """
474
+ Returns the requires_grad property of the underlying rotation
475
+
476
+ Returns:
477
+ The requires_grad property of the underlying tensor
478
+ """
479
+ if(self._rot_mats is not None):
480
+ return self._rot_mats.requires_grad
481
+ elif(self._quats is not None):
482
+ return self._quats.requires_grad
483
+ else:
484
+ raise ValueError("Both rotations are None")
485
+
486
+ def get_rot_mats(self) -> torch.Tensor:
487
+ """
488
+ Returns the underlying rotation as a rotation matrix tensor.
489
+
490
+ Returns:
491
+ The rotation as a rotation matrix tensor
492
+ """
493
+ rot_mats = self._rot_mats
494
+ if(rot_mats is None):
495
+ if(self._quats is None):
496
+ raise ValueError("Both rotations are None")
497
+ else:
498
+ rot_mats = quat_to_rot(self._quats)
499
+
500
+ return rot_mats
501
+
502
+ def get_quats(self) -> torch.Tensor:
503
+ """
504
+ Returns the underlying rotation as a quaternion tensor.
505
+
506
+ Depending on whether the Rotation was initialized with a
507
+ quaternion, this function may call torch.linalg.eigh.
508
+
509
+ Returns:
510
+ The rotation as a quaternion tensor.
511
+ """
512
+ quats = self._quats
513
+ if(quats is None):
514
+ if(self._rot_mats is None):
515
+ raise ValueError("Both rotations are None")
516
+ else:
517
+ quats = rot_to_quat(self._rot_mats)
518
+
519
+ return quats
520
+
521
+ def get_cur_rot(self) -> torch.Tensor:
522
+ """
523
+ Return the underlying rotation in its current form
524
+
525
+ Returns:
526
+ The stored rotation
527
+ """
528
+ if(self._rot_mats is not None):
529
+ return self._rot_mats
530
+ elif(self._quats is not None):
531
+ return self._quats
532
+ else:
533
+ raise ValueError("Both rotations are None")
534
+
535
+ # Rotation functions
536
+
537
+ def compose_q_update_vec(self,
538
+ q_update_vec: torch.Tensor,
539
+ normalize_quats: bool = True
540
+ ) -> Rotation:
541
+ """
542
+ Returns a new quaternion Rotation after updating the current
543
+ object's underlying rotation with a quaternion update, formatted
544
+ as a [*, 3] tensor whose final three columns represent x, y, z such
545
+ that (1, x, y, z) is the desired (not necessarily unit) quaternion
546
+ update.
547
+
548
+ Args:
549
+ q_update_vec:
550
+ A [*, 3] quaternion update tensor
551
+ normalize_quats:
552
+ Whether to normalize the output quaternion
553
+ Returns:
554
+ An updated Rotation
555
+ """
556
+ quats = self.get_quats()
557
+ new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
558
+ return Rotation(
559
+ rot_mats=None,
560
+ quats=new_quats,
561
+ normalize_quats=normalize_quats,
562
+ )
563
+
564
+ def compose_r(self, r: Rotation) -> Rotation:
565
+ """
566
+ Compose the rotation matrices of the current Rotation object with
567
+ those of another.
568
+
569
+ Args:
570
+ r:
571
+ An update rotation object
572
+ Returns:
573
+ An updated rotation object
574
+ """
575
+ r1 = self.get_rot_mats()
576
+ r2 = r.get_rot_mats()
577
+ new_rot_mats = rot_matmul(r1, r2)
578
+ return Rotation(rot_mats=new_rot_mats, quats=None)
579
+
580
+ def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
581
+ """
582
+ Compose the quaternions of the current Rotation object with those
583
+ of another.
584
+
585
+ Depending on whether either Rotation was initialized with
586
+ quaternions, this function may call torch.linalg.eigh.
587
+
588
+ Args:
589
+ r:
590
+ An update rotation object
591
+ Returns:
592
+ An updated rotation object
593
+ """
594
+ q1 = self.get_quats()
595
+ q2 = r.get_quats()
596
+ new_quats = quat_multiply(q1, q2)
597
+ return Rotation(
598
+ rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
599
+ )
600
+
601
+ def apply(self, pts: torch.Tensor) -> torch.Tensor:
602
+ """
603
+ Apply the current Rotation as a rotation matrix to a set of 3D
604
+ coordinates.
605
+
606
+ Args:
607
+ pts:
608
+ A [*, 3] set of points
609
+ Returns:
610
+ [*, 3] rotated points
611
+ """
612
+ rot_mats = self.get_rot_mats()
613
+ return rot_vec_mul(rot_mats, pts)
614
+
615
+ def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
616
+ """
617
+ The inverse of the apply() method.
618
+
619
+ Args:
620
+ pts:
621
+ A [*, 3] set of points
622
+ Returns:
623
+ [*, 3] inverse-rotated points
624
+ """
625
+ rot_mats = self.get_rot_mats()
626
+ inv_rot_mats = invert_rot_mat(rot_mats)
627
+ return rot_vec_mul(inv_rot_mats, pts)
628
+
629
+ def invert(self) -> Rotation:
630
+ """
631
+ Returns the inverse of the current Rotation.
632
+
633
+ Returns:
634
+ The inverse of the current Rotation
635
+ """
636
+ if(self._rot_mats is not None):
637
+ return Rotation(
638
+ rot_mats=invert_rot_mat(self._rot_mats),
639
+ quats=None
640
+ )
641
+ elif(self._quats is not None):
642
+ return Rotation(
643
+ rot_mats=None,
644
+ quats=invert_quat(self._quats),
645
+ normalize_quats=False,
646
+ )
647
+ else:
648
+ raise ValueError("Both rotations are None")
649
+
650
+ # "Tensor" stuff
651
+
652
+ def unsqueeze(self,
653
+ dim: int,
654
+ ) -> Rigid:
655
+ """
656
+ Analogous to torch.unsqueeze. The dimension is relative to the
657
+ shape of the Rotation object.
658
+
659
+ Args:
660
+ dim: A positive or negative dimension index.
661
+ Returns:
662
+ The unsqueezed Rotation.
663
+ """
664
+ if dim >= len(self.shape):
665
+ raise ValueError("Invalid dimension")
666
+
667
+ if(self._rot_mats is not None):
668
+ rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
669
+ return Rotation(rot_mats=rot_mats, quats=None)
670
+ elif(self._quats is not None):
671
+ quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
672
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
673
+ else:
674
+ raise ValueError("Both rotations are None")
675
+
676
+ @staticmethod
677
+ def cat(
678
+ rs: Sequence[Rotation],
679
+ dim: int,
680
+ ) -> Rigid:
681
+ """
682
+ Concatenates rotations along one of the batch dimensions. Analogous
683
+ to torch.cat().
684
+
685
+ Note that the output of this operation is always a rotation matrix,
686
+ regardless of the format of input rotations.
687
+
688
+ Args:
689
+ rs:
690
+ A list of rotation objects
691
+ dim:
692
+ The dimension along which the rotations should be
693
+ concatenated
694
+ Returns:
695
+ A concatenated Rotation object in rotation matrix format
696
+ """
697
+ rot_mats = [r.get_rot_mats() for r in rs]
698
+ rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
699
+
700
+ return Rotation(rot_mats=rot_mats, quats=None)
701
+
702
+ def map_tensor_fn(self,
703
+ fn: Callable[torch.Tensor, torch.Tensor]
704
+ ) -> Rotation:
705
+ """
706
+ Apply a Tensor -> Tensor function to underlying rotation tensors,
707
+ mapping over the rotation dimension(s). Can be used e.g. to sum out
708
+ a one-hot batch dimension.
709
+
710
+ Args:
711
+ fn:
712
+ A Tensor -> Tensor function to be mapped over the Rotation
713
+ Returns:
714
+ The transformed Rotation object
715
+ """
716
+ if(self._rot_mats is not None):
717
+ rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
718
+ rot_mats = torch.stack(
719
+ list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
720
+ )
721
+ rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
722
+ return Rotation(rot_mats=rot_mats, quats=None)
723
+ elif(self._quats is not None):
724
+ quats = torch.stack(
725
+ list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
726
+ )
727
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
728
+ else:
729
+ raise ValueError("Both rotations are None")
730
+
731
+ def cuda(self) -> Rotation:
732
+ """
733
+ Analogous to the cuda() method of torch Tensors
734
+
735
+ Returns:
736
+ A copy of the Rotation in CUDA memory
737
+ """
738
+ if(self._rot_mats is not None):
739
+ return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
740
+ elif(self._quats is not None):
741
+ return Rotation(
742
+ rot_mats=None,
743
+ quats=self._quats.cuda(),
744
+ normalize_quats=False
745
+ )
746
+ else:
747
+ raise ValueError("Both rotations are None")
748
+
749
+ def to(self,
750
+ device: Optional[torch.device],
751
+ dtype: Optional[torch.dtype]
752
+ ) -> Rotation:
753
+ """
754
+ Analogous to the to() method of torch Tensors
755
+
756
+ Args:
757
+ device:
758
+ A torch device
759
+ dtype:
760
+ A torch dtype
761
+ Returns:
762
+ A copy of the Rotation using the new device and dtype
763
+ """
764
+ if(self._rot_mats is not None):
765
+ return Rotation(
766
+ rot_mats=self._rot_mats.to(device=device, dtype=dtype),
767
+ quats=None,
768
+ )
769
+ elif(self._quats is not None):
770
+ return Rotation(
771
+ rot_mats=None,
772
+ quats=self._quats.to(device=device, dtype=dtype),
773
+ normalize_quats=False,
774
+ )
775
+ else:
776
+ raise ValueError("Both rotations are None")
777
+
778
+ def detach(self) -> Rotation:
779
+ """
780
+ Returns a copy of the Rotation whose underlying Tensor has been
781
+ detached from its torch graph.
782
+
783
+ Returns:
784
+ A copy of the Rotation whose underlying Tensor has been detached
785
+ from its torch graph
786
+ """
787
+ if(self._rot_mats is not None):
788
+ return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
789
+ elif(self._quats is not None):
790
+ return Rotation(
791
+ rot_mats=None,
792
+ quats=self._quats.detach(),
793
+ normalize_quats=False,
794
+ )
795
+ else:
796
+ raise ValueError("Both rotations are None")
797
+
798
+
799
+ class Rigid:
800
+ """
801
+ A class representing a rigid transformation. Little more than a wrapper
802
+ around two objects: a Rotation object and a [*, 3] translation
803
+ Designed to behave approximately like a single torch tensor with the
804
+ shape of the shared batch dimensions of its component parts.
805
+ """
806
+ def __init__(self,
807
+ rots: Optional[Rotation],
808
+ trans: Optional[torch.Tensor],
809
+ ):
810
+ """
811
+ Args:
812
+ rots: A [*, 3, 3] rotation tensor
813
+ trans: A corresponding [*, 3] translation tensor
814
+ """
815
+ # (we need device, dtype, etc. from at least one input)
816
+
817
+ batch_dims, dtype, device, requires_grad = None, None, None, None
818
+ if(trans is not None):
819
+ batch_dims = trans.shape[:-1]
820
+ dtype = trans.dtype
821
+ device = trans.device
822
+ requires_grad = trans.requires_grad
823
+ elif(rots is not None):
824
+ batch_dims = rots.shape
825
+ dtype = rots.dtype
826
+ device = rots.device
827
+ requires_grad = rots.requires_grad
828
+ else:
829
+ raise ValueError("At least one input argument must be specified")
830
+
831
+ if(rots is None):
832
+ rots = Rotation.identity(
833
+ batch_dims, dtype, device, requires_grad,
834
+ )
835
+ elif(trans is None):
836
+ trans = identity_trans(
837
+ batch_dims, dtype, device, requires_grad,
838
+ )
839
+
840
+ if((rots.shape != trans.shape[:-1]) or
841
+ (rots.device != trans.device)):
842
+ raise ValueError("Rots and trans incompatible")
843
+
844
+ # Force full precision. Happens to the rotations automatically.
845
+ trans = trans.to(dtype=torch.float32)
846
+
847
+ self._rots = rots
848
+ self._trans = trans
849
+
850
+ @staticmethod
851
+ def identity(
852
+ shape: Tuple[int],
853
+ dtype: Optional[torch.dtype] = None,
854
+ device: Optional[torch.device] = None,
855
+ requires_grad: bool = True,
856
+ fmt: str = "quat",
857
+ ) -> Rigid:
858
+ """
859
+ Constructs an identity transformation.
860
+
861
+ Args:
862
+ shape:
863
+ The desired shape
864
+ dtype:
865
+ The dtype of both internal tensors
866
+ device:
867
+ The device of both internal tensors
868
+ requires_grad:
869
+ Whether grad should be enabled for the internal tensors
870
+ Returns:
871
+ The identity transformation
872
+ """
873
+ return Rigid(
874
+ Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
875
+ identity_trans(shape, dtype, device, requires_grad),
876
+ )
877
+
878
+ def __getitem__(self,
879
+ index: Any,
880
+ ) -> Rigid:
881
+ """
882
+ Indexes the affine transformation with PyTorch-style indices.
883
+ The index is applied to the shared dimensions of both the rotation
884
+ and the translation.
885
+
886
+ E.g.::
887
+
888
+ r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
889
+ t = Rigid(r, torch.rand(10, 10, 3))
890
+ indexed = t[3, 4:6]
891
+ assert(indexed.shape == (2,))
892
+ assert(indexed.get_rots().shape == (2,))
893
+ assert(indexed.get_trans().shape == (2, 3))
894
+
895
+ Args:
896
+ index: A standard torch tensor index. E.g. 8, (10, None, 3),
897
+ or (3, slice(0, 1, None))
898
+ Returns:
899
+ The indexed tensor
900
+ """
901
+ if type(index) != tuple:
902
+ index = (index,)
903
+
904
+ return Rigid(
905
+ self._rots[index],
906
+ self._trans[index + (slice(None),)],
907
+ )
908
+
909
+ def __mul__(self,
910
+ right: torch.Tensor,
911
+ ) -> Rigid:
912
+ """
913
+ Pointwise left multiplication of the transformation with a tensor.
914
+ Can be used to e.g. mask the Rigid.
915
+
916
+ Args:
917
+ right:
918
+ The tensor multiplicand
919
+ Returns:
920
+ The product
921
+ """
922
+ if not(isinstance(right, torch.Tensor)):
923
+ raise TypeError("The other multiplicand must be a Tensor")
924
+
925
+ new_rots = self._rots * right
926
+ new_trans = self._trans * right[..., None]
927
+
928
+ return Rigid(new_rots, new_trans)
929
+
930
+ def __rmul__(self,
931
+ left: torch.Tensor,
932
+ ) -> Rigid:
933
+ """
934
+ Reverse pointwise multiplication of the transformation with a
935
+ tensor.
936
+
937
+ Args:
938
+ left:
939
+ The left multiplicand
940
+ Returns:
941
+ The product
942
+ """
943
+ return self.__mul__(left)
944
+
945
+ @property
946
+ def shape(self) -> torch.Size:
947
+ """
948
+ Returns the shape of the shared dimensions of the rotation and
949
+ the translation.
950
+
951
+ Returns:
952
+ The shape of the transformation
953
+ """
954
+ s = self._trans.shape[:-1]
955
+ return s
956
+
957
+ @property
958
+ def device(self) -> torch.device:
959
+ """
960
+ Returns the device on which the Rigid's tensors are located.
961
+
962
+ Returns:
963
+ The device on which the Rigid's tensors are located
964
+ """
965
+ return self._trans.device
966
+
967
+ def get_rots(self) -> Rotation:
968
+ """
969
+ Getter for the rotation.
970
+
971
+ Returns:
972
+ The rotation object
973
+ """
974
+ return self._rots
975
+
976
+ def get_trans(self) -> torch.Tensor:
977
+ """
978
+ Getter for the translation.
979
+
980
+ Returns:
981
+ The stored translation
982
+ """
983
+ return self._trans
984
+
985
+ def compose_q_update_vec(self,
986
+ q_update_vec: torch.Tensor,
987
+ ) -> Rigid:
988
+ """
989
+ Composes the transformation with a quaternion update vector of
990
+ shape [*, 6], where the final 6 columns represent the x, y, and
991
+ z values of a quaternion of form (1, x, y, z) followed by a 3D
992
+ translation.
993
+
994
+ Args:
995
+ q_vec: The quaternion update vector.
996
+ Returns:
997
+ The composed transformation.
998
+ """
999
+ q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
1000
+ new_rots = self._rots.compose_q_update_vec(q_vec)
1001
+
1002
+ trans_update = self._rots.apply(t_vec)
1003
+ new_translation = self._trans + trans_update
1004
+
1005
+ return Rigid(new_rots, new_translation)
1006
+
1007
+ def compose(self,
1008
+ r: Rigid,
1009
+ ) -> Rigid:
1010
+ """
1011
+ Composes the current rigid object with another.
1012
+
1013
+ Args:
1014
+ r:
1015
+ Another Rigid object
1016
+ Returns:
1017
+ The composition of the two transformations
1018
+ """
1019
+ new_rot = self._rots.compose_r(r._rots)
1020
+ new_trans = self._rots.apply(r._trans) + self._trans
1021
+ return Rigid(new_rot, new_trans)
1022
+
1023
+ def apply(self,
1024
+ pts: torch.Tensor,
1025
+ ) -> torch.Tensor:
1026
+ """
1027
+ Applies the transformation to a coordinate tensor.
1028
+
1029
+ Args:
1030
+ pts: A [*, 3] coordinate tensor.
1031
+ Returns:
1032
+ The transformed points.
1033
+ """
1034
+ rotated = self._rots.apply(pts)
1035
+ return rotated + self._trans
1036
+
1037
+ def invert_apply(self,
1038
+ pts: torch.Tensor
1039
+ ) -> torch.Tensor:
1040
+ """
1041
+ Applies the inverse of the transformation to a coordinate tensor.
1042
+
1043
+ Args:
1044
+ pts: A [*, 3] coordinate tensor
1045
+ Returns:
1046
+ The transformed points.
1047
+ """
1048
+ pts = pts - self._trans
1049
+ return self._rots.invert_apply(pts)
1050
+
1051
+ def invert(self) -> Rigid:
1052
+ """
1053
+ Inverts the transformation.
1054
+
1055
+ Returns:
1056
+ The inverse transformation.
1057
+ """
1058
+ rot_inv = self._rots.invert()
1059
+ trn_inv = rot_inv.apply(self._trans)
1060
+
1061
+ return Rigid(rot_inv, -1 * trn_inv)
1062
+
1063
+ def map_tensor_fn(self,
1064
+ fn: Callable[torch.Tensor, torch.Tensor]
1065
+ ) -> Rigid:
1066
+ """
1067
+ Apply a Tensor -> Tensor function to underlying translation and
1068
+ rotation tensors, mapping over the translation/rotation dimensions
1069
+ respectively.
1070
+
1071
+ Args:
1072
+ fn:
1073
+ A Tensor -> Tensor function to be mapped over the Rigid
1074
+ Returns:
1075
+ The transformed Rigid object
1076
+ """
1077
+ new_rots = self._rots.map_tensor_fn(fn)
1078
+ new_trans = torch.stack(
1079
+ list(map(fn, torch.unbind(self._trans, dim=-1))),
1080
+ dim=-1
1081
+ )
1082
+
1083
+ return Rigid(new_rots, new_trans)
1084
+
1085
+ def to_tensor_4x4(self) -> torch.Tensor:
1086
+ """
1087
+ Converts a transformation to a homogenous transformation tensor.
1088
+
1089
+ Returns:
1090
+ A [*, 4, 4] homogenous transformation tensor
1091
+ """
1092
+ tensor = self._trans.new_zeros((*self.shape, 4, 4))
1093
+ tensor[..., :3, :3] = self._rots.get_rot_mats()
1094
+ tensor[..., :3, 3] = self._trans
1095
+ tensor[..., 3, 3] = 1
1096
+ return tensor
1097
+
1098
+ @staticmethod
1099
+ def from_tensor_4x4(
1100
+ t: torch.Tensor
1101
+ ) -> Rigid:
1102
+ """
1103
+ Constructs a transformation from a homogenous transformation
1104
+ tensor.
1105
+
1106
+ Args:
1107
+ t: [*, 4, 4] homogenous transformation tensor
1108
+ Returns:
1109
+ T object with shape [*]
1110
+ """
1111
+ if(t.shape[-2:] != (4, 4)):
1112
+ raise ValueError("Incorrectly shaped input tensor")
1113
+
1114
+ rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
1115
+ trans = t[..., :3, 3]
1116
+
1117
+ return Rigid(rots, trans)
1118
+
1119
+ def to_tensor_7(self) -> torch.Tensor:
1120
+ """
1121
+ Converts a transformation to a tensor with 7 final columns, four
1122
+ for the quaternion followed by three for the translation.
1123
+
1124
+ Returns:
1125
+ A [*, 7] tensor representation of the transformation
1126
+ """
1127
+ tensor = self._trans.new_zeros((*self.shape, 7))
1128
+ tensor[..., :4] = self._rots.get_quats()
1129
+ tensor[..., 4:] = self._trans
1130
+
1131
+ return tensor
1132
+
1133
+ @staticmethod
1134
+ def from_tensor_7(
1135
+ t: torch.Tensor,
1136
+ normalize_quats: bool = False,
1137
+ ) -> Rigid:
1138
+ if(t.shape[-1] != 7):
1139
+ raise ValueError("Incorrectly shaped input tensor")
1140
+
1141
+ quats, trans = t[..., :4], t[..., 4:]
1142
+
1143
+ rots = Rotation(
1144
+ rot_mats=None,
1145
+ quats=quats,
1146
+ normalize_quats=normalize_quats
1147
+ )
1148
+
1149
+ return Rigid(rots, trans)
1150
+
1151
+ @staticmethod
1152
+ def from_3_points(
1153
+ p_neg_x_axis: torch.Tensor,
1154
+ origin: torch.Tensor,
1155
+ p_xy_plane: torch.Tensor,
1156
+ eps: float = 1e-8
1157
+ ) -> Rigid:
1158
+ """
1159
+ Implements algorithm 21. Constructs transformations from sets of 3
1160
+ points using the Gram-Schmidt algorithm.
1161
+
1162
+ Args:
1163
+ p_neg_x_axis: [*, 3] coordinates
1164
+ origin: [*, 3] coordinates used as frame origins
1165
+ p_xy_plane: [*, 3] coordinates
1166
+ eps: Small epsilon value
1167
+ Returns:
1168
+ A transformation object of shape [*]
1169
+ """
1170
+ p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
1171
+ origin = torch.unbind(origin, dim=-1)
1172
+ p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
1173
+
1174
+ e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
1175
+ e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
1176
+
1177
+ denom = torch.sqrt(sum((c * c for c in e0)) + eps)
1178
+ e0 = [c / denom for c in e0]
1179
+ dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
1180
+ e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
1181
+ denom = torch.sqrt(sum((c * c for c in e1)) + eps)
1182
+ e1 = [c / denom for c in e1]
1183
+ e2 = [
1184
+ e0[1] * e1[2] - e0[2] * e1[1],
1185
+ e0[2] * e1[0] - e0[0] * e1[2],
1186
+ e0[0] * e1[1] - e0[1] * e1[0],
1187
+ ]
1188
+
1189
+ rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
1190
+ rots = rots.reshape(rots.shape[:-1] + (3, 3))
1191
+
1192
+ rot_obj = Rotation(rot_mats=rots, quats=None)
1193
+
1194
+ return Rigid(rot_obj, torch.stack(origin, dim=-1))
1195
+
1196
+ def unsqueeze(self,
1197
+ dim: int,
1198
+ ) -> Rigid:
1199
+ """
1200
+ Analogous to torch.unsqueeze. The dimension is relative to the
1201
+ shared dimensions of the rotation/translation.
1202
+
1203
+ Args:
1204
+ dim: A positive or negative dimension index.
1205
+ Returns:
1206
+ The unsqueezed transformation.
1207
+ """
1208
+ if dim >= len(self.shape):
1209
+ raise ValueError("Invalid dimension")
1210
+ rots = self._rots.unsqueeze(dim)
1211
+ trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
1212
+
1213
+ return Rigid(rots, trans)
1214
+
1215
+ @staticmethod
1216
+ def cat(
1217
+ ts: Sequence[Rigid],
1218
+ dim: int,
1219
+ ) -> Rigid:
1220
+ """
1221
+ Concatenates transformations along a new dimension.
1222
+
1223
+ Args:
1224
+ ts:
1225
+ A list of T objects
1226
+ dim:
1227
+ The dimension along which the transformations should be
1228
+ concatenated
1229
+ Returns:
1230
+ A concatenated transformation object
1231
+ """
1232
+ rots = Rotation.cat([t._rots for t in ts], dim)
1233
+ trans = torch.cat(
1234
+ [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
1235
+ )
1236
+
1237
+ return Rigid(rots, trans)
1238
+
1239
+ def apply_rot_fn(self, fn: Callable[Rotation, Rotation]) -> Rigid:
1240
+ """
1241
+ Applies a Rotation -> Rotation function to the stored rotation
1242
+ object.
1243
+
1244
+ Args:
1245
+ fn: A function of type Rotation -> Rotation
1246
+ Returns:
1247
+ A transformation object with a transformed rotation.
1248
+ """
1249
+ return Rigid(fn(self._rots), self._trans)
1250
+
1251
+ def apply_trans_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rigid:
1252
+ """
1253
+ Applies a Tensor -> Tensor function to the stored translation.
1254
+
1255
+ Args:
1256
+ fn:
1257
+ A function of type Tensor -> Tensor to be applied to the
1258
+ translation
1259
+ Returns:
1260
+ A transformation object with a transformed translation.
1261
+ """
1262
+ return Rigid(self._rots, fn(self._trans))
1263
+
1264
+ def scale_translation(self, trans_scale_factor: float) -> Rigid:
1265
+ """
1266
+ Scales the translation by a constant factor.
1267
+
1268
+ Args:
1269
+ trans_scale_factor:
1270
+ The constant factor
1271
+ Returns:
1272
+ A transformation object with a scaled translation.
1273
+ """
1274
+ fn = lambda t: t * trans_scale_factor
1275
+ return self.apply_trans_fn(fn)
1276
+
1277
+ def stop_rot_gradient(self) -> Rigid:
1278
+ """
1279
+ Detaches the underlying rotation object
1280
+
1281
+ Returns:
1282
+ A transformation object with detached rotations
1283
+ """
1284
+ fn = lambda r: r.detach()
1285
+ return self.apply_rot_fn(fn)
1286
+
1287
+ @staticmethod
1288
+ def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
1289
+ """
1290
+ Returns a transformation object from reference coordinates.
1291
+
1292
+ Note that this method does not take care of symmetries. If you
1293
+ provide the atom positions in the non-standard way, the N atom will
1294
+ end up not at [-0.527250, 1.359329, 0.0] but instead at
1295
+ [-0.527250, -1.359329, 0.0]. You need to take care of such cases in
1296
+ your code.
1297
+
1298
+ Args:
1299
+ n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
1300
+ ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
1301
+ c_xyz: A [*, 3] tensor of carbon xyz coordinates.
1302
+ Returns:
1303
+ A transformation object. After applying the translation and
1304
+ rotation to the reference backbone, the coordinates will
1305
+ approximately equal to the input coordinates.
1306
+ """
1307
+ translation = -1 * ca_xyz
1308
+ n_xyz = n_xyz + translation
1309
+ c_xyz = c_xyz + translation
1310
+
1311
+ c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
1312
+ norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2)
1313
+ sin_c1 = -c_y / norm
1314
+ cos_c1 = c_x / norm
1315
+ zeros = sin_c1.new_zeros(sin_c1.shape)
1316
+ ones = sin_c1.new_ones(sin_c1.shape)
1317
+
1318
+ c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
1319
+ c1_rots[..., 0, 0] = cos_c1
1320
+ c1_rots[..., 0, 1] = -1 * sin_c1
1321
+ c1_rots[..., 1, 0] = sin_c1
1322
+ c1_rots[..., 1, 1] = cos_c1
1323
+ c1_rots[..., 2, 2] = 1
1324
+
1325
+ norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2)
1326
+ sin_c2 = c_z / norm
1327
+ cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm
1328
+
1329
+ c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
1330
+ c2_rots[..., 0, 0] = cos_c2
1331
+ c2_rots[..., 0, 2] = sin_c2
1332
+ c2_rots[..., 1, 1] = 1
1333
+ c2_rots[..., 2, 0] = -1 * sin_c2
1334
+ c2_rots[..., 2, 2] = cos_c2
1335
+
1336
+ c_rots = rot_matmul(c2_rots, c1_rots)
1337
+ n_xyz = rot_vec_mul(c_rots, n_xyz)
1338
+
1339
+ _, n_y, n_z = [n_xyz[..., i] for i in range(3)]
1340
+ norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2)
1341
+ sin_n = -n_z / norm
1342
+ cos_n = n_y / norm
1343
+
1344
+ n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
1345
+ n_rots[..., 0, 0] = 1
1346
+ n_rots[..., 1, 1] = cos_n
1347
+ n_rots[..., 1, 2] = -1 * sin_n
1348
+ n_rots[..., 2, 1] = sin_n
1349
+ n_rots[..., 2, 2] = cos_n
1350
+
1351
+ rots = rot_matmul(n_rots, c_rots)
1352
+
1353
+ rots = rots.transpose(-1, -2)
1354
+ translation = -1 * translation
1355
+
1356
+ rot_obj = Rotation(rot_mats=rots, quats=None)
1357
+
1358
+ return Rigid(rot_obj, translation)
1359
+
1360
+ def cuda(self) -> Rigid:
1361
+ """
1362
+ Moves the transformation object to GPU memory
1363
+
1364
+ Returns:
1365
+ A version of the transformation on GPU
1366
+ """
1367
+ return Rigid(self._rots.cuda(), self._trans.cuda())
openfold/utils/tensor_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ import logging
18
+ from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+
24
+ def add(m1, m2, inplace):
25
+ # The first operation in a checkpoint can't be in-place, but it's
26
+ # nice to have in-place addition during inference. Thus...
27
+ if(not inplace):
28
+ m1 = m1 + m2
29
+ else:
30
+ m1 += m2
31
+
32
+ return m1
33
+
34
+
35
+ def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
36
+ zero_index = -1 * len(inds)
37
+ first_inds = list(range(len(tensor.shape[:zero_index])))
38
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
39
+
40
+
41
+ def flatten_final_dims(t: torch.Tensor, no_dims: int):
42
+ return t.reshape(t.shape[:-no_dims] + (-1,))
43
+
44
+
45
+ def masked_mean(mask, value, dim, eps=1e-4):
46
+ mask = mask.expand(*value.shape)
47
+ return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
48
+
49
+
50
+ def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
51
+ boundaries = torch.linspace(
52
+ min_bin, max_bin, no_bins - 1, device=pts.device
53
+ )
54
+ dists = torch.sqrt(
55
+ torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
56
+ )
57
+ return torch.bucketize(dists, boundaries)
58
+
59
+
60
+ def dict_multimap(fn, dicts):
61
+ first = dicts[0]
62
+ new_dict = {}
63
+ for k, v in first.items():
64
+ all_v = [d[k] for d in dicts]
65
+ if type(v) is dict:
66
+ new_dict[k] = dict_multimap(fn, all_v)
67
+ else:
68
+ new_dict[k] = fn(all_v)
69
+
70
+ return new_dict
71
+
72
+
73
+ def one_hot(x, v_bins):
74
+ reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
75
+ diffs = x[..., None] - reshaped_bins
76
+ am = torch.argmin(torch.abs(diffs), dim=-1)
77
+ return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
78
+
79
+
80
+ def batched_gather(data, inds, dim=0, no_batch_dims=0):
81
+ ranges = []
82
+ for i, s in enumerate(data.shape[:no_batch_dims]):
83
+ r = torch.arange(s)
84
+ r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
85
+ ranges.append(r)
86
+
87
+ remaining_dims = [
88
+ slice(None) for _ in range(len(data.shape) - no_batch_dims)
89
+ ]
90
+ remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
91
+ ranges.extend(remaining_dims)
92
+ return data[ranges]
93
+
94
+
95
+ # With tree_map, a poor man's JAX tree_map
96
+ def dict_map(fn, dic, leaf_type):
97
+ new_dict = {}
98
+ for k, v in dic.items():
99
+ if type(v) is dict:
100
+ new_dict[k] = dict_map(fn, v, leaf_type)
101
+ else:
102
+ new_dict[k] = tree_map(fn, v, leaf_type)
103
+
104
+ return new_dict
105
+
106
+
107
+ def tree_map(fn, tree, leaf_type):
108
+ if isinstance(tree, dict):
109
+ return dict_map(fn, tree, leaf_type)
110
+ elif isinstance(tree, list):
111
+ return [tree_map(fn, x, leaf_type) for x in tree]
112
+ elif isinstance(tree, tuple):
113
+ return tuple([tree_map(fn, x, leaf_type) for x in tree])
114
+ elif isinstance(tree, leaf_type):
115
+ return fn(tree)
116
+ else:
117
+ print(type(tree))
118
+ raise ValueError("Not supported")
119
+
120
+
121
+ tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ biopython==1.79
2
+ filelock==3.13.1
3
+ fsspec==2024.3.1
4
+ Jinja2==3.1.3
5
+ MarkupSafe==2.1.5
6
+ mpmath==1.3.0
7
+ networkx==3.2.1
8
+ numpy==1.23.5
9
+ nvidia-cublas-cu12==12.1.3.1
10
+ nvidia-cuda-cupti-cu12==12.1.105
11
+ nvidia-cuda-nvrtc-cu12==12.1.105
12
+ nvidia-cuda-runtime-cu12==12.1.105
13
+ nvidia-cudnn-cu12==8.9.2.26
14
+ nvidia-cufft-cu12==11.0.2.54
15
+ nvidia-curand-cu12==10.3.2.106
16
+ nvidia-cusolver-cu12==11.4.5.107
17
+ nvidia-cusparse-cu12==12.1.0.106
18
+ nvidia-nccl-cu12==2.19.3
19
+ nvidia-nvjitlink-cu12==12.4.99
20
+ nvidia-nvtx-cu12==12.1.105
21
+ ProDy==2.4.1
22
+ pyparsing==3.1.1
23
+ scipy==1.12.0
24
+ sympy==1.12
25
+ torch==2.2.1
26
+ triton==2.2.0
27
+ typing_extensions==4.10.0
28
+ ml-collections==0.1.1
29
+ dm-tree==0.1.8
run.py ADDED
@@ -0,0 +1,990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import json
4
+ import os.path
5
+ import random
6
+ import sys
7
+
8
+ import numpy as np
9
+ import torch
10
+ from data_utils import (
11
+ alphabet,
12
+ element_dict_rev,
13
+ featurize,
14
+ get_score,
15
+ get_seq_rec,
16
+ parse_PDB,
17
+ restype_1to3,
18
+ restype_int_to_str,
19
+ restype_str_to_int,
20
+ write_full_PDB,
21
+ )
22
+ from model_utils import ProteinMPNN
23
+ from prody import writePDB
24
+ from sc_utils import Packer, pack_side_chains
25
+
26
+
27
+ def main(args) -> None:
28
+ """
29
+ Inference function
30
+ """
31
+ if args.seed:
32
+ seed = args.seed
33
+ else:
34
+ seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0])
35
+ torch.manual_seed(seed)
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
39
+ folder_for_outputs = args.out_folder
40
+ base_folder = folder_for_outputs
41
+ if base_folder[-1] != "/":
42
+ base_folder = base_folder + "/"
43
+ if not os.path.exists(base_folder):
44
+ os.makedirs(base_folder, exist_ok=True)
45
+ if not os.path.exists(base_folder + "seqs"):
46
+ os.makedirs(base_folder + "seqs", exist_ok=True)
47
+ if not os.path.exists(base_folder + "backbones"):
48
+ os.makedirs(base_folder + "backbones", exist_ok=True)
49
+ if not os.path.exists(base_folder + "packed"):
50
+ os.makedirs(base_folder + "packed", exist_ok=True)
51
+ if args.save_stats:
52
+ if not os.path.exists(base_folder + "stats"):
53
+ os.makedirs(base_folder + "stats", exist_ok=True)
54
+ if args.model_type == "protein_mpnn":
55
+ checkpoint_path = args.checkpoint_protein_mpnn
56
+ elif args.model_type == "ligand_mpnn":
57
+ checkpoint_path = args.checkpoint_ligand_mpnn
58
+ elif args.model_type == "per_residue_label_membrane_mpnn":
59
+ checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn
60
+ elif args.model_type == "global_label_membrane_mpnn":
61
+ checkpoint_path = args.checkpoint_global_label_membrane_mpnn
62
+ elif args.model_type == "soluble_mpnn":
63
+ checkpoint_path = args.checkpoint_soluble_mpnn
64
+ else:
65
+ print("Choose one of the available models")
66
+ sys.exit()
67
+ checkpoint = torch.load(checkpoint_path, map_location=device)
68
+ if args.model_type == "ligand_mpnn":
69
+ atom_context_num = checkpoint["atom_context_num"]
70
+ ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context
71
+ k_neighbors = checkpoint["num_edges"]
72
+ else:
73
+ atom_context_num = 1
74
+ ligand_mpnn_use_side_chain_context = 0
75
+ k_neighbors = checkpoint["num_edges"]
76
+
77
+ model = ProteinMPNN(
78
+ node_features=128,
79
+ edge_features=128,
80
+ hidden_dim=128,
81
+ num_encoder_layers=3,
82
+ num_decoder_layers=3,
83
+ k_neighbors=k_neighbors,
84
+ device=device,
85
+ atom_context_num=atom_context_num,
86
+ model_type=args.model_type,
87
+ ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context,
88
+ )
89
+
90
+ model.load_state_dict(checkpoint["model_state_dict"])
91
+ model.to(device)
92
+ model.eval()
93
+
94
+ if args.pack_side_chains:
95
+ model_sc = Packer(
96
+ node_features=128,
97
+ edge_features=128,
98
+ num_positional_embeddings=16,
99
+ num_chain_embeddings=16,
100
+ num_rbf=16,
101
+ hidden_dim=128,
102
+ num_encoder_layers=3,
103
+ num_decoder_layers=3,
104
+ atom_context_num=16,
105
+ lower_bound=0.0,
106
+ upper_bound=20.0,
107
+ top_k=32,
108
+ dropout=0.0,
109
+ augment_eps=0.0,
110
+ atom37_order=False,
111
+ device=device,
112
+ num_mix=3,
113
+ )
114
+
115
+ checkpoint_sc = torch.load(args.checkpoint_path_sc, map_location=device)
116
+ model_sc.load_state_dict(checkpoint_sc["model_state_dict"])
117
+ model_sc.to(device)
118
+ model_sc.eval()
119
+
120
+ if args.pdb_path_multi:
121
+ with open(args.pdb_path_multi, "r") as fh:
122
+ pdb_paths = list(json.load(fh))
123
+ else:
124
+ pdb_paths = [args.pdb_path]
125
+
126
+ if args.fixed_residues_multi:
127
+ with open(args.fixed_residues_multi, "r") as fh:
128
+ fixed_residues_multi = json.load(fh)
129
+ fixed_residues_multi = {key:value.split() for key,value in fixed_residues_multi.items()}
130
+ else:
131
+ fixed_residues = [item for item in args.fixed_residues.split()]
132
+ fixed_residues_multi = {}
133
+ for pdb in pdb_paths:
134
+ fixed_residues_multi[pdb] = fixed_residues
135
+
136
+ if args.redesigned_residues_multi:
137
+ with open(args.redesigned_residues_multi, "r") as fh:
138
+ redesigned_residues_multi = json.load(fh)
139
+ redesigned_residues_multi = {key:value.split() for key,value in redesigned_residues_multi.items()}
140
+ else:
141
+ redesigned_residues = [item for item in args.redesigned_residues.split()]
142
+ redesigned_residues_multi = {}
143
+ for pdb in pdb_paths:
144
+ redesigned_residues_multi[pdb] = redesigned_residues
145
+
146
+ bias_AA = torch.zeros([21], device=device, dtype=torch.float32)
147
+ if args.bias_AA:
148
+ tmp = [item.split(":") for item in args.bias_AA.split(",")]
149
+ a1 = [b[0] for b in tmp]
150
+ a2 = [float(b[1]) for b in tmp]
151
+ for i, AA in enumerate(a1):
152
+ bias_AA[restype_str_to_int[AA]] = a2[i]
153
+
154
+ if args.bias_AA_per_residue_multi:
155
+ with open(args.bias_AA_per_residue_multi, "r") as fh:
156
+ bias_AA_per_residue_multi = json.load(
157
+ fh
158
+ ) # {"pdb_path" : {"A12": {"G": 1.1}}}
159
+ else:
160
+ if args.bias_AA_per_residue:
161
+ with open(args.bias_AA_per_residue, "r") as fh:
162
+ bias_AA_per_residue = json.load(fh) # {"A12": {"G": 1.1}}
163
+ bias_AA_per_residue_multi = {}
164
+ for pdb in pdb_paths:
165
+ bias_AA_per_residue_multi[pdb] = bias_AA_per_residue
166
+
167
+ if args.omit_AA_per_residue_multi:
168
+ with open(args.omit_AA_per_residue_multi, "r") as fh:
169
+ omit_AA_per_residue_multi = json.load(
170
+ fh
171
+ ) # {"pdb_path" : {"A12": "PQR", "A13": "QS"}}
172
+ else:
173
+ if args.omit_AA_per_residue:
174
+ with open(args.omit_AA_per_residue, "r") as fh:
175
+ omit_AA_per_residue = json.load(fh) # {"A12": "PG"}
176
+ omit_AA_per_residue_multi = {}
177
+ for pdb in pdb_paths:
178
+ omit_AA_per_residue_multi[pdb] = omit_AA_per_residue
179
+ omit_AA_list = args.omit_AA
180
+ omit_AA = torch.tensor(
181
+ np.array([AA in omit_AA_list for AA in alphabet]).astype(np.float32),
182
+ device=device,
183
+ )
184
+
185
+ if len(args.parse_these_chains_only) != 0:
186
+ parse_these_chains_only_list = args.parse_these_chains_only.split(",")
187
+ else:
188
+ parse_these_chains_only_list = []
189
+
190
+
191
+ # loop over PDB paths
192
+ for pdb in pdb_paths:
193
+ if args.verbose:
194
+ print("Designing protein from this path:", pdb)
195
+ fixed_residues = fixed_residues_multi[pdb]
196
+ redesigned_residues = redesigned_residues_multi[pdb]
197
+ parse_all_atoms_flag = args.ligand_mpnn_use_side_chain_context or (
198
+ args.pack_side_chains and not args.repack_everything
199
+ )
200
+ protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(
201
+ pdb,
202
+ device=device,
203
+ chains=parse_these_chains_only_list,
204
+ parse_all_atoms=parse_all_atoms_flag,
205
+ parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy,
206
+ )
207
+ # make chain_letter + residue_idx + insertion_code mapping to integers
208
+ R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices
209
+ chain_letters_list = list(protein_dict["chain_letters"]) # chain letters
210
+ encoded_residues = []
211
+ for i, R_idx_item in enumerate(R_idx_list):
212
+ tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i]
213
+ encoded_residues.append(tmp)
214
+ encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues))))
215
+ encoded_residue_dict_rev = dict(
216
+ zip(list(range(len(encoded_residues))), encoded_residues)
217
+ )
218
+
219
+ bias_AA_per_residue = torch.zeros(
220
+ [len(encoded_residues), 21], device=device, dtype=torch.float32
221
+ )
222
+ if args.bias_AA_per_residue_multi or args.bias_AA_per_residue:
223
+ bias_dict = bias_AA_per_residue_multi[pdb]
224
+ for residue_name, v1 in bias_dict.items():
225
+ if residue_name in encoded_residues:
226
+ i1 = encoded_residue_dict[residue_name]
227
+ for amino_acid, v2 in v1.items():
228
+ if amino_acid in alphabet:
229
+ j1 = restype_str_to_int[amino_acid]
230
+ bias_AA_per_residue[i1, j1] = v2
231
+
232
+ omit_AA_per_residue = torch.zeros(
233
+ [len(encoded_residues), 21], device=device, dtype=torch.float32
234
+ )
235
+ if args.omit_AA_per_residue_multi or args.omit_AA_per_residue:
236
+ omit_dict = omit_AA_per_residue_multi[pdb]
237
+ for residue_name, v1 in omit_dict.items():
238
+ if residue_name in encoded_residues:
239
+ i1 = encoded_residue_dict[residue_name]
240
+ for amino_acid in v1:
241
+ if amino_acid in alphabet:
242
+ j1 = restype_str_to_int[amino_acid]
243
+ omit_AA_per_residue[i1, j1] = 1.0
244
+
245
+ fixed_positions = torch.tensor(
246
+ [int(item not in fixed_residues) for item in encoded_residues],
247
+ device=device,
248
+ )
249
+ redesigned_positions = torch.tensor(
250
+ [int(item not in redesigned_residues) for item in encoded_residues],
251
+ device=device,
252
+ )
253
+
254
+ # specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model
255
+ if args.transmembrane_buried:
256
+ buried_residues = [item for item in args.transmembrane_buried.split()]
257
+ buried_positions = torch.tensor(
258
+ [int(item in buried_residues) for item in encoded_residues],
259
+ device=device,
260
+ )
261
+ else:
262
+ buried_positions = torch.zeros_like(fixed_positions)
263
+
264
+ if args.transmembrane_interface:
265
+ interface_residues = [item for item in args.transmembrane_interface.split()]
266
+ interface_positions = torch.tensor(
267
+ [int(item in interface_residues) for item in encoded_residues],
268
+ device=device,
269
+ )
270
+ else:
271
+ interface_positions = torch.zeros_like(fixed_positions)
272
+ protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * (
273
+ 1 - interface_positions
274
+ ) + 1 * interface_positions * (1 - buried_positions)
275
+
276
+ if args.model_type == "global_label_membrane_mpnn":
277
+ protein_dict["membrane_per_residue_labels"] = (
278
+ args.global_transmembrane_label + 0 * fixed_positions
279
+ )
280
+ if len(args.chains_to_design) != 0:
281
+ chains_to_design_list = args.chains_to_design.split(",")
282
+ else:
283
+ chains_to_design_list = protein_dict["chain_letters"]
284
+
285
+ chain_mask = torch.tensor(
286
+ np.array(
287
+ [
288
+ item in chains_to_design_list
289
+ for item in protein_dict["chain_letters"]
290
+ ],
291
+ dtype=np.int32,
292
+ ),
293
+ device=device,
294
+ )
295
+
296
+ # create chain_mask to notify which residues are fixed (0) and which need to be designed (1)
297
+ if redesigned_residues:
298
+ protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions)
299
+ elif fixed_residues:
300
+ protein_dict["chain_mask"] = chain_mask * fixed_positions
301
+ else:
302
+ protein_dict["chain_mask"] = chain_mask
303
+
304
+ if args.verbose:
305
+ PDB_residues_to_be_redesigned = [
306
+ encoded_residue_dict_rev[item]
307
+ for item in range(protein_dict["chain_mask"].shape[0])
308
+ if protein_dict["chain_mask"][item] == 1
309
+ ]
310
+ PDB_residues_to_be_fixed = [
311
+ encoded_residue_dict_rev[item]
312
+ for item in range(protein_dict["chain_mask"].shape[0])
313
+ if protein_dict["chain_mask"][item] == 0
314
+ ]
315
+ print("These residues will be redesigned: ", PDB_residues_to_be_redesigned)
316
+ print("These residues will be fixed: ", PDB_residues_to_be_fixed)
317
+
318
+ # specify which residues are linked
319
+ if args.symmetry_residues:
320
+ symmetry_residues_list_of_lists = [
321
+ x.split(",") for x in args.symmetry_residues.split("|")
322
+ ]
323
+ remapped_symmetry_residues = []
324
+ for t_list in symmetry_residues_list_of_lists:
325
+ tmp_list = []
326
+ for t in t_list:
327
+ tmp_list.append(encoded_residue_dict[t])
328
+ remapped_symmetry_residues.append(tmp_list)
329
+ else:
330
+ remapped_symmetry_residues = [[]]
331
+
332
+ # specify linking weights
333
+ if args.symmetry_weights:
334
+ symmetry_weights = [
335
+ [float(item) for item in x.split(",")]
336
+ for x in args.symmetry_weights.split("|")
337
+ ]
338
+ else:
339
+ symmetry_weights = [[]]
340
+
341
+ if args.homo_oligomer:
342
+ if args.verbose:
343
+ print("Designing HOMO-OLIGOMER")
344
+ chain_letters_set = list(set(chain_letters_list))
345
+ reference_chain = chain_letters_set[0]
346
+ lc = len(reference_chain)
347
+ residue_indices = [
348
+ item[lc:] for item in encoded_residues if item[:lc] == reference_chain
349
+ ]
350
+ remapped_symmetry_residues = []
351
+ symmetry_weights = []
352
+ for res in residue_indices:
353
+ tmp_list = []
354
+ tmp_w_list = []
355
+ for chain in chain_letters_set:
356
+ name = chain + res
357
+ tmp_list.append(encoded_residue_dict[name])
358
+ tmp_w_list.append(1 / len(chain_letters_set))
359
+ remapped_symmetry_residues.append(tmp_list)
360
+ symmetry_weights.append(tmp_w_list)
361
+
362
+ # set other atom bfactors to 0.0
363
+ if other_atoms:
364
+ other_bfactors = other_atoms.getBetas()
365
+ other_atoms.setBetas(other_bfactors * 0.0)
366
+
367
+ # adjust input PDB name by dropping .pdb if it does exist
368
+ name = pdb[pdb.rfind("/") + 1 :]
369
+ if name[-4:] == ".pdb":
370
+ name = name[:-4]
371
+
372
+ with torch.no_grad():
373
+ # run featurize to remap R_idx and add batch dimension
374
+ if args.verbose:
375
+ if "Y" in list(protein_dict):
376
+ atom_coords = protein_dict["Y"].cpu().numpy()
377
+ atom_types = list(protein_dict["Y_t"].cpu().numpy())
378
+ atom_mask = list(protein_dict["Y_m"].cpu().numpy())
379
+ number_of_atoms_parsed = np.sum(atom_mask)
380
+ else:
381
+ print("No ligand atoms parsed")
382
+ number_of_atoms_parsed = 0
383
+ atom_types = ""
384
+ atom_coords = []
385
+ if number_of_atoms_parsed == 0:
386
+ print("No ligand atoms parsed")
387
+ elif args.model_type == "ligand_mpnn":
388
+ print(
389
+ f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}"
390
+ )
391
+ for i, atom_type in enumerate(atom_types):
392
+ print(
393
+ f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}"
394
+ )
395
+ feature_dict = featurize(
396
+ protein_dict,
397
+ cutoff_for_score=args.ligand_mpnn_cutoff_for_score,
398
+ use_atom_context=args.ligand_mpnn_use_atom_context,
399
+ number_of_ligand_atoms=atom_context_num,
400
+ model_type=args.model_type,
401
+ )
402
+ feature_dict["batch_size"] = args.batch_size
403
+ B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now.
404
+ # add additional keys to the feature dictionary
405
+ feature_dict["temperature"] = args.temperature
406
+ feature_dict["bias"] = (
407
+ (-1e8 * omit_AA[None, None, :] + bias_AA).repeat([1, L, 1])
408
+ + bias_AA_per_residue[None]
409
+ - 1e8 * omit_AA_per_residue[None]
410
+ )
411
+ feature_dict["symmetry_residues"] = remapped_symmetry_residues
412
+ feature_dict["symmetry_weights"] = symmetry_weights
413
+
414
+ sampling_probs_list = []
415
+ log_probs_list = []
416
+ decoding_order_list = []
417
+ S_list = []
418
+ loss_list = []
419
+ loss_per_residue_list = []
420
+ loss_XY_list = []
421
+ for _ in range(args.number_of_batches):
422
+ feature_dict["randn"] = torch.randn(
423
+ [feature_dict["batch_size"], feature_dict["mask"].shape[1]],
424
+ device=device,
425
+ )
426
+ output_dict = model.sample(feature_dict)
427
+
428
+ # compute confidence scores
429
+ loss, loss_per_residue = get_score(
430
+ output_dict["S"],
431
+ output_dict["log_probs"],
432
+ feature_dict["mask"] * feature_dict["chain_mask"],
433
+ )
434
+ if args.model_type == "ligand_mpnn":
435
+ combined_mask = (
436
+ feature_dict["mask"]
437
+ * feature_dict["mask_XY"]
438
+ * feature_dict["chain_mask"]
439
+ )
440
+ else:
441
+ combined_mask = feature_dict["mask"] * feature_dict["chain_mask"]
442
+ loss_XY, _ = get_score(
443
+ output_dict["S"], output_dict["log_probs"], combined_mask
444
+ )
445
+ # -----
446
+ S_list.append(output_dict["S"])
447
+ log_probs_list.append(output_dict["log_probs"])
448
+ sampling_probs_list.append(output_dict["sampling_probs"])
449
+ decoding_order_list.append(output_dict["decoding_order"])
450
+ loss_list.append(loss)
451
+ loss_per_residue_list.append(loss_per_residue)
452
+ loss_XY_list.append(loss_XY)
453
+ S_stack = torch.cat(S_list, 0)
454
+ log_probs_stack = torch.cat(log_probs_list, 0)
455
+ sampling_probs_stack = torch.cat(sampling_probs_list, 0)
456
+ decoding_order_stack = torch.cat(decoding_order_list, 0)
457
+ loss_stack = torch.cat(loss_list, 0)
458
+ loss_per_residue_stack = torch.cat(loss_per_residue_list, 0)
459
+ loss_XY_stack = torch.cat(loss_XY_list, 0)
460
+ rec_mask = feature_dict["mask"][:1] * feature_dict["chain_mask"][:1]
461
+ rec_stack = get_seq_rec(feature_dict["S"][:1], S_stack, rec_mask)
462
+
463
+ native_seq = "".join(
464
+ [restype_int_to_str[AA] for AA in feature_dict["S"][0].cpu().numpy()]
465
+ )
466
+ seq_np = np.array(list(native_seq))
467
+ seq_out_str = []
468
+ for mask in protein_dict["mask_c"]:
469
+ seq_out_str += list(seq_np[mask.cpu().numpy()])
470
+ seq_out_str += [args.fasta_seq_separation]
471
+ seq_out_str = "".join(seq_out_str)[:-1]
472
+
473
+ output_fasta = base_folder + "/seqs/" + name + args.file_ending + ".fa"
474
+ output_backbones = base_folder + "/backbones/"
475
+ output_packed = base_folder + "/packed/"
476
+ output_stats_path = base_folder + "stats/" + name + args.file_ending + ".pt"
477
+
478
+ out_dict = {}
479
+ out_dict["generated_sequences"] = S_stack.cpu()
480
+ out_dict["sampling_probs"] = sampling_probs_stack.cpu()
481
+ out_dict["log_probs"] = log_probs_stack.cpu()
482
+ out_dict["decoding_order"] = decoding_order_stack.cpu()
483
+ out_dict["native_sequence"] = feature_dict["S"][0].cpu()
484
+ out_dict["mask"] = feature_dict["mask"][0].cpu()
485
+ out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu()
486
+ out_dict["seed"] = seed
487
+ out_dict["temperature"] = args.temperature
488
+ if args.save_stats:
489
+ torch.save(out_dict, output_stats_path)
490
+
491
+ if args.pack_side_chains:
492
+ if args.verbose:
493
+ print("Packing side chains...")
494
+ feature_dict_ = featurize(
495
+ protein_dict,
496
+ cutoff_for_score=8.0,
497
+ use_atom_context=args.pack_with_ligand_context,
498
+ number_of_ligand_atoms=16,
499
+ model_type="ligand_mpnn",
500
+ )
501
+ sc_feature_dict = copy.deepcopy(feature_dict_)
502
+ B = args.batch_size
503
+ for k, v in sc_feature_dict.items():
504
+ if k != "S":
505
+ try:
506
+ num_dim = len(v.shape)
507
+ if num_dim == 2:
508
+ sc_feature_dict[k] = v.repeat(B, 1)
509
+ elif num_dim == 3:
510
+ sc_feature_dict[k] = v.repeat(B, 1, 1)
511
+ elif num_dim == 4:
512
+ sc_feature_dict[k] = v.repeat(B, 1, 1, 1)
513
+ elif num_dim == 5:
514
+ sc_feature_dict[k] = v.repeat(B, 1, 1, 1, 1)
515
+ except:
516
+ pass
517
+ X_stack_list = []
518
+ X_m_stack_list = []
519
+ b_factor_stack_list = []
520
+ for _ in range(args.number_of_packs_per_design):
521
+ X_list = []
522
+ X_m_list = []
523
+ b_factor_list = []
524
+ for c in range(args.number_of_batches):
525
+ sc_feature_dict["S"] = S_list[c]
526
+ sc_dict = pack_side_chains(
527
+ sc_feature_dict,
528
+ model_sc,
529
+ args.sc_num_denoising_steps,
530
+ args.sc_num_samples,
531
+ args.repack_everything,
532
+ )
533
+ X_list.append(sc_dict["X"])
534
+ X_m_list.append(sc_dict["X_m"])
535
+ b_factor_list.append(sc_dict["b_factors"])
536
+
537
+ X_stack = torch.cat(X_list, 0)
538
+ X_m_stack = torch.cat(X_m_list, 0)
539
+ b_factor_stack = torch.cat(b_factor_list, 0)
540
+
541
+ X_stack_list.append(X_stack)
542
+ X_m_stack_list.append(X_m_stack)
543
+ b_factor_stack_list.append(b_factor_stack)
544
+
545
+ with open(output_fasta, "w") as f:
546
+ f.write(
547
+ ">{}, T={}, seed={}, num_res={}, num_ligand_res={}, use_ligand_context={}, ligand_cutoff_distance={}, batch_size={}, number_of_batches={}, model_path={}\n{}\n".format(
548
+ name,
549
+ args.temperature,
550
+ seed,
551
+ torch.sum(rec_mask).cpu().numpy(),
552
+ torch.sum(combined_mask[:1]).cpu().numpy(),
553
+ bool(args.ligand_mpnn_use_atom_context),
554
+ float(args.ligand_mpnn_cutoff_for_score),
555
+ args.batch_size,
556
+ args.number_of_batches,
557
+ checkpoint_path,
558
+ seq_out_str,
559
+ )
560
+ )
561
+ for ix in range(S_stack.shape[0]):
562
+ ix_suffix = ix
563
+ if not args.zero_indexed:
564
+ ix_suffix += 1
565
+ seq_rec_print = np.format_float_positional(
566
+ rec_stack[ix].cpu().numpy(), unique=False, precision=4
567
+ )
568
+ loss_np = np.format_float_positional(
569
+ np.exp(-loss_stack[ix].cpu().numpy()), unique=False, precision=4
570
+ )
571
+ loss_XY_np = np.format_float_positional(
572
+ np.exp(-loss_XY_stack[ix].cpu().numpy()),
573
+ unique=False,
574
+ precision=4,
575
+ )
576
+ seq = "".join(
577
+ [restype_int_to_str[AA] for AA in S_stack[ix].cpu().numpy()]
578
+ )
579
+
580
+ # write new sequences into PDB with backbone coordinates
581
+ seq_prody = np.array([restype_1to3[AA] for AA in list(seq)])[
582
+ None,
583
+ ].repeat(4, 1)
584
+ bfactor_prody = (
585
+ loss_per_residue_stack[ix].cpu().numpy()[None, :].repeat(4, 1)
586
+ )
587
+ backbone.setResnames(seq_prody)
588
+ backbone.setBetas(
589
+ np.exp(-bfactor_prody)
590
+ * (bfactor_prody > 0.01).astype(np.float32)
591
+ )
592
+ if other_atoms:
593
+ writePDB(
594
+ output_backbones
595
+ + name
596
+ + "_"
597
+ + str(ix_suffix)
598
+ + args.file_ending
599
+ + ".pdb",
600
+ backbone + other_atoms,
601
+ )
602
+ else:
603
+ writePDB(
604
+ output_backbones
605
+ + name
606
+ + "_"
607
+ + str(ix_suffix)
608
+ + args.file_ending
609
+ + ".pdb",
610
+ backbone,
611
+ )
612
+
613
+ # write full PDB files
614
+ if args.pack_side_chains:
615
+ for c_pack in range(args.number_of_packs_per_design):
616
+ X_stack = X_stack_list[c_pack]
617
+ X_m_stack = X_m_stack_list[c_pack]
618
+ b_factor_stack = b_factor_stack_list[c_pack]
619
+ write_full_PDB(
620
+ output_packed
621
+ + name
622
+ + args.packed_suffix
623
+ + "_"
624
+ + str(ix_suffix)
625
+ + "_"
626
+ + str(c_pack + 1)
627
+ + args.file_ending
628
+ + ".pdb",
629
+ X_stack[ix].cpu().numpy(),
630
+ X_m_stack[ix].cpu().numpy(),
631
+ b_factor_stack[ix].cpu().numpy(),
632
+ feature_dict["R_idx_original"][0].cpu().numpy(),
633
+ protein_dict["chain_letters"],
634
+ S_stack[ix].cpu().numpy(),
635
+ other_atoms=other_atoms,
636
+ icodes=icodes,
637
+ force_hetatm=args.force_hetatm,
638
+ )
639
+ # -----
640
+
641
+ # write fasta lines
642
+ seq_np = np.array(list(seq))
643
+ seq_out_str = []
644
+ for mask in protein_dict["mask_c"]:
645
+ seq_out_str += list(seq_np[mask.cpu().numpy()])
646
+ seq_out_str += [args.fasta_seq_separation]
647
+ seq_out_str = "".join(seq_out_str)[:-1]
648
+ if ix == S_stack.shape[0] - 1:
649
+ # final 2 lines
650
+ f.write(
651
+ ">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}".format(
652
+ name,
653
+ ix_suffix,
654
+ args.temperature,
655
+ seed,
656
+ loss_np,
657
+ loss_XY_np,
658
+ seq_rec_print,
659
+ seq_out_str,
660
+ )
661
+ )
662
+ else:
663
+ f.write(
664
+ ">{}, id={}, T={}, seed={}, overall_confidence={}, ligand_confidence={}, seq_rec={}\n{}\n".format(
665
+ name,
666
+ ix_suffix,
667
+ args.temperature,
668
+ seed,
669
+ loss_np,
670
+ loss_XY_np,
671
+ seq_rec_print,
672
+ seq_out_str,
673
+ )
674
+ )
675
+
676
+
677
+ if __name__ == "__main__":
678
+ argparser = argparse.ArgumentParser(
679
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
680
+ )
681
+
682
+ argparser.add_argument(
683
+ "--model_type",
684
+ type=str,
685
+ default="protein_mpnn",
686
+ help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn",
687
+ )
688
+ # protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms
689
+ # ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB
690
+ # per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed
691
+ # global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane
692
+ # soluble_mpnn - ProteinMPNN trained only on soluble PDB ids
693
+ argparser.add_argument(
694
+ "--checkpoint_protein_mpnn",
695
+ type=str,
696
+ default="./model_params/proteinmpnn_v_48_020.pt",
697
+ help="Path to model weights.",
698
+ )
699
+ argparser.add_argument(
700
+ "--checkpoint_ligand_mpnn",
701
+ type=str,
702
+ default="./model_params/ligandmpnn_v_32_010_25.pt",
703
+ help="Path to model weights.",
704
+ )
705
+ argparser.add_argument(
706
+ "--checkpoint_per_residue_label_membrane_mpnn",
707
+ type=str,
708
+ default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt",
709
+ help="Path to model weights.",
710
+ )
711
+ argparser.add_argument(
712
+ "--checkpoint_global_label_membrane_mpnn",
713
+ type=str,
714
+ default="./model_params/global_label_membrane_mpnn_v_48_020.pt",
715
+ help="Path to model weights.",
716
+ )
717
+ argparser.add_argument(
718
+ "--checkpoint_soluble_mpnn",
719
+ type=str,
720
+ default="./model_params/solublempnn_v_48_020.pt",
721
+ help="Path to model weights.",
722
+ )
723
+
724
+ argparser.add_argument(
725
+ "--fasta_seq_separation",
726
+ type=str,
727
+ default=":",
728
+ help="Symbol to use between sequences from different chains",
729
+ )
730
+ argparser.add_argument("--verbose", type=int, default=1, help="Print stuff")
731
+
732
+ argparser.add_argument(
733
+ "--pdb_path", type=str, default="", help="Path to the input PDB."
734
+ )
735
+ argparser.add_argument(
736
+ "--pdb_path_multi",
737
+ type=str,
738
+ default="",
739
+ help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.",
740
+ )
741
+
742
+ argparser.add_argument(
743
+ "--fixed_residues",
744
+ type=str,
745
+ default="",
746
+ help="Provide fixed residues, A12 A13 A14 B2 B25",
747
+ )
748
+ argparser.add_argument(
749
+ "--fixed_residues_multi",
750
+ type=str,
751
+ default="",
752
+ help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
753
+ )
754
+
755
+ argparser.add_argument(
756
+ "--redesigned_residues",
757
+ type=str,
758
+ default="",
759
+ help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25",
760
+ )
761
+ argparser.add_argument(
762
+ "--redesigned_residues_multi",
763
+ type=str,
764
+ default="",
765
+ help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
766
+ )
767
+
768
+ argparser.add_argument(
769
+ "--bias_AA",
770
+ type=str,
771
+ default="",
772
+ help="Bias generation of amino acids, e.g. 'A:-1.024,P:2.34,C:-12.34'",
773
+ )
774
+ argparser.add_argument(
775
+ "--bias_AA_per_residue",
776
+ type=str,
777
+ default="",
778
+ help="Path to json mapping of bias {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}",
779
+ )
780
+ argparser.add_argument(
781
+ "--bias_AA_per_residue_multi",
782
+ type=str,
783
+ default="",
784
+ help="Path to json mapping of bias {'pdb_path': {'A12': {'G': -0.3, 'C': -2.0, 'H': 0.8}, 'A13': {'G': -1.3}}}",
785
+ )
786
+
787
+ argparser.add_argument(
788
+ "--omit_AA",
789
+ type=str,
790
+ default="",
791
+ help="Bias generation of amino acids, e.g. 'ACG'",
792
+ )
793
+ argparser.add_argument(
794
+ "--omit_AA_per_residue",
795
+ type=str,
796
+ default="",
797
+ help="Path to json mapping of bias {'A12': 'APQ', 'A13': 'QST'}",
798
+ )
799
+ argparser.add_argument(
800
+ "--omit_AA_per_residue_multi",
801
+ type=str,
802
+ default="",
803
+ help="Path to json mapping of bias {'pdb_path': {'A12': 'QSPC', 'A13': 'AGE'}}",
804
+ )
805
+
806
+ argparser.add_argument(
807
+ "--symmetry_residues",
808
+ type=str,
809
+ default="",
810
+ help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'",
811
+ )
812
+ argparser.add_argument(
813
+ "--symmetry_weights",
814
+ type=str,
815
+ default="",
816
+ help="Add weights that match symmetry_residues, e.g. '1.01,1.0,1.0|-1.0,2.0|2.0,2.3'",
817
+ )
818
+ argparser.add_argument(
819
+ "--homo_oligomer",
820
+ type=int,
821
+ default=0,
822
+ help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.",
823
+ )
824
+
825
+ argparser.add_argument(
826
+ "--out_folder",
827
+ type=str,
828
+ help="Path to a folder to output sequences, e.g. /home/out/",
829
+ )
830
+ argparser.add_argument(
831
+ "--file_ending", type=str, default="", help="adding_string_to_the_end"
832
+ )
833
+ argparser.add_argument(
834
+ "--zero_indexed",
835
+ type=str,
836
+ default=0,
837
+ help="1 - to start output PDB numbering with 0",
838
+ )
839
+ argparser.add_argument(
840
+ "--seed",
841
+ type=int,
842
+ default=0,
843
+ help="Set seed for torch, numpy, and python random.",
844
+ )
845
+ argparser.add_argument(
846
+ "--batch_size",
847
+ type=int,
848
+ default=1,
849
+ help="Number of sequence to generate per one pass.",
850
+ )
851
+ argparser.add_argument(
852
+ "--number_of_batches",
853
+ type=int,
854
+ default=1,
855
+ help="Number of times to design sequence using a chosen batch size.",
856
+ )
857
+ argparser.add_argument(
858
+ "--temperature",
859
+ type=float,
860
+ default=0.1,
861
+ help="Temperature to sample sequences.",
862
+ )
863
+ argparser.add_argument(
864
+ "--save_stats", type=int, default=0, help="Save output statistics"
865
+ )
866
+
867
+ argparser.add_argument(
868
+ "--ligand_mpnn_use_atom_context",
869
+ type=int,
870
+ default=1,
871
+ help="1 - use atom context, 0 - do not use atom context.",
872
+ )
873
+ argparser.add_argument(
874
+ "--ligand_mpnn_cutoff_for_score",
875
+ type=float,
876
+ default=8.0,
877
+ help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.",
878
+ )
879
+ argparser.add_argument(
880
+ "--ligand_mpnn_use_side_chain_context",
881
+ type=int,
882
+ default=0,
883
+ help="Flag to use side chain atoms as ligand context for the fixed residues",
884
+ )
885
+ argparser.add_argument(
886
+ "--chains_to_design",
887
+ type=str,
888
+ default="",
889
+ help="Specify which chains to redesign, all others will be kept fixed, 'A,B,C,F'",
890
+ )
891
+
892
+ argparser.add_argument(
893
+ "--parse_these_chains_only",
894
+ type=str,
895
+ default="",
896
+ help="Provide chains letters for parsing backbones, 'A,B,C,F'",
897
+ )
898
+
899
+ argparser.add_argument(
900
+ "--transmembrane_buried",
901
+ type=str,
902
+ default="",
903
+ help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
904
+ )
905
+ argparser.add_argument(
906
+ "--transmembrane_interface",
907
+ type=str,
908
+ default="",
909
+ help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
910
+ )
911
+
912
+ argparser.add_argument(
913
+ "--global_transmembrane_label",
914
+ type=int,
915
+ default=0,
916
+ help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble",
917
+ )
918
+
919
+ argparser.add_argument(
920
+ "--parse_atoms_with_zero_occupancy",
921
+ type=int,
922
+ default=0,
923
+ help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy",
924
+ )
925
+
926
+ argparser.add_argument(
927
+ "--pack_side_chains",
928
+ type=int,
929
+ default=0,
930
+ help="1 - to run side chain packer, 0 - do not run it",
931
+ )
932
+
933
+ argparser.add_argument(
934
+ "--checkpoint_path_sc",
935
+ type=str,
936
+ default="./model_params/ligandmpnn_sc_v_32_002_16.pt",
937
+ help="Path to model weights.",
938
+ )
939
+
940
+ argparser.add_argument(
941
+ "--number_of_packs_per_design",
942
+ type=int,
943
+ default=4,
944
+ help="Number of independent side chain packing samples to return per design",
945
+ )
946
+
947
+ argparser.add_argument(
948
+ "--sc_num_denoising_steps",
949
+ type=int,
950
+ default=3,
951
+ help="Number of denoising/recycling steps to make for side chain packing",
952
+ )
953
+
954
+ argparser.add_argument(
955
+ "--sc_num_samples",
956
+ type=int,
957
+ default=16,
958
+ help="Number of samples to draw from a mixture distribution and then take a sample with the highest likelihood.",
959
+ )
960
+
961
+ argparser.add_argument(
962
+ "--repack_everything",
963
+ type=int,
964
+ default=0,
965
+ help="1 - repacks side chains of all residues including the fixed ones; 0 - keeps the side chains fixed for fixed residues",
966
+ )
967
+
968
+ argparser.add_argument(
969
+ "--force_hetatm",
970
+ type=int,
971
+ default=0,
972
+ help="To force ligand atoms to be written as HETATM to PDB file after packing.",
973
+ )
974
+
975
+ argparser.add_argument(
976
+ "--packed_suffix",
977
+ type=str,
978
+ default="_packed",
979
+ help="Suffix for packed PDB paths",
980
+ )
981
+
982
+ argparser.add_argument(
983
+ "--pack_with_ligand_context",
984
+ type=int,
985
+ default=1,
986
+ help="1-pack side chains using ligand context, 0 - do not use it.",
987
+ )
988
+
989
+ args = argparser.parse_args()
990
+ main(args)
run_examples.sh ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #1
4
+ python run.py \
5
+ --seed 111 \
6
+ --pdb_path "./inputs/1BC8.pdb" \
7
+ --out_folder "./outputs/default"
8
+ #2
9
+ python run.py \
10
+ --seed 111 \
11
+ --pdb_path "./inputs/1BC8.pdb" \
12
+ --temperature 0.05 \
13
+ --out_folder "./outputs/temperature"
14
+
15
+ #3
16
+ python run.py \
17
+ --pdb_path "./inputs/1BC8.pdb" \
18
+ --out_folder "./outputs/random_seed"
19
+
20
+ #4
21
+ python run.py \
22
+ --seed 111 \
23
+ --verbose 0 \
24
+ --pdb_path "./inputs/1BC8.pdb" \
25
+ --out_folder "./outputs/verbose"
26
+
27
+ #5
28
+ python run.py \
29
+ --seed 111 \
30
+ --pdb_path "./inputs/1BC8.pdb" \
31
+ --out_folder "./outputs/save_stats" \
32
+ --save_stats 1
33
+
34
+ #6
35
+ python run.py \
36
+ --seed 111 \
37
+ --pdb_path "./inputs/1BC8.pdb" \
38
+ --out_folder "./outputs/fix_residues" \
39
+ --fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
40
+ --bias_AA "A:10.0"
41
+
42
+ #7
43
+ python run.py \
44
+ --seed 111 \
45
+ --pdb_path "./inputs/1BC8.pdb" \
46
+ --out_folder "./outputs/redesign_residues" \
47
+ --redesigned_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10" \
48
+ --bias_AA "A:10.0"
49
+
50
+ #8
51
+ python run.py \
52
+ --seed 111 \
53
+ --pdb_path "./inputs/1BC8.pdb" \
54
+ --out_folder "./outputs/batch_size" \
55
+ --batch_size 3 \
56
+ --number_of_batches 5
57
+
58
+ #9
59
+ python run.py \
60
+ --seed 111 \
61
+ --pdb_path "./inputs/1BC8.pdb" \
62
+ --bias_AA "W:3.0,P:3.0,C:3.0,A:-3.0" \
63
+ --out_folder "./outputs/global_bias"
64
+
65
+ #10
66
+ python run.py \
67
+ --seed 111 \
68
+ --pdb_path "./inputs/1BC8.pdb" \
69
+ --bias_AA_per_residue "./inputs/bias_AA_per_residue.json" \
70
+ --out_folder "./outputs/per_residue_bias"
71
+
72
+ #11
73
+ python run.py \
74
+ --seed 111 \
75
+ --pdb_path "./inputs/1BC8.pdb" \
76
+ --omit_AA "CDFGHILMNPQRSTVWY" \
77
+ --out_folder "./outputs/global_omit"
78
+
79
+ #12
80
+ python run.py \
81
+ --seed 111 \
82
+ --pdb_path "./inputs/1BC8.pdb" \
83
+ --omit_AA_per_residue "./inputs/omit_AA_per_residue.json" \
84
+ --out_folder "./outputs/per_residue_omit"
85
+
86
+ #13
87
+ python run.py \
88
+ --seed 111 \
89
+ --pdb_path "./inputs/1BC8.pdb" \
90
+ --out_folder "./outputs/symmetry" \
91
+ --symmetry_residues "C1,C2,C3|C4,C5|C6,C7" \
92
+ --symmetry_weights "0.33,0.33,0.33|0.5,0.5|0.5,0.5"
93
+
94
+ #14
95
+ python run.py \
96
+ --model_type "ligand_mpnn" \
97
+ --seed 111 \
98
+ --pdb_path "./inputs/4GYT.pdb" \
99
+ --out_folder "./outputs/homooligomer" \
100
+ --homo_oligomer 1 \
101
+ --number_of_batches 2
102
+
103
+ #15
104
+ python run.py \
105
+ --seed 111 \
106
+ --pdb_path "./inputs/1BC8.pdb" \
107
+ --out_folder "./outputs/file_ending" \
108
+ --file_ending "_xyz"
109
+
110
+ #16
111
+ python run.py \
112
+ --seed 111 \
113
+ --pdb_path "./inputs/1BC8.pdb" \
114
+ --out_folder "./outputs/zero_indexed" \
115
+ --zero_indexed 1 \
116
+ --number_of_batches 2
117
+
118
+ #17
119
+ python run.py \
120
+ --model_type "ligand_mpnn" \
121
+ --seed 111 \
122
+ --pdb_path "./inputs/4GYT.pdb" \
123
+ --out_folder "./outputs/chains_to_design" \
124
+ --chains_to_design "A,B"
125
+
126
+ #18
127
+ python run.py \
128
+ --model_type "ligand_mpnn" \
129
+ --seed 111 \
130
+ --pdb_path "./inputs/4GYT.pdb" \
131
+ --out_folder "./outputs/parse_these_chains_only" \
132
+ --parse_these_chains_only "A,B"
133
+
134
+ #19
135
+ python run.py \
136
+ --model_type "ligand_mpnn" \
137
+ --seed 111 \
138
+ --pdb_path "./inputs/1BC8.pdb" \
139
+ --out_folder "./outputs/ligandmpnn_default"
140
+
141
+ #20
142
+ python run.py \
143
+ --checkpoint_ligand_mpnn "./model_params/ligandmpnn_v_32_005_25.pt" \
144
+ --model_type "ligand_mpnn" \
145
+ --seed 111 \
146
+ --pdb_path "./inputs/1BC8.pdb" \
147
+ --out_folder "./outputs/ligandmpnn_v_32_005_25"
148
+
149
+ #21
150
+ python run.py \
151
+ --model_type "ligand_mpnn" \
152
+ --seed 111 \
153
+ --pdb_path "./inputs/1BC8.pdb" \
154
+ --out_folder "./outputs/ligandmpnn_no_context" \
155
+ --ligand_mpnn_use_atom_context 0
156
+
157
+ #22
158
+ python run.py \
159
+ --model_type "ligand_mpnn" \
160
+ --seed 111 \
161
+ --pdb_path "./inputs/1BC8.pdb" \
162
+ --out_folder "./outputs/ligandmpnn_use_side_chain_atoms" \
163
+ --ligand_mpnn_use_side_chain_context 1 \
164
+ --fixed_residues "C1 C2 C3 C4 C5 C6 C7 C8 C9 C10"
165
+
166
+ #23
167
+ python run.py \
168
+ --model_type "soluble_mpnn" \
169
+ --seed 111 \
170
+ --pdb_path "./inputs/1BC8.pdb" \
171
+ --out_folder "./outputs/soluble_mpnn_default"
172
+
173
+ #24
174
+ python run.py \
175
+ --model_type "global_label_membrane_mpnn" \
176
+ --seed 111 \
177
+ --pdb_path "./inputs/1BC8.pdb" \
178
+ --out_folder "./outputs/global_label_membrane_mpnn_0" \
179
+ --global_transmembrane_label 0
180
+
181
+ #25
182
+ python run.py \
183
+ --model_type "per_residue_label_membrane_mpnn" \
184
+ --seed 111 \
185
+ --pdb_path "./inputs/1BC8.pdb" \
186
+ --out_folder "./outputs/per_residue_label_membrane_mpnn_default" \
187
+ --transmembrane_buried "C1 C2 C3 C11" \
188
+ --transmembrane_interface "C4 C5 C6 C22"
189
+
190
+ #26
191
+ python run.py \
192
+ --pdb_path "./inputs/1BC8.pdb" \
193
+ --out_folder "./outputs/fasta_seq_separation" \
194
+ --fasta_seq_separation ":"
195
+
196
+ #27
197
+ python run.py \
198
+ --pdb_path_multi "./inputs/pdb_ids.json" \
199
+ --out_folder "./outputs/pdb_path_multi" \
200
+ --seed 111
201
+
202
+ #28
203
+ python run.py \
204
+ --pdb_path_multi "./inputs/pdb_ids.json" \
205
+ --fixed_residues_multi "./inputs/fix_residues_multi.json" \
206
+ --out_folder "./outputs/fixed_residues_multi" \
207
+ --seed 111
208
+
209
+ #29
210
+ python run.py \
211
+ --pdb_path_multi "./inputs/pdb_ids.json" \
212
+ --redesigned_residues_multi "./inputs/redesigned_residues_multi.json" \
213
+ --out_folder "./outputs/redesigned_residues_multi" \
214
+ --seed 111
215
+
216
+ #30
217
+ python run.py \
218
+ --pdb_path_multi "./inputs/pdb_ids.json" \
219
+ --omit_AA_per_residue_multi "./inputs/omit_AA_per_residue_multi.json" \
220
+ --out_folder "./outputs/omit_AA_per_residue_multi" \
221
+ --seed 111
222
+
223
+ #31
224
+ python run.py \
225
+ --pdb_path_multi "./inputs/pdb_ids.json" \
226
+ --bias_AA_per_residue_multi "./inputs/bias_AA_per_residue_multi.json" \
227
+ --out_folder "./outputs/bias_AA_per_residue_multi" \
228
+ --seed 111
229
+
230
+ #32
231
+ python run.py \
232
+ --model_type "ligand_mpnn" \
233
+ --seed 111 \
234
+ --pdb_path "./inputs/1BC8.pdb" \
235
+ --ligand_mpnn_cutoff_for_score "6.0" \
236
+ --out_folder "./outputs/ligand_mpnn_cutoff_for_score"
237
+
238
+ #33
239
+ python run.py \
240
+ --seed 111 \
241
+ --pdb_path "./inputs/2GFB.pdb" \
242
+ --out_folder "./outputs/insertion_code" \
243
+ --redesigned_residues "B82 B82A B82B B82C" \
244
+ --parse_these_chains_only "B"
sc_examples.sh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #1 design a new sequence and pack side chains (return 1 side chain packing sample - fast)
2
+ python run.py \
3
+ --model_type "ligand_mpnn" \
4
+ --seed 111 \
5
+ --pdb_path "./inputs/1BC8.pdb" \
6
+ --out_folder "./outputs/sc_default_fast" \
7
+ --pack_side_chains 1 \
8
+ --number_of_packs_per_design 0 \
9
+ --pack_with_ligand_context 1
10
+
11
+ #2 design a new sequence and pack side chains (return 4 side chain packing samples)
12
+ python run.py \
13
+ --model_type "ligand_mpnn" \
14
+ --seed 111 \
15
+ --pdb_path "./inputs/1BC8.pdb" \
16
+ --out_folder "./outputs/sc_default" \
17
+ --pack_side_chains 1 \
18
+ --number_of_packs_per_design 4 \
19
+ --pack_with_ligand_context 1
20
+
21
+
22
+ #3 fix specific residues for design and packing
23
+ python run.py \
24
+ --model_type "ligand_mpnn" \
25
+ --seed 111 \
26
+ --pdb_path "./inputs/1BC8.pdb" \
27
+ --out_folder "./outputs/sc_fixed_residues" \
28
+ --pack_side_chains 1 \
29
+ --number_of_packs_per_design 4 \
30
+ --pack_with_ligand_context 1 \
31
+ --fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
32
+ --repack_everything 0
33
+
34
+ #4 fix specific residues for sequence design but repack everything
35
+ python run.py \
36
+ --model_type "ligand_mpnn" \
37
+ --seed 111 \
38
+ --pdb_path "./inputs/1BC8.pdb" \
39
+ --out_folder "./outputs/sc_fixed_residues_full_repack" \
40
+ --pack_side_chains 1 \
41
+ --number_of_packs_per_design 4 \
42
+ --pack_with_ligand_context 1 \
43
+ --fixed_residues "C6 C7 C8 C9 C10 C11 C12 C13 C14 C15" \
44
+ --repack_everything 1
45
+
46
+
47
+ #5 design a new sequence using LigandMPNN but pack side chains without considering ligand/DNA etc atoms
48
+ python run.py \
49
+ --model_type "ligand_mpnn" \
50
+ --seed 111 \
51
+ --pdb_path "./inputs/1BC8.pdb" \
52
+ --out_folder "./outputs/sc_no_context" \
53
+ --pack_side_chains 1 \
54
+ --number_of_packs_per_design 4 \
55
+ --pack_with_ligand_context 0
sc_utils.py ADDED
@@ -0,0 +1,1158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.distributions as D
6
+ import torch.nn as nn
7
+ from model_utils import (
8
+ DecLayer,
9
+ DecLayerJ,
10
+ EncLayer,
11
+ PositionalEncodings,
12
+ cat_neighbors_nodes,
13
+ gather_edges,
14
+ gather_nodes,
15
+ )
16
+
17
+ from openfold.data.data_transforms import atom37_to_torsion_angles, make_atom14_masks
18
+ from openfold.np.residue_constants import (
19
+ restype_atom14_mask,
20
+ restype_atom14_rigid_group_positions,
21
+ restype_atom14_to_rigid_group,
22
+ restype_rigid_group_default_frame,
23
+ )
24
+ from openfold.utils import feats
25
+ from openfold.utils.rigid_utils import Rigid
26
+
27
+ torch_pi = torch.tensor(np.pi, device="cpu")
28
+
29
+
30
+ map_mpnn_to_af2_seq = torch.tensor(
31
+ [
32
+ [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
33
+ [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
34
+ [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
35
+ [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
36
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
37
+ [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
38
+ [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
39
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
40
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
41
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
42
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
43
+ [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
44
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
45
+ [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
46
+ [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
47
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
48
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
49
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
50
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
51
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
52
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
53
+ ],
54
+ device="cpu",
55
+ )
56
+
57
+
58
+ def pack_side_chains(
59
+ feature_dict,
60
+ model_sc,
61
+ num_denoising_steps,
62
+ num_samples=10,
63
+ repack_everything=True,
64
+ num_context_atoms=16,
65
+ ):
66
+ device = feature_dict["X"].device
67
+ torsion_dict = make_torsion_features(feature_dict, repack_everything)
68
+ feature_dict["X"] = torsion_dict["xyz14_noised"]
69
+ feature_dict["X_m"] = torsion_dict["xyz14_m"]
70
+ if "Y" not in list(feature_dict):
71
+ feature_dict["Y"] = torch.zeros(
72
+ [
73
+ feature_dict["X"].shape[0],
74
+ feature_dict["X"].shape[1],
75
+ num_context_atoms,
76
+ 3,
77
+ ],
78
+ device=device,
79
+ )
80
+ feature_dict["Y_t"] = torch.zeros(
81
+ [feature_dict["X"].shape[0], feature_dict["X"].shape[1], num_context_atoms],
82
+ device=device,
83
+ )
84
+ feature_dict["Y_m"] = torch.zeros(
85
+ [feature_dict["X"].shape[0], feature_dict["X"].shape[1], num_context_atoms],
86
+ device=device,
87
+ )
88
+ h_V, h_E, E_idx = model_sc.encode(feature_dict)
89
+ feature_dict["h_V"] = h_V
90
+ feature_dict["h_E"] = h_E
91
+ feature_dict["E_idx"] = E_idx
92
+ for step in range(num_denoising_steps):
93
+ mean, concentration, mix_logits = model_sc.decode(feature_dict)
94
+ mix = D.Categorical(logits=mix_logits)
95
+ comp = D.VonMises(mean, concentration)
96
+ pred_dist = D.MixtureSameFamily(mix, comp)
97
+ predicted_samples = pred_dist.sample([num_samples])
98
+ log_probs_of_samples = pred_dist.log_prob(predicted_samples)
99
+ sample = torch.gather(
100
+ predicted_samples, dim=0, index=torch.argmax(log_probs_of_samples, 0)[None,]
101
+ )[0,]
102
+ torsions_pred_unit = torch.cat(
103
+ [torch.sin(sample[:, :, :, None]), torch.cos(sample[:, :, :, None])], -1
104
+ )
105
+ torsion_dict["torsions_noised"][:, :, 3:] = torsions_pred_unit * torsion_dict[
106
+ "mask_fix_sc"
107
+ ] + torsion_dict["torsions_true"] * (1 - torsion_dict["mask_fix_sc"])
108
+ pred_frames = feats.torsion_angles_to_frames(
109
+ torsion_dict["rigids"],
110
+ torsion_dict["torsions_noised"],
111
+ torsion_dict["aatype"],
112
+ torch.tensor(restype_rigid_group_default_frame, device=device),
113
+ )
114
+ xyz14_noised = feats.frames_and_literature_positions_to_atom14_pos(
115
+ pred_frames,
116
+ torsion_dict["aatype"],
117
+ torch.tensor(restype_rigid_group_default_frame, device=device),
118
+ torch.tensor(restype_atom14_to_rigid_group, device=device),
119
+ torch.tensor(restype_atom14_mask, device=device),
120
+ torch.tensor(restype_atom14_rigid_group_positions, device=device),
121
+ )
122
+ xyz14_noised = xyz14_noised * feature_dict["X_m"][:, :, :, None]
123
+ feature_dict["X"] = xyz14_noised
124
+ S_af2 = torsion_dict["S_af2"]
125
+
126
+ feature_dict["X"] = xyz14_noised
127
+
128
+ log_prob = pred_dist.log_prob(sample) * torsion_dict["mask_fix_sc"][
129
+ ..., 0
130
+ ] + 2.0 * (1 - torsion_dict["mask_fix_sc"][..., 0])
131
+
132
+ tmp_types = torch.tensor(restype_atom14_to_rigid_group, device=device)[S_af2]
133
+ tmp_types[tmp_types < 4] = 4
134
+ tmp_types -= 4
135
+ atom_types_for_b_factor = torch.nn.functional.one_hot(tmp_types, 4) # [B, L, 14, 4]
136
+
137
+ uncertainty = log_prob[:, :, None, :] * atom_types_for_b_factor # [B,L,14,4]
138
+ b_factor_pred = uncertainty.sum(-1) # [B, L, 14]
139
+ feature_dict["b_factors"] = b_factor_pred
140
+ feature_dict["mean"] = mean
141
+ feature_dict["concentration"] = concentration
142
+ feature_dict["mix_logits"] = mix_logits
143
+ feature_dict["log_prob"] = log_prob
144
+ feature_dict["sample"] = sample
145
+ feature_dict["true_torsion_sin_cos"] = torsion_dict["torsions_true"]
146
+ return feature_dict
147
+
148
+
149
+ def make_torsion_features(feature_dict, repack_everything=True):
150
+ device = feature_dict["mask"].device
151
+
152
+ mask = feature_dict["mask"]
153
+ B, L = mask.shape
154
+
155
+ xyz37 = torch.zeros([B, L, 37, 3], device=device, dtype=torch.float32)
156
+ xyz37[:, :, :3] = feature_dict["X"][:, :, :3]
157
+ xyz37[:, :, 4] = feature_dict["X"][:, :, 3]
158
+
159
+ S_af2 = torch.argmax(
160
+ torch.nn.functional.one_hot(feature_dict["S"], 21).float()
161
+ @ map_mpnn_to_af2_seq.to(device).float(),
162
+ -1,
163
+ )
164
+ masks14_37 = make_atom14_masks({"aatype": S_af2})
165
+ temp_dict = {
166
+ "aatype": S_af2,
167
+ "all_atom_positions": xyz37,
168
+ "all_atom_mask": masks14_37["atom37_atom_exists"],
169
+ }
170
+ torsion_dict = atom37_to_torsion_angles("")(temp_dict)
171
+
172
+ rigids = Rigid.make_transform_from_reference(
173
+ n_xyz=xyz37[:, :, 0, :],
174
+ ca_xyz=xyz37[:, :, 1, :],
175
+ c_xyz=xyz37[:, :, 2, :],
176
+ eps=1e-9,
177
+ )
178
+
179
+ if not repack_everything:
180
+ xyz37_true = feature_dict["xyz_37"]
181
+ temp_dict_true = {
182
+ "aatype": S_af2,
183
+ "all_atom_positions": xyz37_true,
184
+ "all_atom_mask": masks14_37["atom37_atom_exists"],
185
+ }
186
+ torsion_dict_true = atom37_to_torsion_angles("")(temp_dict_true)
187
+ torsions_true = torch.clone(torsion_dict_true["torsion_angles_sin_cos"])[
188
+ :, :, 3:
189
+ ]
190
+ mask_fix_sc = feature_dict["chain_mask"][:, :, None, None]
191
+ else:
192
+ torsions_true = torch.zeros([B, L, 4, 2], device=device)
193
+ mask_fix_sc = torch.ones([B, L, 1, 1], device=device)
194
+
195
+ random_angle = (
196
+ 2 * torch_pi * torch.rand([S_af2.shape[0], S_af2.shape[1], 4], device=device)
197
+ )
198
+ random_sin_cos = torch.cat(
199
+ [torch.sin(random_angle)[..., None], torch.cos(random_angle)[..., None]], -1
200
+ )
201
+ torsions_noised = torch.clone(torsion_dict["torsion_angles_sin_cos"])
202
+ torsions_noised[:, :, 3:] = random_sin_cos * mask_fix_sc + torsions_true * (
203
+ 1 - mask_fix_sc
204
+ )
205
+ pred_frames = feats.torsion_angles_to_frames(
206
+ rigids,
207
+ torsions_noised,
208
+ S_af2,
209
+ torch.tensor(restype_rigid_group_default_frame, device=device),
210
+ )
211
+
212
+ xyz14_noised = feats.frames_and_literature_positions_to_atom14_pos(
213
+ pred_frames,
214
+ S_af2,
215
+ torch.tensor(restype_rigid_group_default_frame, device=device),
216
+ torch.tensor(restype_atom14_to_rigid_group, device=device).long(),
217
+ torch.tensor(restype_atom14_mask, device=device),
218
+ torch.tensor(restype_atom14_rigid_group_positions, device=device),
219
+ )
220
+
221
+ xyz14_m = masks14_37["atom14_atom_exists"] * mask[:, :, None]
222
+ xyz14_noised = xyz14_noised * xyz14_m[:, :, :, None]
223
+ torsion_dict["xyz14_m"] = xyz14_m
224
+ torsion_dict["xyz14_noised"] = xyz14_noised
225
+ torsion_dict["mask_for_loss"] = mask
226
+ torsion_dict["rigids"] = rigids
227
+ torsion_dict["torsions_noised"] = torsions_noised
228
+ torsion_dict["mask_fix_sc"] = mask_fix_sc
229
+ torsion_dict["torsions_true"] = torsions_true
230
+ torsion_dict["S_af2"] = S_af2
231
+ return torsion_dict
232
+
233
+
234
+ class Packer(nn.Module):
235
+ def __init__(
236
+ self,
237
+ edge_features=128,
238
+ node_features=128,
239
+ num_positional_embeddings=16,
240
+ num_chain_embeddings=16,
241
+ num_rbf=16,
242
+ top_k=30,
243
+ augment_eps=0.0,
244
+ atom37_order=False,
245
+ device=None,
246
+ atom_context_num=16,
247
+ lower_bound=0.0,
248
+ upper_bound=20.0,
249
+ hidden_dim=128,
250
+ num_encoder_layers=3,
251
+ num_decoder_layers=3,
252
+ dropout=0.1,
253
+ num_mix=3,
254
+ ):
255
+ super(Packer, self).__init__()
256
+ self.edge_features = edge_features
257
+ self.node_features = node_features
258
+ self.num_positional_embeddings = num_positional_embeddings
259
+ self.num_chain_embeddings = num_chain_embeddings
260
+ self.num_rbf = num_rbf
261
+ self.top_k = top_k
262
+ self.augment_eps = augment_eps
263
+ self.atom37_order = atom37_order
264
+ self.device = device
265
+ self.atom_context_num = atom_context_num
266
+ self.lower_bound = lower_bound
267
+ self.upper_bound = upper_bound
268
+
269
+ self.hidden_dim = hidden_dim
270
+ self.num_encoder_layers = num_encoder_layers
271
+ self.num_decoder_layers = num_decoder_layers
272
+ self.dropout = dropout
273
+ self.softplus = nn.Softplus(beta=1, threshold=20)
274
+
275
+ self.features = ProteinFeatures(
276
+ edge_features=edge_features,
277
+ node_features=node_features,
278
+ num_positional_embeddings=num_positional_embeddings,
279
+ num_chain_embeddings=num_chain_embeddings,
280
+ num_rbf=num_rbf,
281
+ top_k=top_k,
282
+ augment_eps=augment_eps,
283
+ atom37_order=atom37_order,
284
+ device=device,
285
+ atom_context_num=atom_context_num,
286
+ lower_bound=lower_bound,
287
+ upper_bound=upper_bound,
288
+ )
289
+
290
+ self.W_e = nn.Linear(edge_features, hidden_dim, bias=True)
291
+ self.W_v = nn.Linear(node_features, hidden_dim, bias=True)
292
+ self.W_f = nn.Linear(edge_features, hidden_dim, bias=True)
293
+ self.W_v_sc = nn.Linear(node_features, hidden_dim, bias=True)
294
+ self.linear_down = nn.Linear(2 * hidden_dim, hidden_dim, bias=True)
295
+ self.W_torsions = nn.Linear(hidden_dim, 4 * 3 * num_mix, bias=True)
296
+ self.num_mix = num_mix
297
+
298
+ self.dropout = nn.Dropout(dropout)
299
+
300
+ # Encoder layers
301
+ self.encoder_layers = nn.ModuleList(
302
+ [
303
+ EncLayer(hidden_dim, hidden_dim * 2, dropout=dropout)
304
+ for _ in range(num_encoder_layers)
305
+ ]
306
+ )
307
+
308
+ self.W_c = nn.Linear(hidden_dim, hidden_dim, bias=True)
309
+ self.W_e_context = nn.Linear(hidden_dim, hidden_dim, bias=True)
310
+
311
+ self.W_nodes_y = nn.Linear(hidden_dim, hidden_dim, bias=True)
312
+ self.W_edges_y = nn.Linear(hidden_dim, hidden_dim, bias=True)
313
+
314
+ self.context_encoder_layers = nn.ModuleList(
315
+ [DecLayer(hidden_dim, hidden_dim * 2, dropout=dropout) for _ in range(2)]
316
+ )
317
+
318
+ self.V_C = nn.Linear(hidden_dim, hidden_dim, bias=False)
319
+ self.V_C_norm = nn.LayerNorm(hidden_dim)
320
+ self.y_context_encoder_layers = nn.ModuleList(
321
+ [DecLayerJ(hidden_dim, hidden_dim, dropout=dropout) for _ in range(2)]
322
+ )
323
+
324
+ self.h_V_C_dropout = nn.Dropout(dropout)
325
+
326
+ # Decoder layers
327
+ self.decoder_layers = nn.ModuleList(
328
+ [
329
+ DecLayer(hidden_dim, hidden_dim * 3, dropout=dropout)
330
+ for _ in range(num_decoder_layers)
331
+ ]
332
+ )
333
+
334
+ for p in self.parameters():
335
+ if p.dim() > 1:
336
+ nn.init.xavier_uniform_(p)
337
+
338
+ def encode(self, feature_dict):
339
+ mask = feature_dict["mask"]
340
+ V, E, E_idx, Y_nodes, Y_edges, E_context, Y_m = self.features.features_encode(
341
+ feature_dict
342
+ )
343
+
344
+ h_E_context = self.W_e_context(E_context)
345
+ h_V = self.W_v(V)
346
+ h_E = self.W_e(E)
347
+ mask_attend = gather_nodes(mask.unsqueeze(-1), E_idx).squeeze(-1)
348
+ mask_attend = mask.unsqueeze(-1) * mask_attend
349
+ for layer in self.encoder_layers:
350
+ h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
351
+
352
+ h_V_C = self.W_c(h_V)
353
+ Y_m_edges = Y_m[:, :, :, None] * Y_m[:, :, None, :]
354
+ Y_nodes = self.W_nodes_y(Y_nodes)
355
+ Y_edges = self.W_edges_y(Y_edges)
356
+ for i in range(len(self.context_encoder_layers)):
357
+ Y_nodes = self.y_context_encoder_layers[i](Y_nodes, Y_edges, Y_m, Y_m_edges)
358
+ h_E_context_cat = torch.cat([h_E_context, Y_nodes], -1)
359
+ h_V_C = self.context_encoder_layers[i](h_V_C, h_E_context_cat, mask, Y_m)
360
+
361
+ h_V_C = self.V_C(h_V_C)
362
+ h_V = h_V + self.V_C_norm(self.h_V_C_dropout(h_V_C))
363
+
364
+ return h_V, h_E, E_idx
365
+
366
+ def decode(self, feature_dict):
367
+ h_V = feature_dict["h_V"]
368
+ h_E = feature_dict["h_E"]
369
+ E_idx = feature_dict["E_idx"]
370
+ mask = feature_dict["mask"]
371
+ device = h_V.device
372
+ V, F = self.features.features_decode(feature_dict)
373
+
374
+ h_F = self.W_f(F)
375
+ h_EF = torch.cat([h_E, h_F], -1)
376
+
377
+ h_V_sc = self.W_v_sc(V)
378
+ h_V_combined = torch.cat([h_V, h_V_sc], -1)
379
+ h_V = self.linear_down(h_V_combined)
380
+
381
+ for layer in self.decoder_layers:
382
+ h_EV = cat_neighbors_nodes(h_V, h_EF, E_idx)
383
+ h_V = layer(h_V, h_EV, mask)
384
+
385
+ torsions = self.W_torsions(h_V)
386
+ torsions = torsions.reshape(h_V.shape[0], h_V.shape[1], 4, self.num_mix, 3)
387
+ mean = torsions[:, :, :, :, 0].float()
388
+ concentration = 0.1 + self.softplus(torsions[:, :, :, :, 1]).float()
389
+ mix_logits = torsions[:, :, :, :, 2].float()
390
+ return mean, concentration, mix_logits
391
+
392
+
393
+ class ProteinFeatures(nn.Module):
394
+ def __init__(
395
+ self,
396
+ edge_features=128,
397
+ node_features=128,
398
+ num_positional_embeddings=16,
399
+ num_chain_embeddings=16,
400
+ num_rbf=16,
401
+ top_k=30,
402
+ augment_eps=0.0,
403
+ atom37_order=False,
404
+ device=None,
405
+ atom_context_num=16,
406
+ lower_bound=0.0,
407
+ upper_bound=20.0,
408
+ ):
409
+ """Extract protein features"""
410
+ super(ProteinFeatures, self).__init__()
411
+ self.edge_features = edge_features
412
+ self.node_features = node_features
413
+ self.num_positional_embeddings = num_positional_embeddings
414
+ self.num_chain_embeddings = num_chain_embeddings
415
+ self.num_rbf = num_rbf
416
+ self.top_k = top_k
417
+ self.augment_eps = augment_eps
418
+ self.atom37_order = atom37_order
419
+ self.device = device
420
+ self.atom_context_num = atom_context_num
421
+ self.lower_bound = lower_bound
422
+ self.upper_bound = upper_bound
423
+
424
+ # deal with oxygen index
425
+ # ------
426
+ self.N_idx = 0
427
+ self.CA_idx = 1
428
+ self.C_idx = 2
429
+
430
+ if atom37_order:
431
+ self.O_idx = 4
432
+ else:
433
+ self.O_idx = 3
434
+ # -------
435
+ self.positional_embeddings = PositionalEncodings(num_positional_embeddings)
436
+
437
+ # Features for the encoder
438
+ enc_node_in = 21 # alphabet for the sequence
439
+ enc_edge_in = (
440
+ num_positional_embeddings + num_rbf * 25
441
+ ) # positional + distance features
442
+
443
+ self.enc_node_in = enc_node_in
444
+ self.enc_edge_in = enc_edge_in
445
+
446
+ self.enc_edge_embedding = nn.Linear(enc_edge_in, edge_features, bias=False)
447
+ self.enc_norm_edges = nn.LayerNorm(edge_features)
448
+ self.enc_node_embedding = nn.Linear(enc_node_in, node_features, bias=False)
449
+ self.enc_norm_nodes = nn.LayerNorm(node_features)
450
+
451
+ # Features for the decoder
452
+ dec_node_in = 14 * atom_context_num * num_rbf
453
+ dec_edge_in = num_rbf * 14 * 14 + 42
454
+
455
+ self.dec_node_in = dec_node_in
456
+ self.dec_edge_in = dec_edge_in
457
+
458
+ self.W_XY_project_down1 = nn.Linear(num_rbf + 120, num_rbf, bias=True)
459
+ self.dec_edge_embedding1 = nn.Linear(dec_edge_in, edge_features, bias=False)
460
+ self.dec_norm_edges1 = nn.LayerNorm(edge_features)
461
+ self.dec_node_embedding1 = nn.Linear(dec_node_in, node_features, bias=False)
462
+ self.dec_norm_nodes1 = nn.LayerNorm(node_features)
463
+
464
+ self.node_project_down = nn.Linear(
465
+ 5 * num_rbf + 64 + 4, node_features, bias=True
466
+ )
467
+ self.norm_nodes = nn.LayerNorm(node_features)
468
+
469
+ self.type_linear = nn.Linear(147, 64)
470
+
471
+ self.y_nodes = nn.Linear(147, node_features, bias=False)
472
+ self.y_edges = nn.Linear(num_rbf, node_features, bias=False)
473
+
474
+ self.norm_y_edges = nn.LayerNorm(node_features)
475
+ self.norm_y_nodes = nn.LayerNorm(node_features)
476
+
477
+ self.periodic_table_features = torch.tensor(
478
+ [
479
+ [
480
+ 0,
481
+ 1,
482
+ 2,
483
+ 3,
484
+ 4,
485
+ 5,
486
+ 6,
487
+ 7,
488
+ 8,
489
+ 9,
490
+ 10,
491
+ 11,
492
+ 12,
493
+ 13,
494
+ 14,
495
+ 15,
496
+ 16,
497
+ 17,
498
+ 18,
499
+ 19,
500
+ 20,
501
+ 21,
502
+ 22,
503
+ 23,
504
+ 24,
505
+ 25,
506
+ 26,
507
+ 27,
508
+ 28,
509
+ 29,
510
+ 30,
511
+ 31,
512
+ 32,
513
+ 33,
514
+ 34,
515
+ 35,
516
+ 36,
517
+ 37,
518
+ 38,
519
+ 39,
520
+ 40,
521
+ 41,
522
+ 42,
523
+ 43,
524
+ 44,
525
+ 45,
526
+ 46,
527
+ 47,
528
+ 48,
529
+ 49,
530
+ 50,
531
+ 51,
532
+ 52,
533
+ 53,
534
+ 54,
535
+ 55,
536
+ 56,
537
+ 57,
538
+ 58,
539
+ 59,
540
+ 60,
541
+ 61,
542
+ 62,
543
+ 63,
544
+ 64,
545
+ 65,
546
+ 66,
547
+ 67,
548
+ 68,
549
+ 69,
550
+ 70,
551
+ 71,
552
+ 72,
553
+ 73,
554
+ 74,
555
+ 75,
556
+ 76,
557
+ 77,
558
+ 78,
559
+ 79,
560
+ 80,
561
+ 81,
562
+ 82,
563
+ 83,
564
+ 84,
565
+ 85,
566
+ 86,
567
+ 87,
568
+ 88,
569
+ 89,
570
+ 90,
571
+ 91,
572
+ 92,
573
+ 93,
574
+ 94,
575
+ 95,
576
+ 96,
577
+ 97,
578
+ 98,
579
+ 99,
580
+ 100,
581
+ 101,
582
+ 102,
583
+ 103,
584
+ 104,
585
+ 105,
586
+ 106,
587
+ 107,
588
+ 108,
589
+ 109,
590
+ 110,
591
+ 111,
592
+ 112,
593
+ 113,
594
+ 114,
595
+ 115,
596
+ 116,
597
+ 117,
598
+ 118,
599
+ ],
600
+ [
601
+ 0,
602
+ 1,
603
+ 18,
604
+ 1,
605
+ 2,
606
+ 13,
607
+ 14,
608
+ 15,
609
+ 16,
610
+ 17,
611
+ 18,
612
+ 1,
613
+ 2,
614
+ 13,
615
+ 14,
616
+ 15,
617
+ 16,
618
+ 17,
619
+ 18,
620
+ 1,
621
+ 2,
622
+ 3,
623
+ 4,
624
+ 5,
625
+ 6,
626
+ 7,
627
+ 8,
628
+ 9,
629
+ 10,
630
+ 11,
631
+ 12,
632
+ 13,
633
+ 14,
634
+ 15,
635
+ 16,
636
+ 17,
637
+ 18,
638
+ 1,
639
+ 2,
640
+ 3,
641
+ 4,
642
+ 5,
643
+ 6,
644
+ 7,
645
+ 8,
646
+ 9,
647
+ 10,
648
+ 11,
649
+ 12,
650
+ 13,
651
+ 14,
652
+ 15,
653
+ 16,
654
+ 17,
655
+ 18,
656
+ 1,
657
+ 2,
658
+ 3,
659
+ 3,
660
+ 3,
661
+ 3,
662
+ 3,
663
+ 3,
664
+ 3,
665
+ 3,
666
+ 3,
667
+ 3,
668
+ 3,
669
+ 3,
670
+ 3,
671
+ 3,
672
+ 3,
673
+ 4,
674
+ 5,
675
+ 6,
676
+ 7,
677
+ 8,
678
+ 9,
679
+ 10,
680
+ 11,
681
+ 12,
682
+ 13,
683
+ 14,
684
+ 15,
685
+ 16,
686
+ 17,
687
+ 18,
688
+ 1,
689
+ 2,
690
+ 3,
691
+ 3,
692
+ 3,
693
+ 3,
694
+ 3,
695
+ 3,
696
+ 3,
697
+ 3,
698
+ 3,
699
+ 3,
700
+ 3,
701
+ 3,
702
+ 3,
703
+ 3,
704
+ 3,
705
+ 4,
706
+ 5,
707
+ 6,
708
+ 7,
709
+ 8,
710
+ 9,
711
+ 10,
712
+ 11,
713
+ 12,
714
+ 13,
715
+ 14,
716
+ 15,
717
+ 16,
718
+ 17,
719
+ 18,
720
+ ],
721
+ [
722
+ 0,
723
+ 1,
724
+ 1,
725
+ 2,
726
+ 2,
727
+ 2,
728
+ 2,
729
+ 2,
730
+ 2,
731
+ 2,
732
+ 2,
733
+ 3,
734
+ 3,
735
+ 3,
736
+ 3,
737
+ 3,
738
+ 3,
739
+ 3,
740
+ 3,
741
+ 4,
742
+ 4,
743
+ 4,
744
+ 4,
745
+ 4,
746
+ 4,
747
+ 4,
748
+ 4,
749
+ 4,
750
+ 4,
751
+ 4,
752
+ 4,
753
+ 4,
754
+ 4,
755
+ 4,
756
+ 4,
757
+ 4,
758
+ 4,
759
+ 5,
760
+ 5,
761
+ 5,
762
+ 5,
763
+ 5,
764
+ 5,
765
+ 5,
766
+ 5,
767
+ 5,
768
+ 5,
769
+ 5,
770
+ 5,
771
+ 5,
772
+ 5,
773
+ 5,
774
+ 5,
775
+ 5,
776
+ 5,
777
+ 6,
778
+ 6,
779
+ 6,
780
+ 6,
781
+ 6,
782
+ 6,
783
+ 6,
784
+ 6,
785
+ 6,
786
+ 6,
787
+ 6,
788
+ 6,
789
+ 6,
790
+ 6,
791
+ 6,
792
+ 6,
793
+ 6,
794
+ 6,
795
+ 6,
796
+ 6,
797
+ 6,
798
+ 6,
799
+ 6,
800
+ 6,
801
+ 6,
802
+ 6,
803
+ 6,
804
+ 6,
805
+ 6,
806
+ 6,
807
+ 6,
808
+ 6,
809
+ 7,
810
+ 7,
811
+ 7,
812
+ 7,
813
+ 7,
814
+ 7,
815
+ 7,
816
+ 7,
817
+ 7,
818
+ 7,
819
+ 7,
820
+ 7,
821
+ 7,
822
+ 7,
823
+ 7,
824
+ 7,
825
+ 7,
826
+ 7,
827
+ 7,
828
+ 7,
829
+ 7,
830
+ 7,
831
+ 7,
832
+ 7,
833
+ 7,
834
+ 7,
835
+ 7,
836
+ 7,
837
+ 7,
838
+ 7,
839
+ 7,
840
+ 7,
841
+ ],
842
+ ],
843
+ dtype=torch.long,
844
+ device=device,
845
+ )
846
+
847
+ def _dist(self, X, mask, eps=1e-6):
848
+ mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2)
849
+ dX = torch.unsqueeze(X, 1) - torch.unsqueeze(X, 2)
850
+ D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + eps)
851
+ D_max, _ = torch.max(D, -1, keepdim=True)
852
+ D_adjust = D + (1.0 - mask_2D) * D_max
853
+ sampled_top_k = self.top_k
854
+ D_neighbors, E_idx = torch.topk(
855
+ D_adjust, np.minimum(self.top_k, X.shape[1]), dim=-1, largest=False
856
+ )
857
+ return D_neighbors, E_idx
858
+
859
+ def _make_angle_features(self, A, B, C, Y):
860
+ v1 = A - B
861
+ v2 = C - B
862
+ e1 = torch.nn.functional.normalize(v1, dim=-1)
863
+ e1_v2_dot = torch.einsum("bli, bli -> bl", e1, v2)[..., None]
864
+ u2 = v2 - e1 * e1_v2_dot
865
+ e2 = torch.nn.functional.normalize(u2, dim=-1)
866
+ e3 = torch.cross(e1, e2, dim=-1)
867
+ R_residue = torch.cat(
868
+ (e1[:, :, :, None], e2[:, :, :, None], e3[:, :, :, None]), dim=-1
869
+ )
870
+
871
+ local_vectors = torch.einsum(
872
+ "blqp, blyq -> blyp", R_residue, Y - B[:, :, None, :]
873
+ )
874
+
875
+ rxy = torch.sqrt(local_vectors[..., 0] ** 2 + local_vectors[..., 1] ** 2 + 1e-8)
876
+ f1 = local_vectors[..., 0] / rxy
877
+ f2 = local_vectors[..., 1] / rxy
878
+ rxyz = torch.norm(local_vectors, dim=-1) + 1e-8
879
+ f3 = rxy / rxyz
880
+ f4 = local_vectors[..., 2] / rxyz
881
+
882
+ f = torch.cat([f1[..., None], f2[..., None], f3[..., None], f4[..., None]], -1)
883
+ return f
884
+
885
+ def _rbf(
886
+ self,
887
+ D,
888
+ D_mu_shape=[1, 1, 1, -1],
889
+ lower_bound=0.0,
890
+ upper_bound=20.0,
891
+ num_bins=16,
892
+ ):
893
+ device = D.device
894
+ D_min, D_max, D_count = lower_bound, upper_bound, num_bins
895
+ D_mu = torch.linspace(D_min, D_max, D_count, device=device)
896
+ D_mu = D_mu.view(D_mu_shape)
897
+ D_sigma = (D_max - D_min) / D_count
898
+ D_expand = torch.unsqueeze(D, -1)
899
+ RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2))
900
+ return RBF
901
+
902
+ def _get_rbf(
903
+ self,
904
+ A,
905
+ B,
906
+ E_idx,
907
+ D_mu_shape=[1, 1, 1, -1],
908
+ lower_bound=2.0,
909
+ upper_bound=22.0,
910
+ num_bins=16,
911
+ ):
912
+ D_A_B = torch.sqrt(
913
+ torch.sum((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6
914
+ ) # [B, L, L]
915
+ D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[
916
+ :, :, :, 0
917
+ ] # [B,L,K]
918
+ RBF_A_B = self._rbf(
919
+ D_A_B_neighbors,
920
+ D_mu_shape=D_mu_shape,
921
+ lower_bound=lower_bound,
922
+ upper_bound=upper_bound,
923
+ num_bins=num_bins,
924
+ )
925
+ return RBF_A_B
926
+
927
+ def features_encode(self, features):
928
+ """
929
+ make protein graph and encode backbone
930
+ """
931
+ S = features["S"]
932
+ X = features["X"]
933
+ Y = features["Y"]
934
+ Y_m = features["Y_m"]
935
+ Y_t = features["Y_t"]
936
+ mask = features["mask"]
937
+ R_idx = features["R_idx"]
938
+ chain_labels = features["chain_labels"]
939
+
940
+ if self.training and self.augment_eps > 0:
941
+ X = X + self.augment_eps * torch.randn_like(X)
942
+
943
+ Ca = X[:, :, self.CA_idx, :]
944
+ N = X[:, :, self.N_idx, :]
945
+ C = X[:, :, self.C_idx, :]
946
+ O = X[:, :, self.O_idx, :]
947
+
948
+ b = Ca - N
949
+ c = C - Ca
950
+ a = torch.cross(b, c, dim=-1)
951
+ Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca # shift from CA
952
+
953
+ _, E_idx = self._dist(Ca, mask)
954
+
955
+ backbone_coords_list = [N, Ca, C, O, Cb]
956
+
957
+ RBF_all = []
958
+ for atom_1 in backbone_coords_list:
959
+ for atom_2 in backbone_coords_list:
960
+ RBF_all.append(
961
+ self._get_rbf(
962
+ atom_1,
963
+ atom_2,
964
+ E_idx,
965
+ D_mu_shape=[1, 1, 1, -1],
966
+ lower_bound=self.lower_bound,
967
+ upper_bound=self.upper_bound,
968
+ num_bins=self.num_rbf,
969
+ )
970
+ )
971
+ RBF_all = torch.cat(tuple(RBF_all), dim=-1)
972
+
973
+ offset = R_idx[:, :, None] - R_idx[:, None, :]
974
+ offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K]
975
+
976
+ d_chains = (
977
+ (chain_labels[:, :, None] - chain_labels[:, None, :]) == 0
978
+ ).long() # find self vs non-self interaction
979
+ E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0]
980
+ E_positional = self.positional_embeddings(offset.long(), E_chains)
981
+ E = torch.cat((E_positional, RBF_all), -1)
982
+ E = self.enc_edge_embedding(E)
983
+ E = self.enc_norm_edges(E)
984
+
985
+ V = torch.nn.functional.one_hot(S, self.enc_node_in).float()
986
+ V = self.enc_node_embedding(V)
987
+ V = self.enc_norm_nodes(V)
988
+
989
+ Y_t = Y_t.long()
990
+ Y_t_g = self.periodic_table_features[1][Y_t] # group; 19 categories including 0
991
+ Y_t_p = self.periodic_table_features[2][Y_t] # period; 8 categories including 0
992
+
993
+ Y_t_g_1hot_ = torch.nn.functional.one_hot(Y_t_g, 19) # [B, L, M, 19]
994
+ Y_t_p_1hot_ = torch.nn.functional.one_hot(Y_t_p, 8) # [B, L, M, 8]
995
+ Y_t_1hot_ = torch.nn.functional.one_hot(Y_t, 120) # [B, L, M, 120]
996
+
997
+ Y_t_1hot_ = torch.cat(
998
+ [Y_t_1hot_, Y_t_g_1hot_, Y_t_p_1hot_], -1
999
+ ) # [B, L, M, 147]
1000
+ Y_t_1hot = self.type_linear(Y_t_1hot_.float())
1001
+
1002
+ D_N_Y = torch.sqrt(
1003
+ torch.sum((N[:, :, None, :] - Y) ** 2, -1) + 1e-6
1004
+ ) # [B, L, M, num_bins]
1005
+ D_N_Y = self._rbf(
1006
+ D_N_Y,
1007
+ D_mu_shape=[1, 1, 1, -1],
1008
+ lower_bound=self.lower_bound,
1009
+ upper_bound=self.upper_bound,
1010
+ num_bins=self.num_rbf,
1011
+ )
1012
+
1013
+ D_Ca_Y = torch.sqrt(
1014
+ torch.sum((Ca[:, :, None, :] - Y) ** 2, -1) + 1e-6
1015
+ ) # [B, L, M, num_bins]
1016
+ D_Ca_Y = self._rbf(
1017
+ D_Ca_Y,
1018
+ D_mu_shape=[1, 1, 1, -1],
1019
+ lower_bound=self.lower_bound,
1020
+ upper_bound=self.upper_bound,
1021
+ num_bins=self.num_rbf,
1022
+ )
1023
+
1024
+ D_C_Y = torch.sqrt(
1025
+ torch.sum((C[:, :, None, :] - Y) ** 2, -1) + 1e-6
1026
+ ) # [B, L, M, num_bins]
1027
+ D_C_Y = self._rbf(
1028
+ D_C_Y,
1029
+ D_mu_shape=[1, 1, 1, -1],
1030
+ lower_bound=self.lower_bound,
1031
+ upper_bound=self.upper_bound,
1032
+ num_bins=self.num_rbf,
1033
+ )
1034
+
1035
+ D_O_Y = torch.sqrt(
1036
+ torch.sum((O[:, :, None, :] - Y) ** 2, -1) + 1e-6
1037
+ ) # [B, L, M, num_bins]
1038
+ D_O_Y = self._rbf(
1039
+ D_O_Y,
1040
+ D_mu_shape=[1, 1, 1, -1],
1041
+ lower_bound=self.lower_bound,
1042
+ upper_bound=self.upper_bound,
1043
+ num_bins=self.num_rbf,
1044
+ )
1045
+
1046
+ D_Cb_Y = torch.sqrt(
1047
+ torch.sum((Cb[:, :, None, :] - Y) ** 2, -1) + 1e-6
1048
+ ) # [B, L, M, num_bins]
1049
+ D_Cb_Y = self._rbf(
1050
+ D_Cb_Y,
1051
+ D_mu_shape=[1, 1, 1, -1],
1052
+ lower_bound=self.lower_bound,
1053
+ upper_bound=self.upper_bound,
1054
+ num_bins=self.num_rbf,
1055
+ )
1056
+
1057
+ f_angles = self._make_angle_features(N, Ca, C, Y)
1058
+
1059
+ D_all = torch.cat(
1060
+ (D_N_Y, D_Ca_Y, D_C_Y, D_O_Y, D_Cb_Y, Y_t_1hot, f_angles), dim=-1
1061
+ ) # [B,L,M,5*num_bins+5]
1062
+ E_context = self.node_project_down(D_all) # [B, L, M, node_features]
1063
+ E_context = self.norm_nodes(E_context)
1064
+
1065
+ Y_edges = self._rbf(
1066
+ torch.sqrt(
1067
+ torch.sum((Y[:, :, :, None, :] - Y[:, :, None, :, :]) ** 2, -1) + 1e-6
1068
+ )
1069
+ ) # [B, L, M, M, num_bins]
1070
+
1071
+ Y_edges = self.y_edges(Y_edges)
1072
+ Y_nodes = self.y_nodes(Y_t_1hot_.float())
1073
+
1074
+ Y_edges = self.norm_y_edges(Y_edges)
1075
+ Y_nodes = self.norm_y_nodes(Y_nodes)
1076
+
1077
+ return V, E, E_idx, Y_nodes, Y_edges, E_context, Y_m
1078
+
1079
+ def features_decode(self, features):
1080
+ """
1081
+ Make features for decoding. Explicit side chain atom and other atom distances.
1082
+ """
1083
+
1084
+ S = features["S"]
1085
+ X = features["X"]
1086
+ X_m = features["X_m"]
1087
+ mask = features["mask"]
1088
+ E_idx = features["E_idx"]
1089
+
1090
+ Y = features["Y"][:, :, : self.atom_context_num]
1091
+ Y_m = features["Y_m"][:, :, : self.atom_context_num]
1092
+ Y_t = features["Y_t"][:, :, : self.atom_context_num]
1093
+
1094
+ X_m = X_m * mask[:, :, None]
1095
+ device = S.device
1096
+
1097
+ B, L, _, _ = X.shape
1098
+
1099
+ RBF_sidechain = []
1100
+ X_m_gathered = gather_nodes(X_m, E_idx) # [B, L, K, 14]
1101
+
1102
+ for i in range(14):
1103
+ for j in range(14):
1104
+ rbf_features = self._get_rbf(
1105
+ X[:, :, i, :],
1106
+ X[:, :, j, :],
1107
+ E_idx,
1108
+ D_mu_shape=[1, 1, 1, -1],
1109
+ lower_bound=self.lower_bound,
1110
+ upper_bound=self.upper_bound,
1111
+ num_bins=self.num_rbf,
1112
+ )
1113
+ rbf_features = (
1114
+ rbf_features
1115
+ * X_m[:, :, i, None, None]
1116
+ * X_m_gathered[:, :, :, j, None]
1117
+ )
1118
+ RBF_sidechain.append(rbf_features)
1119
+
1120
+ D_XY = torch.sqrt(
1121
+ torch.sum((X[:, :, :, None, :] - Y[:, :, None, :, :]) ** 2, -1) + 1e-6
1122
+ ) # [B, L, 14, atom_context_num]
1123
+ XY_features = self._rbf(
1124
+ D_XY,
1125
+ D_mu_shape=[1, 1, 1, 1, -1],
1126
+ lower_bound=self.lower_bound,
1127
+ upper_bound=self.upper_bound,
1128
+ num_bins=self.num_rbf,
1129
+ ) # [B, L, 14, atom_context_num, num_rbf]
1130
+ XY_features = XY_features * X_m[:, :, :, None, None] * Y_m[:, :, None, :, None]
1131
+
1132
+ Y_t_1hot = torch.nn.functional.one_hot(
1133
+ Y_t.long(), 120
1134
+ ).float() # [B, L, atom_context_num, 120]
1135
+ XY_Y_t = torch.cat(
1136
+ [XY_features, Y_t_1hot[:, :, None, :, :].repeat(1, 1, 14, 1, 1)], -1
1137
+ ) # [B, L, 14, atom_context_num, num_rbf+120]
1138
+ XY_Y_t = self.W_XY_project_down1(
1139
+ XY_Y_t
1140
+ ) # [B, L, 14, atom_context_num, num_rbf]
1141
+ XY_features = XY_Y_t.view([B, L, -1])
1142
+
1143
+ V = self.dec_node_embedding1(XY_features)
1144
+ V = self.dec_norm_nodes1(V)
1145
+
1146
+ S_1h = torch.nn.functional.one_hot(S, self.enc_node_in).float()
1147
+ S_1h_gathered = gather_nodes(S_1h, E_idx) # [B, L, K, 21]
1148
+ S_features = torch.cat(
1149
+ [S_1h[:, :, None, :].repeat(1, 1, E_idx.shape[2], 1), S_1h_gathered], -1
1150
+ ) # [B, L, K, 42]
1151
+
1152
+ F = torch.cat(
1153
+ tuple(RBF_sidechain), dim=-1
1154
+ ) # [B,L,atom_context_num,14*14*num_rbf]
1155
+ F = torch.cat([F, S_features], -1)
1156
+ F = self.dec_edge_embedding1(F)
1157
+ F = self.dec_norm_edges1(F)
1158
+ return V, F
score.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os.path
4
+ import random
5
+ import sys
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from data_utils import (
11
+ element_dict_rev,
12
+ alphabet,
13
+ restype_int_to_str,
14
+ featurize,
15
+ parse_PDB,
16
+ )
17
+ from model_utils import ProteinMPNN
18
+
19
+
20
+ def main(args) -> None:
21
+ """
22
+ Inference function
23
+ """
24
+ if args.seed:
25
+ seed = args.seed
26
+ else:
27
+ seed = int(np.random.randint(0, high=99999, size=1, dtype=int)[0])
28
+ torch.manual_seed(seed)
29
+ random.seed(seed)
30
+ np.random.seed(seed)
31
+ device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
32
+ folder_for_outputs = args.out_folder
33
+ base_folder = folder_for_outputs
34
+ if base_folder[-1] != "/":
35
+ base_folder = base_folder + "/"
36
+ if not os.path.exists(base_folder):
37
+ os.makedirs(base_folder, exist_ok=True)
38
+ if args.model_type == "protein_mpnn":
39
+ checkpoint_path = args.checkpoint_protein_mpnn
40
+ elif args.model_type == "ligand_mpnn":
41
+ checkpoint_path = args.checkpoint_ligand_mpnn
42
+ elif args.model_type == "per_residue_label_membrane_mpnn":
43
+ checkpoint_path = args.checkpoint_per_residue_label_membrane_mpnn
44
+ elif args.model_type == "global_label_membrane_mpnn":
45
+ checkpoint_path = args.checkpoint_global_label_membrane_mpnn
46
+ elif args.model_type == "soluble_mpnn":
47
+ checkpoint_path = args.checkpoint_soluble_mpnn
48
+ else:
49
+ print("Choose one of the available models")
50
+ sys.exit()
51
+ checkpoint = torch.load(checkpoint_path, map_location=device)
52
+ if args.model_type == "ligand_mpnn":
53
+ atom_context_num = checkpoint["atom_context_num"]
54
+ ligand_mpnn_use_side_chain_context = args.ligand_mpnn_use_side_chain_context
55
+ k_neighbors = checkpoint["num_edges"]
56
+ else:
57
+ atom_context_num = 1
58
+ ligand_mpnn_use_side_chain_context = 0
59
+ k_neighbors = checkpoint["num_edges"]
60
+
61
+ model = ProteinMPNN(
62
+ node_features=128,
63
+ edge_features=128,
64
+ hidden_dim=128,
65
+ num_encoder_layers=3,
66
+ num_decoder_layers=3,
67
+ k_neighbors=k_neighbors,
68
+ device=device,
69
+ atom_context_num=atom_context_num,
70
+ model_type=args.model_type,
71
+ ligand_mpnn_use_side_chain_context=ligand_mpnn_use_side_chain_context,
72
+ )
73
+
74
+ model.load_state_dict(checkpoint["model_state_dict"])
75
+ model.to(device)
76
+ model.eval()
77
+
78
+ if args.pdb_path_multi:
79
+ with open(args.pdb_path_multi, "r") as fh:
80
+ pdb_paths = list(json.load(fh))
81
+ else:
82
+ pdb_paths = [args.pdb_path]
83
+
84
+ if args.fixed_residues_multi:
85
+ with open(args.fixed_residues_multi, "r") as fh:
86
+ fixed_residues_multi = json.load(fh)
87
+ else:
88
+ fixed_residues = [item for item in args.fixed_residues.split()]
89
+ fixed_residues_multi = {}
90
+ for pdb in pdb_paths:
91
+ fixed_residues_multi[pdb] = fixed_residues
92
+
93
+ if args.redesigned_residues_multi:
94
+ with open(args.redesigned_residues_multi, "r") as fh:
95
+ redesigned_residues_multi = json.load(fh)
96
+ else:
97
+ redesigned_residues = [item for item in args.redesigned_residues.split()]
98
+ redesigned_residues_multi = {}
99
+ for pdb in pdb_paths:
100
+ redesigned_residues_multi[pdb] = redesigned_residues
101
+
102
+ # loop over PDB paths
103
+ for pdb in pdb_paths:
104
+ if args.verbose:
105
+ print("Designing protein from this path:", pdb)
106
+ fixed_residues = fixed_residues_multi[pdb]
107
+ redesigned_residues = redesigned_residues_multi[pdb]
108
+ protein_dict, backbone, other_atoms, icodes, _ = parse_PDB(
109
+ pdb,
110
+ device=device,
111
+ chains=args.parse_these_chains_only,
112
+ parse_all_atoms=args.ligand_mpnn_use_side_chain_context,
113
+ parse_atoms_with_zero_occupancy=args.parse_atoms_with_zero_occupancy
114
+ )
115
+ # make chain_letter + residue_idx + insertion_code mapping to integers
116
+ R_idx_list = list(protein_dict["R_idx"].cpu().numpy()) # residue indices
117
+ chain_letters_list = list(protein_dict["chain_letters"]) # chain letters
118
+ encoded_residues = []
119
+ for i, R_idx_item in enumerate(R_idx_list):
120
+ tmp = str(chain_letters_list[i]) + str(R_idx_item) + icodes[i]
121
+ encoded_residues.append(tmp)
122
+ encoded_residue_dict = dict(zip(encoded_residues, range(len(encoded_residues))))
123
+ encoded_residue_dict_rev = dict(
124
+ zip(list(range(len(encoded_residues))), encoded_residues)
125
+ )
126
+
127
+ fixed_positions = torch.tensor(
128
+ [int(item not in fixed_residues) for item in encoded_residues],
129
+ device=device,
130
+ )
131
+ redesigned_positions = torch.tensor(
132
+ [int(item not in redesigned_residues) for item in encoded_residues],
133
+ device=device,
134
+ )
135
+
136
+ # specify which residues are buried for checkpoint_per_residue_label_membrane_mpnn model
137
+ if args.transmembrane_buried:
138
+ buried_residues = [item for item in args.transmembrane_buried.split()]
139
+ buried_positions = torch.tensor(
140
+ [int(item in buried_residues) for item in encoded_residues],
141
+ device=device,
142
+ )
143
+ else:
144
+ buried_positions = torch.zeros_like(fixed_positions)
145
+
146
+ if args.transmembrane_interface:
147
+ interface_residues = [item for item in args.transmembrane_interface.split()]
148
+ interface_positions = torch.tensor(
149
+ [int(item in interface_residues) for item in encoded_residues],
150
+ device=device,
151
+ )
152
+ else:
153
+ interface_positions = torch.zeros_like(fixed_positions)
154
+ protein_dict["membrane_per_residue_labels"] = 2 * buried_positions * (
155
+ 1 - interface_positions
156
+ ) + 1 * interface_positions * (1 - buried_positions)
157
+
158
+ if args.model_type == "global_label_membrane_mpnn":
159
+ protein_dict["membrane_per_residue_labels"] = (
160
+ args.global_transmembrane_label + 0 * fixed_positions
161
+ )
162
+ if type(args.chains_to_design) == str:
163
+ chains_to_design_list = args.chains_to_design.split(",")
164
+ else:
165
+ chains_to_design_list = protein_dict["chain_letters"]
166
+ chain_mask = torch.tensor(
167
+ np.array(
168
+ [
169
+ item in chains_to_design_list
170
+ for item in protein_dict["chain_letters"]
171
+ ],
172
+ dtype=np.int32,
173
+ ),
174
+ device=device,
175
+ )
176
+
177
+ # create chain_mask to notify which residues are fixed (0) and which need to be designed (1)
178
+ if redesigned_residues:
179
+ protein_dict["chain_mask"] = chain_mask * (1 - redesigned_positions)
180
+ elif fixed_residues:
181
+ protein_dict["chain_mask"] = chain_mask * fixed_positions
182
+ else:
183
+ protein_dict["chain_mask"] = chain_mask
184
+
185
+ if args.verbose:
186
+ PDB_residues_to_be_redesigned = [
187
+ encoded_residue_dict_rev[item]
188
+ for item in range(protein_dict["chain_mask"].shape[0])
189
+ if protein_dict["chain_mask"][item] == 1
190
+ ]
191
+ PDB_residues_to_be_fixed = [
192
+ encoded_residue_dict_rev[item]
193
+ for item in range(protein_dict["chain_mask"].shape[0])
194
+ if protein_dict["chain_mask"][item] == 0
195
+ ]
196
+ print("These residues will be redesigned: ", PDB_residues_to_be_redesigned)
197
+ print("These residues will be fixed: ", PDB_residues_to_be_fixed)
198
+
199
+ # specify which residues are linked
200
+ if args.symmetry_residues:
201
+ symmetry_residues_list_of_lists = [
202
+ x.split(",") for x in args.symmetry_residues.split("|")
203
+ ]
204
+ remapped_symmetry_residues = []
205
+ for t_list in symmetry_residues_list_of_lists:
206
+ tmp_list = []
207
+ for t in t_list:
208
+ tmp_list.append(encoded_residue_dict[t])
209
+ remapped_symmetry_residues.append(tmp_list)
210
+ else:
211
+ remapped_symmetry_residues = [[]]
212
+
213
+ if args.homo_oligomer:
214
+ if args.verbose:
215
+ print("Designing HOMO-OLIGOMER")
216
+ chain_letters_set = list(set(chain_letters_list))
217
+ reference_chain = chain_letters_set[0]
218
+ lc = len(reference_chain)
219
+ residue_indices = [
220
+ item[lc:] for item in encoded_residues if item[:lc] == reference_chain
221
+ ]
222
+ remapped_symmetry_residues = []
223
+ for res in residue_indices:
224
+ tmp_list = []
225
+ tmp_w_list = []
226
+ for chain in chain_letters_set:
227
+ name = chain + res
228
+ tmp_list.append(encoded_residue_dict[name])
229
+ tmp_w_list.append(1 / len(chain_letters_set))
230
+ remapped_symmetry_residues.append(tmp_list)
231
+
232
+ # set other atom bfactors to 0.0
233
+ if other_atoms:
234
+ other_bfactors = other_atoms.getBetas()
235
+ other_atoms.setBetas(other_bfactors * 0.0)
236
+
237
+ # adjust input PDB name by dropping .pdb if it does exist
238
+ name = pdb[pdb.rfind("/") + 1 :]
239
+ if name[-4:] == ".pdb":
240
+ name = name[:-4]
241
+
242
+ with torch.no_grad():
243
+ # run featurize to remap R_idx and add batch dimension
244
+ if args.verbose:
245
+ if "Y" in list(protein_dict):
246
+ atom_coords = protein_dict["Y"].cpu().numpy()
247
+ atom_types = list(protein_dict["Y_t"].cpu().numpy())
248
+ atom_mask = list(protein_dict["Y_m"].cpu().numpy())
249
+ number_of_atoms_parsed = np.sum(atom_mask)
250
+ else:
251
+ print("No ligand atoms parsed")
252
+ number_of_atoms_parsed = 0
253
+ atom_types = ""
254
+ atom_coords = []
255
+ if number_of_atoms_parsed == 0:
256
+ print("No ligand atoms parsed")
257
+ elif args.model_type == "ligand_mpnn":
258
+ print(
259
+ f"The number of ligand atoms parsed is equal to: {number_of_atoms_parsed}"
260
+ )
261
+ for i, atom_type in enumerate(atom_types):
262
+ print(
263
+ f"Type: {element_dict_rev[atom_type]}, Coords {atom_coords[i]}, Mask {atom_mask[i]}"
264
+ )
265
+ feature_dict = featurize(
266
+ protein_dict,
267
+ cutoff_for_score=args.ligand_mpnn_cutoff_for_score,
268
+ use_atom_context=args.ligand_mpnn_use_atom_context,
269
+ number_of_ligand_atoms=atom_context_num,
270
+ model_type=args.model_type,
271
+ )
272
+ feature_dict["batch_size"] = args.batch_size
273
+ B, L, _, _ = feature_dict["X"].shape # batch size should be 1 for now.
274
+ # add additional keys to the feature dictionary
275
+ feature_dict["symmetry_residues"] = remapped_symmetry_residues
276
+
277
+ logits_list = []
278
+ probs_list = []
279
+ log_probs_list = []
280
+ decoding_order_list = []
281
+ for _ in range(args.number_of_batches):
282
+ feature_dict["randn"] = torch.randn(
283
+ [feature_dict["batch_size"], feature_dict["mask"].shape[1]],
284
+ device=device,
285
+ )
286
+ if args.autoregressive_score:
287
+ score_dict = model.score(feature_dict, use_sequence=args.use_sequence)
288
+ elif args.single_aa_score:
289
+ score_dict = model.single_aa_score(feature_dict, use_sequence=args.use_sequence)
290
+ else:
291
+ print("Set either autoregressive_score or single_aa_score to True")
292
+ sys.exit()
293
+ logits_list.append(score_dict["logits"])
294
+ log_probs_list.append(score_dict["log_probs"])
295
+ probs_list.append(torch.exp(score_dict["log_probs"]))
296
+ decoding_order_list.append(score_dict["decoding_order"])
297
+ log_probs_stack = torch.cat(log_probs_list, 0)
298
+ logits_stack = torch.cat(logits_list, 0)
299
+ probs_stack = torch.cat(probs_list, 0)
300
+ decoding_order_stack = torch.cat(decoding_order_list, 0)
301
+
302
+ output_stats_path = base_folder + name + args.file_ending + ".pt"
303
+ out_dict = {}
304
+ out_dict["logits"] = logits_stack.cpu().numpy()
305
+ out_dict["probs"] = probs_stack.cpu().numpy()
306
+ out_dict["log_probs"] = log_probs_stack.cpu().numpy()
307
+ out_dict["decoding_order"] = decoding_order_stack.cpu().numpy()
308
+ out_dict["native_sequence"] = feature_dict["S"][0].cpu().numpy()
309
+ out_dict["mask"] = feature_dict["mask"][0].cpu().numpy()
310
+ out_dict["chain_mask"] = feature_dict["chain_mask"][0].cpu().numpy() #this affects decoding order
311
+ out_dict["seed"] = seed
312
+ out_dict["alphabet"] = alphabet
313
+ out_dict["residue_names"] = encoded_residue_dict_rev
314
+
315
+ mean_probs = np.mean(out_dict["probs"], 0)
316
+ std_probs = np.std(out_dict["probs"], 0)
317
+ sequence = [restype_int_to_str[AA] for AA in out_dict["native_sequence"]]
318
+ mean_dict = {}
319
+ std_dict = {}
320
+ for residue in range(L):
321
+ mean_dict_ = dict(zip(alphabet, mean_probs[residue]))
322
+ mean_dict[encoded_residue_dict_rev[residue]] = mean_dict_
323
+ std_dict_ = dict(zip(alphabet, std_probs[residue]))
324
+ std_dict[encoded_residue_dict_rev[residue]] = std_dict_
325
+
326
+ out_dict["sequence"] = sequence
327
+ out_dict["mean_of_probs"] = mean_dict
328
+ out_dict["std_of_probs"] = std_dict
329
+ torch.save(out_dict, output_stats_path)
330
+
331
+
332
+
333
+ if __name__ == "__main__":
334
+ argparser = argparse.ArgumentParser(
335
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
336
+ )
337
+
338
+ argparser.add_argument(
339
+ "--model_type",
340
+ type=str,
341
+ default="protein_mpnn",
342
+ help="Choose your model: protein_mpnn, ligand_mpnn, per_residue_label_membrane_mpnn, global_label_membrane_mpnn, soluble_mpnn",
343
+ )
344
+ # protein_mpnn - original ProteinMPNN trained on the whole PDB exluding non-protein atoms
345
+ # ligand_mpnn - atomic context aware model trained with small molecules, nucleotides, metals etc on the whole PDB
346
+ # per_residue_label_membrane_mpnn - ProteinMPNN model trained with addition label per residue specifying if that residue is buried or exposed
347
+ # global_label_membrane_mpnn - ProteinMPNN model trained with global label per PDB id to specify if protein is transmembrane
348
+ # soluble_mpnn - ProteinMPNN trained only on soluble PDB ids
349
+ argparser.add_argument(
350
+ "--checkpoint_protein_mpnn",
351
+ type=str,
352
+ default="./model_params/proteinmpnn_v_48_020.pt",
353
+ help="Path to model weights.",
354
+ )
355
+ argparser.add_argument(
356
+ "--checkpoint_ligand_mpnn",
357
+ type=str,
358
+ default="./model_params/ligandmpnn_v_32_010_25.pt",
359
+ help="Path to model weights.",
360
+ )
361
+ argparser.add_argument(
362
+ "--checkpoint_per_residue_label_membrane_mpnn",
363
+ type=str,
364
+ default="./model_params/per_residue_label_membrane_mpnn_v_48_020.pt",
365
+ help="Path to model weights.",
366
+ )
367
+ argparser.add_argument(
368
+ "--checkpoint_global_label_membrane_mpnn",
369
+ type=str,
370
+ default="./model_params/global_label_membrane_mpnn_v_48_020.pt",
371
+ help="Path to model weights.",
372
+ )
373
+ argparser.add_argument(
374
+ "--checkpoint_soluble_mpnn",
375
+ type=str,
376
+ default="./model_params/solublempnn_v_48_020.pt",
377
+ help="Path to model weights.",
378
+ )
379
+
380
+ argparser.add_argument("--verbose", type=int, default=1, help="Print stuff")
381
+
382
+ argparser.add_argument(
383
+ "--pdb_path", type=str, default="", help="Path to the input PDB."
384
+ )
385
+ argparser.add_argument(
386
+ "--pdb_path_multi",
387
+ type=str,
388
+ default="",
389
+ help="Path to json listing PDB paths. {'/path/to/pdb': ''} - only keys will be used.",
390
+ )
391
+
392
+ argparser.add_argument(
393
+ "--fixed_residues",
394
+ type=str,
395
+ default="",
396
+ help="Provide fixed residues, A12 A13 A14 B2 B25",
397
+ )
398
+ argparser.add_argument(
399
+ "--fixed_residues_multi",
400
+ type=str,
401
+ default="",
402
+ help="Path to json mapping of fixed residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
403
+ )
404
+
405
+ argparser.add_argument(
406
+ "--redesigned_residues",
407
+ type=str,
408
+ default="",
409
+ help="Provide to be redesigned residues, everything else will be fixed, A12 A13 A14 B2 B25",
410
+ )
411
+ argparser.add_argument(
412
+ "--redesigned_residues_multi",
413
+ type=str,
414
+ default="",
415
+ help="Path to json mapping of redesigned residues for each pdb i.e., {'/path/to/pdb': 'A12 A13 A14 B2 B25'}",
416
+ )
417
+
418
+ argparser.add_argument(
419
+ "--symmetry_residues",
420
+ type=str,
421
+ default="",
422
+ help="Add list of lists for which residues need to be symmetric, e.g. 'A12,A13,A14|C2,C3|A5,B6'",
423
+ )
424
+
425
+ argparser.add_argument(
426
+ "--homo_oligomer",
427
+ type=int,
428
+ default=0,
429
+ help="Setting this to 1 will automatically set --symmetry_residues and --symmetry_weights to do homooligomer design with equal weighting.",
430
+ )
431
+
432
+ argparser.add_argument(
433
+ "--out_folder",
434
+ type=str,
435
+ help="Path to a folder to output scores, e.g. /home/out/",
436
+ )
437
+ argparser.add_argument(
438
+ "--file_ending", type=str, default="", help="adding_string_to_the_end"
439
+ )
440
+ argparser.add_argument(
441
+ "--zero_indexed",
442
+ type=str,
443
+ default=0,
444
+ help="1 - to start output PDB numbering with 0",
445
+ )
446
+ argparser.add_argument(
447
+ "--seed",
448
+ type=int,
449
+ default=0,
450
+ help="Set seed for torch, numpy, and python random.",
451
+ )
452
+ argparser.add_argument(
453
+ "--batch_size",
454
+ type=int,
455
+ default=1,
456
+ help="Number of sequence to generate per one pass.",
457
+ )
458
+ argparser.add_argument(
459
+ "--number_of_batches",
460
+ type=int,
461
+ default=1,
462
+ help="Number of times to design sequence using a chosen batch size.",
463
+ )
464
+
465
+ argparser.add_argument(
466
+ "--ligand_mpnn_use_atom_context",
467
+ type=int,
468
+ default=1,
469
+ help="1 - use atom context, 0 - do not use atom context.",
470
+ )
471
+
472
+ argparser.add_argument(
473
+ "--ligand_mpnn_use_side_chain_context",
474
+ type=int,
475
+ default=0,
476
+ help="Flag to use side chain atoms as ligand context for the fixed residues",
477
+ )
478
+
479
+ argparser.add_argument(
480
+ "--ligand_mpnn_cutoff_for_score",
481
+ type=float,
482
+ default=8.0,
483
+ help="Cutoff in angstroms between protein and context atoms to select residues for reporting score.",
484
+ )
485
+
486
+ argparser.add_argument(
487
+ "--chains_to_design",
488
+ type=str,
489
+ default=None,
490
+ help="Specify which chains to redesign, all others will be kept fixed.",
491
+ )
492
+
493
+ argparser.add_argument(
494
+ "--parse_these_chains_only",
495
+ type=str,
496
+ default="",
497
+ help="Provide chains letters for parsing backbones, 'ABCF'",
498
+ )
499
+
500
+ argparser.add_argument(
501
+ "--transmembrane_buried",
502
+ type=str,
503
+ default="",
504
+ help="Provide buried residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
505
+ )
506
+ argparser.add_argument(
507
+ "--transmembrane_interface",
508
+ type=str,
509
+ default="",
510
+ help="Provide interface residues when using checkpoint_per_residue_label_membrane_mpnn model, A12 A13 A14 B2 B25",
511
+ )
512
+
513
+ argparser.add_argument(
514
+ "--global_transmembrane_label",
515
+ type=int,
516
+ default=0,
517
+ help="Provide global label for global_label_membrane_mpnn model. 1 - transmembrane, 0 - soluble",
518
+ )
519
+
520
+ argparser.add_argument(
521
+ "--parse_atoms_with_zero_occupancy",
522
+ type=int,
523
+ default=0,
524
+ help="To parse atoms with zero occupancy in the PDB input files. 0 - do not parse, 1 - parse atoms with zero occupancy",
525
+ )
526
+
527
+ argparser.add_argument(
528
+ "--use_sequence",
529
+ type=int,
530
+ default=1,
531
+ help="1 - get scores using amino acid sequence info; 0 - get scores using backbone info only",
532
+ )
533
+
534
+ argparser.add_argument(
535
+ "--autoregressive_score",
536
+ type=int,
537
+ default=0,
538
+ help="1 - run autoregressive scoring function; p(AA_1|backbone); p(AA_2|backbone, AA_1) etc, 0 - False",
539
+ )
540
+
541
+ argparser.add_argument(
542
+ "--single_aa_score",
543
+ type=int,
544
+ default=1,
545
+ help="1 - run single amino acid scoring function; p(AA_i|backbone, AA_{all except ith one}), 0 - False",
546
+ )
547
+
548
+ args = argparser.parse_args()
549
+ main(args)
space_utils/download_weights.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+
3
+ def download_ligandmpnn_weights():
4
+ url = "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_030_25.pt"
5
+ command = f"wget {url} -O ./model_params/ligandmpnn_v_32_030_25.pt"
6
+ subprocess.run(command, shell=True, check=True)
7
+ return 0