lhallee commited on
Commit
989ec3c
·
verified ·
1 Parent(s): 69d2b14

Upload vb_modules_confidencev2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vb_modules_confidencev2.py +498 -498
vb_modules_confidencev2.py CHANGED
@@ -1,498 +1,498 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn.functional import pad
4
-
5
- from . import vb_const as const
6
- from . import vb_layers_initialize as init
7
- from .vb_layers_confidence_utils import (
8
- compute_aggregated_metric,
9
- compute_ptms,
10
- )
11
- from .vb_layers_pairformer import PairformerModule
12
- from .vb_modules_encodersv2 import RelativePositionEncoder
13
- from .vb_modules_trunkv2 import (
14
- ContactConditioning,
15
- )
16
- from .vb_modules_utils import LinearNoBias
17
-
18
-
19
- class ConfidenceModule(nn.Module):
20
- """Algorithm 31"""
21
-
22
- def __init__(
23
- self,
24
- token_s,
25
- token_z,
26
- pairformer_args: dict,
27
- num_dist_bins=64,
28
- token_level_confidence=True,
29
- max_dist=22,
30
- add_s_to_z_prod=False,
31
- add_s_input_to_s=False,
32
- add_z_input_to_z=False,
33
- maximum_bond_distance=0,
34
- bond_type_feature=False,
35
- confidence_args: dict = None,
36
- compile_pairformer=False,
37
- fix_sym_check=False,
38
- cyclic_pos_enc=False,
39
- return_latent_feats=False,
40
- conditioning_cutoff_min=None,
41
- conditioning_cutoff_max=None,
42
- **kwargs,
43
- ):
44
- super().__init__()
45
- self.max_num_atoms_per_token = 23
46
- if "no_update_s" in pairformer_args:
47
- self.no_update_s = pairformer_args["no_update_s"]
48
- else:
49
- self.no_update_s = False
50
- boundaries = torch.linspace(2, max_dist, num_dist_bins - 1)
51
- self.register_buffer("boundaries", boundaries)
52
- self.dist_bin_pairwise_embed = nn.Embedding(num_dist_bins, token_z)
53
- init.gating_init_(self.dist_bin_pairwise_embed.weight)
54
- self.token_level_confidence = token_level_confidence
55
-
56
- self.s_to_z = LinearNoBias(token_s, token_z)
57
- self.s_to_z_transpose = LinearNoBias(token_s, token_z)
58
- init.gating_init_(self.s_to_z.weight)
59
- init.gating_init_(self.s_to_z_transpose.weight)
60
-
61
- self.add_s_to_z_prod = add_s_to_z_prod
62
- if add_s_to_z_prod:
63
- self.s_to_z_prod_in1 = LinearNoBias(token_s, token_z)
64
- self.s_to_z_prod_in2 = LinearNoBias(token_s, token_z)
65
- self.s_to_z_prod_out = LinearNoBias(token_z, token_z)
66
- init.gating_init_(self.s_to_z_prod_out.weight)
67
-
68
- self.s_inputs_norm = nn.LayerNorm(token_s)
69
- if not self.no_update_s:
70
- self.s_norm = nn.LayerNorm(token_s)
71
- self.z_norm = nn.LayerNorm(token_z)
72
-
73
- self.add_s_input_to_s = add_s_input_to_s
74
- if add_s_input_to_s:
75
- self.s_input_to_s = LinearNoBias(token_s, token_s)
76
- init.gating_init_(self.s_input_to_s.weight)
77
-
78
- self.add_z_input_to_z = add_z_input_to_z
79
- if add_z_input_to_z:
80
- self.rel_pos = RelativePositionEncoder(
81
- token_z, fix_sym_check=fix_sym_check, cyclic_pos_enc=cyclic_pos_enc
82
- )
83
- self.token_bonds = nn.Linear(
84
- 1 if maximum_bond_distance == 0 else maximum_bond_distance + 2,
85
- token_z,
86
- bias=False,
87
- )
88
- self.bond_type_feature = bond_type_feature
89
- if bond_type_feature:
90
- self.token_bonds_type = nn.Embedding(len(const.bond_types) + 1, token_z)
91
-
92
- self.contact_conditioning = ContactConditioning(
93
- token_z=token_z,
94
- cutoff_min=conditioning_cutoff_min,
95
- cutoff_max=conditioning_cutoff_max,
96
- )
97
- pairformer_args["v2"] = True
98
- self.pairformer_stack = PairformerModule(
99
- token_s,
100
- token_z,
101
- **pairformer_args,
102
- )
103
- self.return_latent_feats = return_latent_feats
104
-
105
- self.confidence_heads = ConfidenceHeads(
106
- token_s,
107
- token_z,
108
- token_level_confidence=token_level_confidence,
109
- **confidence_args,
110
- )
111
-
112
- def forward(
113
- self,
114
- s_inputs, # Float['b n ts']
115
- s, # Float['b n ts']
116
- z, # Float['b n n tz']
117
- x_pred, # Float['bm m 3']
118
- feats,
119
- pred_distogram_logits,
120
- multiplicity=1,
121
- run_sequentially=False,
122
- use_kernels: bool = False,
123
- ):
124
- if run_sequentially and multiplicity > 1:
125
- assert z.shape[0] == 1, "Not supported with batch size > 1"
126
- out_dicts = []
127
- for sample_idx in range(multiplicity):
128
- out_dicts.append( # noqa: PERF401
129
- self.forward(
130
- s_inputs,
131
- s,
132
- z,
133
- x_pred[sample_idx : sample_idx + 1],
134
- feats,
135
- pred_distogram_logits,
136
- multiplicity=1,
137
- run_sequentially=False,
138
- use_kernels=use_kernels,
139
- )
140
- )
141
-
142
- out_dict = {}
143
- for key in out_dicts[0]:
144
- if key != "pair_chains_iptm":
145
- out_dict[key] = torch.cat([out[key] for out in out_dicts], dim=0)
146
- else:
147
- pair_chains_iptm = {}
148
- for chain_idx1 in out_dicts[0][key]:
149
- chains_iptm = {}
150
- for chain_idx2 in out_dicts[0][key][chain_idx1]:
151
- chains_iptm[chain_idx2] = torch.cat(
152
- [out[key][chain_idx1][chain_idx2] for out in out_dicts],
153
- dim=0,
154
- )
155
- pair_chains_iptm[chain_idx1] = chains_iptm
156
- out_dict[key] = pair_chains_iptm
157
- return out_dict
158
-
159
- s_inputs = self.s_inputs_norm(s_inputs)
160
- if not self.no_update_s:
161
- s = self.s_norm(s)
162
-
163
- if self.add_s_input_to_s:
164
- s = s + self.s_input_to_s(s_inputs)
165
-
166
- z = self.z_norm(z)
167
-
168
- if self.add_z_input_to_z:
169
- relative_position_encoding = self.rel_pos(feats)
170
- z = z + relative_position_encoding
171
- z = z + self.token_bonds(feats["token_bonds"].float())
172
- if self.bond_type_feature:
173
- z = z + self.token_bonds_type(feats["type_bonds"].long())
174
- z = z + self.contact_conditioning(feats)
175
-
176
- s = s.repeat_interleave(multiplicity, 0)
177
-
178
- z = (
179
- z
180
- + self.s_to_z(s_inputs)[:, :, None, :]
181
- + self.s_to_z_transpose(s_inputs)[:, None, :, :]
182
- )
183
- if self.add_s_to_z_prod:
184
- z = z + self.s_to_z_prod_out(
185
- self.s_to_z_prod_in1(s_inputs)[:, :, None, :]
186
- * self.s_to_z_prod_in2(s_inputs)[:, None, :, :]
187
- )
188
-
189
- z = z.repeat_interleave(multiplicity, 0)
190
- s_inputs = s_inputs.repeat_interleave(multiplicity, 0)
191
-
192
- token_to_rep_atom = feats["token_to_rep_atom"]
193
- token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
194
- if len(x_pred.shape) == 4:
195
- B, mult, N, _ = x_pred.shape
196
- x_pred = x_pred.reshape(B * mult, N, -1)
197
- else:
198
- BM, N, _ = x_pred.shape
199
- x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred)
200
- d = torch.cdist(x_pred_repr, x_pred_repr)
201
- distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
202
- distogram = self.dist_bin_pairwise_embed(distogram)
203
- z = z + distogram
204
-
205
- mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
206
- pair_mask = mask[:, :, None] * mask[:, None, :]
207
-
208
- s_t, z_t = self.pairformer_stack(
209
- s, z, mask=mask, pair_mask=pair_mask, use_kernels=use_kernels
210
- )
211
-
212
- # AF3 has residual connections, we remove them
213
- s = s_t
214
- z = z_t
215
-
216
- out_dict = {}
217
-
218
- if self.return_latent_feats:
219
- out_dict["s_conf"] = s
220
- out_dict["z_conf"] = z
221
-
222
- # confidence heads
223
- out_dict.update(
224
- self.confidence_heads(
225
- s=s,
226
- z=z,
227
- x_pred=x_pred,
228
- d=d,
229
- feats=feats,
230
- multiplicity=multiplicity,
231
- pred_distogram_logits=pred_distogram_logits,
232
- )
233
- )
234
- return out_dict
235
-
236
-
237
- class ConfidenceHeads(nn.Module):
238
- def __init__(
239
- self,
240
- token_s,
241
- token_z,
242
- num_plddt_bins=50,
243
- num_pde_bins=64,
244
- num_pae_bins=64,
245
- token_level_confidence=True,
246
- use_separate_heads: bool = False,
247
- **kwargs,
248
- ):
249
- super().__init__()
250
- self.max_num_atoms_per_token = 23
251
- self.token_level_confidence = token_level_confidence
252
- self.use_separate_heads = use_separate_heads
253
-
254
- if self.use_separate_heads:
255
- self.to_pae_intra_logits = LinearNoBias(token_z, num_pae_bins)
256
- self.to_pae_inter_logits = LinearNoBias(token_z, num_pae_bins)
257
- else:
258
- self.to_pae_logits = LinearNoBias(token_z, num_pae_bins)
259
-
260
- if self.use_separate_heads:
261
- self.to_pde_intra_logits = LinearNoBias(token_z, num_pde_bins)
262
- self.to_pde_inter_logits = LinearNoBias(token_z, num_pde_bins)
263
- else:
264
- self.to_pde_logits = LinearNoBias(token_z, num_pde_bins)
265
-
266
- if self.token_level_confidence:
267
- self.to_plddt_logits = LinearNoBias(token_s, num_plddt_bins)
268
- self.to_resolved_logits = LinearNoBias(token_s, 2)
269
- else:
270
- self.to_plddt_logits = LinearNoBias(
271
- token_s, num_plddt_bins * self.max_num_atoms_per_token
272
- )
273
- self.to_resolved_logits = LinearNoBias(
274
- token_s, 2 * self.max_num_atoms_per_token
275
- )
276
-
277
- def forward(
278
- self,
279
- s, # Float['b n ts']
280
- z, # Float['b n n tz']
281
- x_pred, # Float['bm m 3']
282
- d,
283
- feats,
284
- pred_distogram_logits,
285
- multiplicity=1,
286
- ):
287
- if self.use_separate_heads:
288
- asym_id_token = feats["asym_id"]
289
- is_same_chain = asym_id_token.unsqueeze(-1) == asym_id_token.unsqueeze(-2)
290
- is_different_chain = ~is_same_chain
291
-
292
- if self.use_separate_heads:
293
- pae_intra_logits = self.to_pae_intra_logits(z)
294
- pae_intra_logits = pae_intra_logits * is_same_chain.float().unsqueeze(-1)
295
-
296
- pae_inter_logits = self.to_pae_inter_logits(z)
297
- pae_inter_logits = pae_inter_logits * is_different_chain.float().unsqueeze(
298
- -1
299
- )
300
-
301
- pae_logits = pae_inter_logits + pae_intra_logits
302
- else:
303
- pae_logits = self.to_pae_logits(z)
304
-
305
- if self.use_separate_heads:
306
- pde_intra_logits = self.to_pde_intra_logits(z + z.transpose(1, 2))
307
- pde_intra_logits = pde_intra_logits * is_same_chain.float().unsqueeze(-1)
308
-
309
- pde_inter_logits = self.to_pde_inter_logits(z + z.transpose(1, 2))
310
- pde_inter_logits = pde_inter_logits * is_different_chain.float().unsqueeze(
311
- -1
312
- )
313
-
314
- pde_logits = pde_inter_logits + pde_intra_logits
315
- else:
316
- pde_logits = self.to_pde_logits(z + z.transpose(1, 2))
317
- resolved_logits = self.to_resolved_logits(s)
318
- plddt_logits = self.to_plddt_logits(s)
319
-
320
- ligand_weight = 20
321
- non_interface_weight = 1
322
- interface_weight = 10
323
-
324
- token_type = feats["mol_type"]
325
- token_type = token_type.repeat_interleave(multiplicity, 0)
326
- is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
327
-
328
- if self.token_level_confidence:
329
- plddt = compute_aggregated_metric(plddt_logits)
330
- token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
331
- complex_plddt = (plddt * token_pad_mask).sum(dim=-1) / token_pad_mask.sum(
332
- dim=-1
333
- )
334
-
335
- is_contact = (d < 8).float()
336
- is_different_chain = (
337
- feats["asym_id"].unsqueeze(-1) != feats["asym_id"].unsqueeze(-2)
338
- ).float()
339
- is_different_chain = is_different_chain.repeat_interleave(multiplicity, 0)
340
- token_interface_mask = torch.max(
341
- is_contact * is_different_chain * (1 - is_ligand_token).unsqueeze(-1),
342
- dim=-1,
343
- ).values
344
- token_non_interface_mask = (1 - token_interface_mask) * (
345
- 1 - is_ligand_token
346
- )
347
- iplddt_weight = (
348
- is_ligand_token * ligand_weight
349
- + token_interface_mask * interface_weight
350
- + token_non_interface_mask * non_interface_weight
351
- )
352
- complex_iplddt = (plddt * token_pad_mask * iplddt_weight).sum(
353
- dim=-1
354
- ) / torch.sum(token_pad_mask * iplddt_weight, dim=-1)
355
-
356
- else:
357
- # token to atom conversion for resolved logits
358
- B, N, _ = resolved_logits.shape
359
- resolved_logits = resolved_logits.reshape(
360
- B, N, self.max_num_atoms_per_token, 2
361
- )
362
-
363
- arange_max_num_atoms = (
364
- torch.arange(self.max_num_atoms_per_token)
365
- .reshape(1, 1, -1)
366
- .to(resolved_logits.device)
367
- )
368
- max_num_atoms_mask = (
369
- feats["atom_to_token"].sum(1).unsqueeze(-1) > arange_max_num_atoms
370
- )
371
- resolved_logits = resolved_logits[:, max_num_atoms_mask.squeeze(0)]
372
- resolved_logits = pad(
373
- resolved_logits,
374
- (
375
- 0,
376
- 0,
377
- 0,
378
- int(
379
- feats["atom_pad_mask"].shape[1]
380
- - feats["atom_pad_mask"].sum().item()
381
- ),
382
- ),
383
- value=0,
384
- )
385
- plddt_logits = plddt_logits.reshape(B, N, self.max_num_atoms_per_token, -1)
386
- plddt_logits = plddt_logits[:, max_num_atoms_mask.squeeze(0)]
387
- plddt_logits = pad(
388
- plddt_logits,
389
- (
390
- 0,
391
- 0,
392
- 0,
393
- int(
394
- feats["atom_pad_mask"].shape[1]
395
- - feats["atom_pad_mask"].sum().item()
396
- ),
397
- ),
398
- value=0,
399
- )
400
- atom_pad_mask = feats["atom_pad_mask"].repeat_interleave(multiplicity, 0)
401
- plddt = compute_aggregated_metric(plddt_logits)
402
-
403
- complex_plddt = (plddt * atom_pad_mask).sum(dim=-1) / atom_pad_mask.sum(
404
- dim=-1
405
- )
406
- token_type = feats["mol_type"].float()
407
- atom_to_token = feats["atom_to_token"].float()
408
- chain_id_token = feats["asym_id"].float()
409
- atom_type = torch.bmm(atom_to_token, token_type.unsqueeze(-1)).squeeze(-1)
410
- is_ligand_atom = (atom_type == const.chain_type_ids["NONPOLYMER"]).float()
411
- d_atom = torch.cdist(x_pred, x_pred)
412
- is_contact = (d_atom < 8).float()
413
- chain_id_atom = torch.bmm(
414
- atom_to_token, chain_id_token.unsqueeze(-1)
415
- ).squeeze(-1)
416
- is_different_chain = (
417
- chain_id_atom.unsqueeze(-1) != chain_id_atom.unsqueeze(-2)
418
- ).float()
419
-
420
- atom_interface_mask = torch.max(
421
- is_contact * is_different_chain * (1 - is_ligand_atom).unsqueeze(-1),
422
- dim=-1,
423
- ).values
424
- atom_non_interface_mask = (1 - atom_interface_mask) * (1 - is_ligand_atom)
425
- iplddt_weight = (
426
- is_ligand_atom * ligand_weight
427
- + atom_interface_mask * interface_weight
428
- + atom_non_interface_mask * non_interface_weight
429
- )
430
-
431
- complex_iplddt = (plddt * feats["atom_pad_mask"] * iplddt_weight).sum(
432
- dim=-1
433
- ) / torch.sum(feats["atom_pad_mask"] * iplddt_weight, dim=-1)
434
-
435
- # Compute the gPDE and giPDE
436
- pde = compute_aggregated_metric(pde_logits, end=32)
437
- pred_distogram_prob = nn.functional.softmax(
438
- pred_distogram_logits, dim=-1
439
- ).repeat_interleave(multiplicity, 0)
440
- contacts = torch.zeros((1, 1, 1, 64), dtype=pred_distogram_prob.dtype).to(
441
- pred_distogram_prob.device
442
- )
443
- contacts[:, :, :, :20] = 1.0
444
- prob_contact = (pred_distogram_prob * contacts).sum(-1)
445
- token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
446
- token_pad_pair_mask = (
447
- token_pad_mask.unsqueeze(-1)
448
- * token_pad_mask.unsqueeze(-2)
449
- * (
450
- 1
451
- - torch.eye(
452
- token_pad_mask.shape[1], device=token_pad_mask.device
453
- ).unsqueeze(0)
454
- )
455
- )
456
- token_pair_mask = token_pad_pair_mask * prob_contact
457
- complex_pde = (pde * token_pair_mask).sum(dim=(1, 2)) / token_pair_mask.sum(
458
- dim=(1, 2)
459
- )
460
- asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
461
- token_interface_pair_mask = token_pair_mask * (
462
- asym_id.unsqueeze(-1) != asym_id.unsqueeze(-2)
463
- )
464
- complex_ipde = (pde * token_interface_pair_mask).sum(dim=(1, 2)) / (
465
- token_interface_pair_mask.sum(dim=(1, 2)) + 1e-5
466
- )
467
- out_dict = dict(
468
- pde_logits=pde_logits,
469
- plddt_logits=plddt_logits,
470
- resolved_logits=resolved_logits,
471
- pde=pde,
472
- plddt=plddt,
473
- complex_plddt=complex_plddt,
474
- complex_iplddt=complex_iplddt,
475
- complex_pde=complex_pde,
476
- complex_ipde=complex_ipde,
477
- )
478
- out_dict["pae_logits"] = pae_logits
479
- out_dict["pae"] = compute_aggregated_metric(pae_logits, end=32)
480
-
481
- try:
482
- ptm, iptm, ligand_iptm, protein_iptm, pair_chains_iptm = compute_ptms(
483
- pae_logits, x_pred, feats, multiplicity
484
- )
485
- out_dict["ptm"] = ptm
486
- out_dict["iptm"] = iptm
487
- out_dict["ligand_iptm"] = ligand_iptm
488
- out_dict["protein_iptm"] = protein_iptm
489
- out_dict["pair_chains_iptm"] = pair_chains_iptm
490
- except Exception as e:
491
- print(f"Error in compute_ptms: {e}")
492
- out_dict["ptm"] = torch.zeros_like(complex_plddt)
493
- out_dict["iptm"] = torch.zeros_like(complex_plddt)
494
- out_dict["ligand_iptm"] = torch.zeros_like(complex_plddt)
495
- out_dict["protein_iptm"] = torch.zeros_like(complex_plddt)
496
- out_dict["pair_chains_iptm"] = torch.zeros_like(complex_plddt)
497
-
498
- return out_dict
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn.functional import pad
4
+
5
+ from . import vb_const as const
6
+ from . import vb_layers_initialize as init
7
+ from .vb_layers_confidence_utils import (
8
+ compute_aggregated_metric,
9
+ compute_ptms,
10
+ )
11
+ from .vb_layers_pairformer import PairformerModule
12
+ from .vb_modules_encodersv2 import RelativePositionEncoder
13
+ from .vb_modules_trunkv2 import (
14
+ ContactConditioning,
15
+ )
16
+ from .vb_modules_utils import LinearNoBias
17
+
18
+
19
+ class ConfidenceModule(nn.Module):
20
+ """Algorithm 31"""
21
+
22
+ def __init__(
23
+ self,
24
+ token_s,
25
+ token_z,
26
+ pairformer_args: dict,
27
+ num_dist_bins=64,
28
+ token_level_confidence=True,
29
+ max_dist=22,
30
+ add_s_to_z_prod=False,
31
+ add_s_input_to_s=False,
32
+ add_z_input_to_z=False,
33
+ maximum_bond_distance=0,
34
+ bond_type_feature=False,
35
+ confidence_args: dict = None,
36
+ compile_pairformer=False,
37
+ fix_sym_check=False,
38
+ cyclic_pos_enc=False,
39
+ return_latent_feats=False,
40
+ conditioning_cutoff_min=None,
41
+ conditioning_cutoff_max=None,
42
+ **kwargs,
43
+ ):
44
+ super().__init__()
45
+ self.max_num_atoms_per_token = 23
46
+ if "no_update_s" in pairformer_args:
47
+ self.no_update_s = pairformer_args["no_update_s"]
48
+ else:
49
+ self.no_update_s = False
50
+ boundaries = torch.linspace(2, max_dist, num_dist_bins - 1)
51
+ self.register_buffer("boundaries", boundaries)
52
+ self.dist_bin_pairwise_embed = nn.Embedding(num_dist_bins, token_z)
53
+ init.gating_init_(self.dist_bin_pairwise_embed.weight)
54
+ self.token_level_confidence = token_level_confidence
55
+
56
+ self.s_to_z = LinearNoBias(token_s, token_z)
57
+ self.s_to_z_transpose = LinearNoBias(token_s, token_z)
58
+ init.gating_init_(self.s_to_z.weight)
59
+ init.gating_init_(self.s_to_z_transpose.weight)
60
+
61
+ self.add_s_to_z_prod = add_s_to_z_prod
62
+ if add_s_to_z_prod:
63
+ self.s_to_z_prod_in1 = LinearNoBias(token_s, token_z)
64
+ self.s_to_z_prod_in2 = LinearNoBias(token_s, token_z)
65
+ self.s_to_z_prod_out = LinearNoBias(token_z, token_z)
66
+ init.gating_init_(self.s_to_z_prod_out.weight)
67
+
68
+ self.s_inputs_norm = nn.LayerNorm(token_s)
69
+ if not self.no_update_s:
70
+ self.s_norm = nn.LayerNorm(token_s)
71
+ self.z_norm = nn.LayerNorm(token_z)
72
+
73
+ self.add_s_input_to_s = add_s_input_to_s
74
+ if add_s_input_to_s:
75
+ self.s_input_to_s = LinearNoBias(token_s, token_s)
76
+ init.gating_init_(self.s_input_to_s.weight)
77
+
78
+ self.add_z_input_to_z = add_z_input_to_z
79
+ if add_z_input_to_z:
80
+ self.rel_pos = RelativePositionEncoder(
81
+ token_z, fix_sym_check=fix_sym_check, cyclic_pos_enc=cyclic_pos_enc
82
+ )
83
+ self.token_bonds = nn.Linear(
84
+ 1 if maximum_bond_distance == 0 else maximum_bond_distance + 2,
85
+ token_z,
86
+ bias=False,
87
+ )
88
+ self.bond_type_feature = bond_type_feature
89
+ if bond_type_feature:
90
+ self.token_bonds_type = nn.Embedding(len(const.bond_types) + 1, token_z)
91
+
92
+ self.contact_conditioning = ContactConditioning(
93
+ token_z=token_z,
94
+ cutoff_min=conditioning_cutoff_min,
95
+ cutoff_max=conditioning_cutoff_max,
96
+ )
97
+ pairformer_args["v2"] = True
98
+ self.pairformer_stack = PairformerModule(
99
+ token_s,
100
+ token_z,
101
+ **pairformer_args,
102
+ )
103
+ self.return_latent_feats = return_latent_feats
104
+
105
+ self.confidence_heads = ConfidenceHeads(
106
+ token_s,
107
+ token_z,
108
+ token_level_confidence=token_level_confidence,
109
+ **confidence_args,
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ s_inputs, # Float['b n ts']
115
+ s, # Float['b n ts']
116
+ z, # Float['b n n tz']
117
+ x_pred, # Float['bm m 3']
118
+ feats,
119
+ pred_distogram_logits,
120
+ multiplicity=1,
121
+ run_sequentially=False,
122
+ use_kernels: bool = False,
123
+ ):
124
+ if run_sequentially and multiplicity > 1:
125
+ assert z.shape[0] == 1, "Not supported with batch size > 1"
126
+ out_dicts = []
127
+ for sample_idx in range(multiplicity):
128
+ out_dicts.append( # noqa: PERF401
129
+ self.forward(
130
+ s_inputs,
131
+ s,
132
+ z,
133
+ x_pred[sample_idx : sample_idx + 1],
134
+ feats,
135
+ pred_distogram_logits,
136
+ multiplicity=1,
137
+ run_sequentially=False,
138
+ use_kernels=use_kernels,
139
+ )
140
+ )
141
+
142
+ out_dict = {}
143
+ for key in out_dicts[0]:
144
+ if key != "pair_chains_iptm":
145
+ out_dict[key] = torch.cat([out[key] for out in out_dicts], dim=0)
146
+ else:
147
+ pair_chains_iptm = {}
148
+ for chain_idx1 in out_dicts[0][key]:
149
+ chains_iptm = {}
150
+ for chain_idx2 in out_dicts[0][key][chain_idx1]:
151
+ chains_iptm[chain_idx2] = torch.cat(
152
+ [out[key][chain_idx1][chain_idx2] for out in out_dicts],
153
+ dim=0,
154
+ )
155
+ pair_chains_iptm[chain_idx1] = chains_iptm
156
+ out_dict[key] = pair_chains_iptm
157
+ return out_dict
158
+
159
+ s_inputs = self.s_inputs_norm(s_inputs)
160
+ if not self.no_update_s:
161
+ s = self.s_norm(s)
162
+
163
+ if self.add_s_input_to_s:
164
+ s = s + self.s_input_to_s(s_inputs)
165
+
166
+ z = self.z_norm(z)
167
+
168
+ if self.add_z_input_to_z:
169
+ relative_position_encoding = self.rel_pos(feats)
170
+ z = z + relative_position_encoding
171
+ z = z + self.token_bonds(feats["token_bonds"].float())
172
+ if self.bond_type_feature:
173
+ z = z + self.token_bonds_type(feats["type_bonds"].long())
174
+ z = z + self.contact_conditioning(feats)
175
+
176
+ s = s.repeat_interleave(multiplicity, 0)
177
+
178
+ z = (
179
+ z
180
+ + self.s_to_z(s_inputs)[:, :, None, :]
181
+ + self.s_to_z_transpose(s_inputs)[:, None, :, :]
182
+ )
183
+ if self.add_s_to_z_prod:
184
+ z = z + self.s_to_z_prod_out(
185
+ self.s_to_z_prod_in1(s_inputs)[:, :, None, :]
186
+ * self.s_to_z_prod_in2(s_inputs)[:, None, :, :]
187
+ )
188
+
189
+ z = z.repeat_interleave(multiplicity, 0)
190
+ s_inputs = s_inputs.repeat_interleave(multiplicity, 0)
191
+
192
+ token_to_rep_atom = feats["token_to_rep_atom"]
193
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
194
+ if len(x_pred.shape) == 4:
195
+ B, mult, N, _ = x_pred.shape
196
+ x_pred = x_pred.reshape(B * mult, N, -1)
197
+ else:
198
+ BM, N, _ = x_pred.shape
199
+ x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred)
200
+ d = torch.cdist(x_pred_repr, x_pred_repr)
201
+ distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
202
+ distogram = self.dist_bin_pairwise_embed(distogram)
203
+ z = z + distogram
204
+
205
+ mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
206
+ pair_mask = mask[:, :, None] * mask[:, None, :]
207
+
208
+ s_t, z_t = self.pairformer_stack(
209
+ s, z, mask=mask, pair_mask=pair_mask, use_kernels=use_kernels
210
+ )
211
+
212
+ # AF3 has residual connections, we remove them
213
+ s = s_t
214
+ z = z_t
215
+
216
+ out_dict = {}
217
+
218
+ if self.return_latent_feats:
219
+ out_dict["s_conf"] = s
220
+ out_dict["z_conf"] = z
221
+
222
+ # confidence heads
223
+ out_dict.update(
224
+ self.confidence_heads(
225
+ s=s,
226
+ z=z,
227
+ x_pred=x_pred,
228
+ d=d,
229
+ feats=feats,
230
+ multiplicity=multiplicity,
231
+ pred_distogram_logits=pred_distogram_logits,
232
+ )
233
+ )
234
+ return out_dict
235
+
236
+
237
+ class ConfidenceHeads(nn.Module):
238
+ def __init__(
239
+ self,
240
+ token_s,
241
+ token_z,
242
+ num_plddt_bins=50,
243
+ num_pde_bins=64,
244
+ num_pae_bins=64,
245
+ token_level_confidence=True,
246
+ use_separate_heads: bool = False,
247
+ **kwargs,
248
+ ):
249
+ super().__init__()
250
+ self.max_num_atoms_per_token = 23
251
+ self.token_level_confidence = token_level_confidence
252
+ self.use_separate_heads = use_separate_heads
253
+
254
+ if self.use_separate_heads:
255
+ self.to_pae_intra_logits = LinearNoBias(token_z, num_pae_bins)
256
+ self.to_pae_inter_logits = LinearNoBias(token_z, num_pae_bins)
257
+ else:
258
+ self.to_pae_logits = LinearNoBias(token_z, num_pae_bins)
259
+
260
+ if self.use_separate_heads:
261
+ self.to_pde_intra_logits = LinearNoBias(token_z, num_pde_bins)
262
+ self.to_pde_inter_logits = LinearNoBias(token_z, num_pde_bins)
263
+ else:
264
+ self.to_pde_logits = LinearNoBias(token_z, num_pde_bins)
265
+
266
+ if self.token_level_confidence:
267
+ self.to_plddt_logits = LinearNoBias(token_s, num_plddt_bins)
268
+ self.to_resolved_logits = LinearNoBias(token_s, 2)
269
+ else:
270
+ self.to_plddt_logits = LinearNoBias(
271
+ token_s, num_plddt_bins * self.max_num_atoms_per_token
272
+ )
273
+ self.to_resolved_logits = LinearNoBias(
274
+ token_s, 2 * self.max_num_atoms_per_token
275
+ )
276
+
277
+ def forward(
278
+ self,
279
+ s, # Float['b n ts']
280
+ z, # Float['b n n tz']
281
+ x_pred, # Float['bm m 3']
282
+ d,
283
+ feats,
284
+ pred_distogram_logits,
285
+ multiplicity=1,
286
+ ):
287
+ if self.use_separate_heads:
288
+ asym_id_token = feats["asym_id"]
289
+ is_same_chain = asym_id_token.unsqueeze(-1) == asym_id_token.unsqueeze(-2)
290
+ is_different_chain = ~is_same_chain
291
+
292
+ if self.use_separate_heads:
293
+ pae_intra_logits = self.to_pae_intra_logits(z)
294
+ pae_intra_logits = pae_intra_logits * is_same_chain.float().unsqueeze(-1)
295
+
296
+ pae_inter_logits = self.to_pae_inter_logits(z)
297
+ pae_inter_logits = pae_inter_logits * is_different_chain.float().unsqueeze(
298
+ -1
299
+ )
300
+
301
+ pae_logits = pae_inter_logits + pae_intra_logits
302
+ else:
303
+ pae_logits = self.to_pae_logits(z)
304
+
305
+ if self.use_separate_heads:
306
+ pde_intra_logits = self.to_pde_intra_logits(z + z.transpose(1, 2))
307
+ pde_intra_logits = pde_intra_logits * is_same_chain.float().unsqueeze(-1)
308
+
309
+ pde_inter_logits = self.to_pde_inter_logits(z + z.transpose(1, 2))
310
+ pde_inter_logits = pde_inter_logits * is_different_chain.float().unsqueeze(
311
+ -1
312
+ )
313
+
314
+ pde_logits = pde_inter_logits + pde_intra_logits
315
+ else:
316
+ pde_logits = self.to_pde_logits(z + z.transpose(1, 2))
317
+ resolved_logits = self.to_resolved_logits(s)
318
+ plddt_logits = self.to_plddt_logits(s)
319
+
320
+ ligand_weight = 20
321
+ non_interface_weight = 1
322
+ interface_weight = 10
323
+
324
+ token_type = feats["mol_type"]
325
+ token_type = token_type.repeat_interleave(multiplicity, 0)
326
+ is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
327
+
328
+ if self.token_level_confidence:
329
+ plddt = compute_aggregated_metric(plddt_logits)
330
+ token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
331
+ complex_plddt = (plddt * token_pad_mask).sum(dim=-1) / token_pad_mask.sum(
332
+ dim=-1
333
+ )
334
+
335
+ is_contact = (d < 8).float()
336
+ is_different_chain = (
337
+ feats["asym_id"].unsqueeze(-1) != feats["asym_id"].unsqueeze(-2)
338
+ ).float()
339
+ is_different_chain = is_different_chain.repeat_interleave(multiplicity, 0)
340
+ token_interface_mask = torch.max(
341
+ is_contact * is_different_chain * (1 - is_ligand_token).unsqueeze(-1),
342
+ dim=-1,
343
+ ).values
344
+ token_non_interface_mask = (1 - token_interface_mask) * (
345
+ 1 - is_ligand_token
346
+ )
347
+ iplddt_weight = (
348
+ is_ligand_token * ligand_weight
349
+ + token_interface_mask * interface_weight
350
+ + token_non_interface_mask * non_interface_weight
351
+ )
352
+ complex_iplddt = (plddt * token_pad_mask * iplddt_weight).sum(
353
+ dim=-1
354
+ ) / torch.sum(token_pad_mask * iplddt_weight, dim=-1)
355
+
356
+ else:
357
+ # token to atom conversion for resolved logits
358
+ B, N, _ = resolved_logits.shape
359
+ resolved_logits = resolved_logits.reshape(
360
+ B, N, self.max_num_atoms_per_token, 2
361
+ )
362
+
363
+ arange_max_num_atoms = (
364
+ torch.arange(self.max_num_atoms_per_token)
365
+ .reshape(1, 1, -1)
366
+ .to(resolved_logits.device)
367
+ )
368
+ max_num_atoms_mask = (
369
+ feats["atom_to_token"].sum(1).unsqueeze(-1) > arange_max_num_atoms
370
+ )
371
+ resolved_logits = resolved_logits[:, max_num_atoms_mask.squeeze(0)]
372
+ resolved_logits = pad(
373
+ resolved_logits,
374
+ (
375
+ 0,
376
+ 0,
377
+ 0,
378
+ int(
379
+ feats["atom_pad_mask"].shape[1]
380
+ - feats["atom_pad_mask"].sum().item()
381
+ ),
382
+ ),
383
+ value=0,
384
+ )
385
+ plddt_logits = plddt_logits.reshape(B, N, self.max_num_atoms_per_token, -1)
386
+ plddt_logits = plddt_logits[:, max_num_atoms_mask.squeeze(0)]
387
+ plddt_logits = pad(
388
+ plddt_logits,
389
+ (
390
+ 0,
391
+ 0,
392
+ 0,
393
+ int(
394
+ feats["atom_pad_mask"].shape[1]
395
+ - feats["atom_pad_mask"].sum().item()
396
+ ),
397
+ ),
398
+ value=0,
399
+ )
400
+ atom_pad_mask = feats["atom_pad_mask"].repeat_interleave(multiplicity, 0)
401
+ plddt = compute_aggregated_metric(plddt_logits)
402
+
403
+ complex_plddt = (plddt * atom_pad_mask).sum(dim=-1) / atom_pad_mask.sum(
404
+ dim=-1
405
+ )
406
+ token_type = feats["mol_type"].float()
407
+ atom_to_token = feats["atom_to_token"].float()
408
+ chain_id_token = feats["asym_id"].float()
409
+ atom_type = torch.bmm(atom_to_token, token_type.unsqueeze(-1)).squeeze(-1)
410
+ is_ligand_atom = (atom_type == const.chain_type_ids["NONPOLYMER"]).float()
411
+ d_atom = torch.cdist(x_pred, x_pred)
412
+ is_contact = (d_atom < 8).float()
413
+ chain_id_atom = torch.bmm(
414
+ atom_to_token, chain_id_token.unsqueeze(-1)
415
+ ).squeeze(-1)
416
+ is_different_chain = (
417
+ chain_id_atom.unsqueeze(-1) != chain_id_atom.unsqueeze(-2)
418
+ ).float()
419
+
420
+ atom_interface_mask = torch.max(
421
+ is_contact * is_different_chain * (1 - is_ligand_atom).unsqueeze(-1),
422
+ dim=-1,
423
+ ).values
424
+ atom_non_interface_mask = (1 - atom_interface_mask) * (1 - is_ligand_atom)
425
+ iplddt_weight = (
426
+ is_ligand_atom * ligand_weight
427
+ + atom_interface_mask * interface_weight
428
+ + atom_non_interface_mask * non_interface_weight
429
+ )
430
+
431
+ complex_iplddt = (plddt * feats["atom_pad_mask"] * iplddt_weight).sum(
432
+ dim=-1
433
+ ) / torch.sum(feats["atom_pad_mask"] * iplddt_weight, dim=-1)
434
+
435
+ # Compute the gPDE and giPDE
436
+ pde = compute_aggregated_metric(pde_logits, end=32)
437
+ pred_distogram_prob = nn.functional.softmax(
438
+ pred_distogram_logits, dim=-1
439
+ ).repeat_interleave(multiplicity, 0)
440
+ contacts = torch.zeros((1, 1, 1, 64), dtype=pred_distogram_prob.dtype).to(
441
+ pred_distogram_prob.device
442
+ )
443
+ contacts[:, :, :, :20] = 1.0
444
+ prob_contact = (pred_distogram_prob * contacts).sum(-1)
445
+ token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
446
+ token_pad_pair_mask = (
447
+ token_pad_mask.unsqueeze(-1)
448
+ * token_pad_mask.unsqueeze(-2)
449
+ * (
450
+ 1
451
+ - torch.eye(
452
+ token_pad_mask.shape[1], device=token_pad_mask.device
453
+ ).unsqueeze(0)
454
+ )
455
+ )
456
+ token_pair_mask = token_pad_pair_mask * prob_contact
457
+ complex_pde = (pde * token_pair_mask).sum(dim=(1, 2)) / token_pair_mask.sum(
458
+ dim=(1, 2)
459
+ )
460
+ asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
461
+ token_interface_pair_mask = token_pair_mask * (
462
+ asym_id.unsqueeze(-1) != asym_id.unsqueeze(-2)
463
+ )
464
+ complex_ipde = (pde * token_interface_pair_mask).sum(dim=(1, 2)) / (
465
+ token_interface_pair_mask.sum(dim=(1, 2)) + 1e-5
466
+ )
467
+ out_dict = dict(
468
+ pde_logits=pde_logits,
469
+ plddt_logits=plddt_logits,
470
+ resolved_logits=resolved_logits,
471
+ pde=pde,
472
+ plddt=plddt,
473
+ complex_plddt=complex_plddt,
474
+ complex_iplddt=complex_iplddt,
475
+ complex_pde=complex_pde,
476
+ complex_ipde=complex_ipde,
477
+ )
478
+ out_dict["pae_logits"] = pae_logits
479
+ out_dict["pae"] = compute_aggregated_metric(pae_logits, end=32)
480
+
481
+ try:
482
+ ptm, iptm, ligand_iptm, protein_iptm, pair_chains_iptm = compute_ptms(
483
+ pae_logits, x_pred, feats, multiplicity
484
+ )
485
+ out_dict["ptm"] = ptm
486
+ out_dict["iptm"] = iptm
487
+ out_dict["ligand_iptm"] = ligand_iptm
488
+ out_dict["protein_iptm"] = protein_iptm
489
+ out_dict["pair_chains_iptm"] = pair_chains_iptm
490
+ except Exception as e:
491
+ print(f"Error in compute_ptms: {e}")
492
+ out_dict["ptm"] = torch.zeros_like(complex_plddt)
493
+ out_dict["iptm"] = torch.zeros_like(complex_plddt)
494
+ out_dict["ligand_iptm"] = torch.zeros_like(complex_plddt)
495
+ out_dict["protein_iptm"] = torch.zeros_like(complex_plddt)
496
+ out_dict["pair_chains_iptm"] = torch.zeros_like(complex_plddt)
497
+
498
+ return out_dict