lhallee commited on
Commit
fc48fbc
·
verified ·
1 Parent(s): 8f4c6ed

Upload vb_modules_encodersv2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vb_modules_encodersv2.py +565 -565
vb_modules_encodersv2.py CHANGED
@@ -1,565 +1,565 @@
1
- # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2
- from functools import partial
3
- from math import pi
4
-
5
- import torch
6
- from einops import rearrange
7
- from torch import nn
8
- from torch.nn import Linear, Module, ModuleList
9
- from torch.nn.functional import one_hot
10
-
11
- from . import vb_layers_initialize as init
12
- from .vb_layers_transition import Transition
13
- from .vb_modules_transformersv2 import AtomTransformer
14
- from .vb_modules_utils import LinearNoBias
15
-
16
-
17
- class FourierEmbedding(Module):
18
- """Algorithm 22."""
19
-
20
- def __init__(self, dim):
21
- super().__init__()
22
- self.proj = nn.Linear(1, dim)
23
- torch.nn.init.normal_(self.proj.weight, mean=0, std=1)
24
- torch.nn.init.normal_(self.proj.bias, mean=0, std=1)
25
- self.proj.requires_grad_(False)
26
-
27
- def forward(
28
- self,
29
- times, # Float[' b'],
30
- ): # -> Float['b d']:
31
- times = rearrange(times, "b -> b 1")
32
- rand_proj = self.proj(times)
33
- return torch.cos(2 * pi * rand_proj)
34
-
35
-
36
- class RelativePositionEncoder(Module):
37
- """Algorithm 3."""
38
-
39
- def __init__(
40
- self, token_z, r_max=32, s_max=2, fix_sym_check=False, cyclic_pos_enc=False
41
- ):
42
- super().__init__()
43
- self.r_max = r_max
44
- self.s_max = s_max
45
- self.linear_layer = LinearNoBias(4 * (r_max + 1) + 2 * (s_max + 1) + 1, token_z)
46
- self.fix_sym_check = fix_sym_check
47
- self.cyclic_pos_enc = cyclic_pos_enc
48
-
49
- def forward(self, feats):
50
- b_same_chain = torch.eq(
51
- feats["asym_id"][:, :, None], feats["asym_id"][:, None, :]
52
- )
53
- b_same_residue = torch.eq(
54
- feats["residue_index"][:, :, None], feats["residue_index"][:, None, :]
55
- )
56
- b_same_entity = torch.eq(
57
- feats["entity_id"][:, :, None], feats["entity_id"][:, None, :]
58
- )
59
-
60
- d_residue = (
61
- feats["residue_index"][:, :, None] - feats["residue_index"][:, None, :]
62
- )
63
-
64
- if self.cyclic_pos_enc and torch.any(feats["cyclic_period"] > 0):
65
- period = torch.where(
66
- feats["cyclic_period"] > 0,
67
- feats["cyclic_period"],
68
- torch.zeros_like(feats["cyclic_period"]) + 10000,
69
- )
70
- d_residue = (d_residue - period * torch.round(d_residue / period)).long()
71
-
72
- d_residue = torch.clip(
73
- d_residue + self.r_max,
74
- 0,
75
- 2 * self.r_max,
76
- )
77
- d_residue = torch.where(
78
- b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1
79
- )
80
- a_rel_pos = one_hot(d_residue, 2 * self.r_max + 2)
81
-
82
- d_token = torch.clip(
83
- feats["token_index"][:, :, None]
84
- - feats["token_index"][:, None, :]
85
- + self.r_max,
86
- 0,
87
- 2 * self.r_max,
88
- )
89
- d_token = torch.where(
90
- b_same_chain & b_same_residue,
91
- d_token,
92
- torch.zeros_like(d_token) + 2 * self.r_max + 1,
93
- )
94
- a_rel_token = one_hot(d_token, 2 * self.r_max + 2)
95
-
96
- d_chain = torch.clip(
97
- feats["sym_id"][:, :, None] - feats["sym_id"][:, None, :] + self.s_max,
98
- 0,
99
- 2 * self.s_max,
100
- )
101
- d_chain = torch.where(
102
- (~b_same_entity) if self.fix_sym_check else b_same_chain,
103
- torch.zeros_like(d_chain) + 2 * self.s_max + 1,
104
- d_chain,
105
- )
106
- # Note: added | (~b_same_entity) based on observation of ProteinX manuscript
107
- a_rel_chain = one_hot(d_chain, 2 * self.s_max + 2)
108
-
109
- p = self.linear_layer(
110
- torch.cat(
111
- [
112
- a_rel_pos.float(),
113
- a_rel_token.float(),
114
- b_same_entity.unsqueeze(-1).float(),
115
- a_rel_chain.float(),
116
- ],
117
- dim=-1,
118
- )
119
- )
120
- return p
121
-
122
-
123
- class SingleConditioning(Module):
124
- """Algorithm 21."""
125
-
126
- def __init__(
127
- self,
128
- sigma_data: float,
129
- token_s: int = 384,
130
- dim_fourier: int = 256,
131
- num_transitions: int = 2,
132
- transition_expansion_factor: int = 2,
133
- eps: float = 1e-20,
134
- disable_times: bool = False,
135
- ) -> None:
136
- super().__init__()
137
- self.eps = eps
138
- self.sigma_data = sigma_data
139
- self.disable_times = disable_times
140
-
141
- self.norm_single = nn.LayerNorm(2 * token_s)
142
- self.single_embed = nn.Linear(2 * token_s, 2 * token_s)
143
- if not self.disable_times:
144
- self.fourier_embed = FourierEmbedding(dim_fourier)
145
- self.norm_fourier = nn.LayerNorm(dim_fourier)
146
- self.fourier_to_single = LinearNoBias(dim_fourier, 2 * token_s)
147
-
148
- transitions = ModuleList([])
149
- for _ in range(num_transitions):
150
- transition = Transition(
151
- dim=2 * token_s, hidden=transition_expansion_factor * 2 * token_s
152
- )
153
- transitions.append(transition)
154
-
155
- self.transitions = transitions
156
-
157
- def forward(
158
- self,
159
- times, # Float[' b'],
160
- s_trunk, # Float['b n ts'],
161
- s_inputs, # Float['b n ts'],
162
- ): # -> Float['b n 2ts']:
163
- s = torch.cat((s_trunk, s_inputs), dim=-1)
164
- s = self.single_embed(self.norm_single(s))
165
- if not self.disable_times:
166
- fourier_embed = self.fourier_embed(
167
- times
168
- ) # note: sigma rescaling done in diffusion module
169
- normed_fourier = self.norm_fourier(fourier_embed)
170
- fourier_to_single = self.fourier_to_single(normed_fourier)
171
-
172
- s = rearrange(fourier_to_single, "b d -> b 1 d") + s
173
-
174
- for transition in self.transitions:
175
- s = transition(s) + s
176
-
177
- return s, normed_fourier if not self.disable_times else None
178
-
179
-
180
- class PairwiseConditioning(Module):
181
- """Algorithm 21."""
182
-
183
- def __init__(
184
- self,
185
- token_z,
186
- dim_token_rel_pos_feats,
187
- num_transitions=2,
188
- transition_expansion_factor=2,
189
- ):
190
- super().__init__()
191
-
192
- self.dim_pairwise_init_proj = nn.Sequential(
193
- nn.LayerNorm(token_z + dim_token_rel_pos_feats),
194
- LinearNoBias(token_z + dim_token_rel_pos_feats, token_z),
195
- )
196
-
197
- transitions = ModuleList([])
198
- for _ in range(num_transitions):
199
- transition = Transition(
200
- dim=token_z, hidden=transition_expansion_factor * token_z
201
- )
202
- transitions.append(transition)
203
-
204
- self.transitions = transitions
205
-
206
- def forward(
207
- self,
208
- z_trunk, # Float['b n n tz'],
209
- token_rel_pos_feats, # Float['b n n 3'],
210
- ): # -> Float['b n n tz']:
211
- z = torch.cat((z_trunk, token_rel_pos_feats), dim=-1)
212
- z = self.dim_pairwise_init_proj(z)
213
-
214
- for transition in self.transitions:
215
- z = transition(z) + z
216
-
217
- return z
218
-
219
-
220
- def get_indexing_matrix(K, W, H, device):
221
- assert W % 2 == 0
222
- assert H % (W // 2) == 0
223
-
224
- h = H // (W // 2)
225
- assert h % 2 == 0
226
-
227
- arange = torch.arange(2 * K, device=device)
228
- index = ((arange.unsqueeze(0) - arange.unsqueeze(1)) + h // 2).clamp(
229
- min=0, max=h + 1
230
- )
231
- index = index.view(K, 2, 2 * K)[:, 0, :]
232
- onehot = one_hot(index, num_classes=h + 2)[..., 1:-1].transpose(1, 0)
233
- return onehot.reshape(2 * K, h * K).float()
234
-
235
-
236
- def single_to_keys(single, indexing_matrix, W, H):
237
- B, N, D = single.shape
238
- K = N // W
239
- single = single.view(B, 2 * K, W // 2, D)
240
- return torch.einsum("b j i d, j k -> b k i d", single, indexing_matrix).reshape(
241
- B, K, H, D
242
- ) # j = 2K, i = W//2, k = h * K
243
-
244
-
245
- class AtomEncoder(Module):
246
- def __init__(
247
- self,
248
- atom_s,
249
- atom_z,
250
- token_s,
251
- token_z,
252
- atoms_per_window_queries,
253
- atoms_per_window_keys,
254
- atom_feature_dim,
255
- structure_prediction=True,
256
- use_no_atom_char=False,
257
- use_atom_backbone_feat=False,
258
- use_residue_feats_atoms=False,
259
- ):
260
- super().__init__()
261
-
262
- self.embed_atom_features = Linear(atom_feature_dim, atom_s)
263
- self.embed_atompair_ref_pos = LinearNoBias(3, atom_z)
264
- self.embed_atompair_ref_dist = LinearNoBias(1, atom_z)
265
- self.embed_atompair_mask = LinearNoBias(1, atom_z)
266
- self.atoms_per_window_queries = atoms_per_window_queries
267
- self.atoms_per_window_keys = atoms_per_window_keys
268
- self.use_no_atom_char = use_no_atom_char
269
- self.use_atom_backbone_feat = use_atom_backbone_feat
270
- self.use_residue_feats_atoms = use_residue_feats_atoms
271
-
272
- self.structure_prediction = structure_prediction
273
- if structure_prediction:
274
- self.s_to_c_trans = nn.Sequential(
275
- nn.LayerNorm(token_s), LinearNoBias(token_s, atom_s)
276
- )
277
- init.final_init_(self.s_to_c_trans[1].weight)
278
-
279
- self.z_to_p_trans = nn.Sequential(
280
- nn.LayerNorm(token_z), LinearNoBias(token_z, atom_z)
281
- )
282
- init.final_init_(self.z_to_p_trans[1].weight)
283
-
284
- self.c_to_p_trans_k = nn.Sequential(
285
- nn.ReLU(),
286
- LinearNoBias(atom_s, atom_z),
287
- )
288
- init.final_init_(self.c_to_p_trans_k[1].weight)
289
-
290
- self.c_to_p_trans_q = nn.Sequential(
291
- nn.ReLU(),
292
- LinearNoBias(atom_s, atom_z),
293
- )
294
- init.final_init_(self.c_to_p_trans_q[1].weight)
295
-
296
- self.p_mlp = nn.Sequential(
297
- nn.ReLU(),
298
- LinearNoBias(atom_z, atom_z),
299
- nn.ReLU(),
300
- LinearNoBias(atom_z, atom_z),
301
- nn.ReLU(),
302
- LinearNoBias(atom_z, atom_z),
303
- )
304
- init.final_init_(self.p_mlp[5].weight)
305
-
306
- def forward(
307
- self,
308
- feats,
309
- s_trunk=None, # Float['bm n ts'],
310
- z=None, # Float['bm n n tz'],
311
- ):
312
- with torch.autocast("cuda", enabled=False):
313
- B, N, _ = feats["ref_pos"].shape
314
- atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'],
315
-
316
- atom_ref_pos = feats["ref_pos"] # Float['b m 3'],
317
- atom_uid = feats["ref_space_uid"] # Long['b m'],
318
-
319
- atom_feats = [
320
- atom_ref_pos,
321
- feats["ref_charge"].unsqueeze(-1),
322
- feats["ref_element"],
323
- ]
324
- if not self.use_no_atom_char:
325
- atom_feats.append(feats["ref_atom_name_chars"].reshape(B, N, 4 * 64))
326
- if self.use_atom_backbone_feat:
327
- atom_feats.append(feats["atom_backbone_feat"])
328
- if self.use_residue_feats_atoms:
329
- res_feats = torch.cat(
330
- [
331
- feats["res_type"],
332
- feats["modified"].unsqueeze(-1),
333
- one_hot(feats["mol_type"], num_classes=4).float(),
334
- ],
335
- dim=-1,
336
- )
337
- atom_to_token = feats["atom_to_token"].float()
338
- atom_res_feats = torch.bmm(atom_to_token, res_feats)
339
- atom_feats.append(atom_res_feats)
340
-
341
- atom_feats = torch.cat(atom_feats, dim=-1)
342
-
343
- c = self.embed_atom_features(atom_feats)
344
-
345
- # note we are already creating the windows to make it more efficient
346
- W, H = self.atoms_per_window_queries, self.atoms_per_window_keys
347
- B, N = c.shape[:2]
348
- K = N // W
349
- keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device)
350
- to_keys = partial(
351
- single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H
352
- )
353
-
354
- atom_ref_pos_queries = atom_ref_pos.view(B, K, W, 1, 3)
355
- atom_ref_pos_keys = to_keys(atom_ref_pos).view(B, K, 1, H, 3)
356
-
357
- d = atom_ref_pos_keys - atom_ref_pos_queries # Float['b k w h 3']
358
- d_norm = torch.sum(d * d, dim=-1, keepdim=True) # Float['b k w h 1']
359
- d_norm = 1 / (
360
- 1 + d_norm
361
- ) # AF3 feeds in the reciprocal of the distance norm
362
-
363
- atom_mask_queries = atom_mask.view(B, K, W, 1)
364
- atom_mask_keys = (
365
- to_keys(atom_mask.unsqueeze(-1).float()).view(B, K, 1, H).bool()
366
- )
367
- atom_uid_queries = atom_uid.view(B, K, W, 1)
368
- atom_uid_keys = (
369
- to_keys(atom_uid.unsqueeze(-1).float()).view(B, K, 1, H).long()
370
- )
371
- v = (
372
- (
373
- atom_mask_queries
374
- & atom_mask_keys
375
- & (atom_uid_queries == atom_uid_keys)
376
- )
377
- .float()
378
- .unsqueeze(-1)
379
- ) # Bool['b k w h 1']
380
-
381
- p = self.embed_atompair_ref_pos(d) * v
382
- p = p + self.embed_atompair_ref_dist(d_norm) * v
383
- p = p + self.embed_atompair_mask(v) * v
384
-
385
- q = c
386
-
387
- if self.structure_prediction:
388
- # run only in structure model not in initial encoding
389
- atom_to_token = feats["atom_to_token"].float() # Long['b m n'],
390
-
391
- s_to_c = self.s_to_c_trans(s_trunk.float())
392
- s_to_c = torch.bmm(atom_to_token, s_to_c)
393
- c = c + s_to_c.to(c)
394
-
395
- atom_to_token_queries = atom_to_token.view(
396
- B, K, W, atom_to_token.shape[-1]
397
- )
398
- atom_to_token_keys = to_keys(atom_to_token)
399
- z_to_p = self.z_to_p_trans(z.float())
400
- z_to_p = torch.einsum(
401
- "bijd,bwki,bwlj->bwkld",
402
- z_to_p,
403
- atom_to_token_queries,
404
- atom_to_token_keys,
405
- )
406
- p = p + z_to_p.to(p)
407
-
408
- p = p + self.c_to_p_trans_q(c.view(B, K, W, 1, c.shape[-1]))
409
- p = p + self.c_to_p_trans_k(to_keys(c).view(B, K, 1, H, c.shape[-1]))
410
- p = p + self.p_mlp(p)
411
- return q, c, p, to_keys
412
-
413
-
414
- class AtomAttentionEncoder(Module):
415
- def __init__(
416
- self,
417
- atom_s,
418
- token_s,
419
- atoms_per_window_queries,
420
- atoms_per_window_keys,
421
- atom_encoder_depth=3,
422
- atom_encoder_heads=4,
423
- structure_prediction=True,
424
- activation_checkpointing=False,
425
- transformer_post_layer_norm=False,
426
- ):
427
- super().__init__()
428
-
429
- self.structure_prediction = structure_prediction
430
- if structure_prediction:
431
- self.r_to_q_trans = LinearNoBias(3, atom_s)
432
- init.final_init_(self.r_to_q_trans.weight)
433
-
434
- self.atom_encoder = AtomTransformer(
435
- dim=atom_s,
436
- dim_single_cond=atom_s,
437
- attn_window_queries=atoms_per_window_queries,
438
- attn_window_keys=atoms_per_window_keys,
439
- depth=atom_encoder_depth,
440
- heads=atom_encoder_heads,
441
- activation_checkpointing=activation_checkpointing,
442
- post_layer_norm=transformer_post_layer_norm,
443
- )
444
-
445
- self.atom_to_token_trans = nn.Sequential(
446
- LinearNoBias(atom_s, 2 * token_s if structure_prediction else token_s),
447
- nn.ReLU(),
448
- )
449
-
450
- def forward(
451
- self,
452
- feats,
453
- q,
454
- c,
455
- atom_enc_bias,
456
- to_keys,
457
- r=None, # Float['bm m 3'],
458
- multiplicity=1,
459
- ):
460
- B, N, _ = feats["ref_pos"].shape
461
- atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'],
462
-
463
- if self.structure_prediction:
464
- # only here the multiplicity kicks in because we use the different positions r
465
- q = q.repeat_interleave(multiplicity, 0)
466
- r_to_q = self.r_to_q_trans(r)
467
- q = q + r_to_q
468
-
469
- c = c.repeat_interleave(multiplicity, 0)
470
- atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
471
-
472
- q = self.atom_encoder(
473
- q=q,
474
- mask=atom_mask,
475
- c=c,
476
- bias=atom_enc_bias,
477
- multiplicity=multiplicity,
478
- to_keys=to_keys,
479
- )
480
-
481
- with torch.autocast("cuda", enabled=False):
482
- q_to_a = self.atom_to_token_trans(q).float()
483
- atom_to_token = feats["atom_to_token"].float()
484
- atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
485
- atom_to_token_mean = atom_to_token / (
486
- atom_to_token.sum(dim=1, keepdim=True) + 1e-6
487
- )
488
- a = torch.bmm(atom_to_token_mean.transpose(1, 2), q_to_a)
489
-
490
- a = a.to(q)
491
-
492
- return a, q, c, to_keys
493
-
494
-
495
- class AtomAttentionDecoder(Module):
496
- """Algorithm 6."""
497
-
498
- def __init__(
499
- self,
500
- atom_s,
501
- token_s,
502
- attn_window_queries,
503
- attn_window_keys,
504
- atom_decoder_depth=3,
505
- atom_decoder_heads=4,
506
- activation_checkpointing=False,
507
- transformer_post_layer_norm=False,
508
- ):
509
- super().__init__()
510
-
511
- self.a_to_q_trans = LinearNoBias(2 * token_s, atom_s)
512
- init.final_init_(self.a_to_q_trans.weight)
513
-
514
- self.atom_decoder = AtomTransformer(
515
- dim=atom_s,
516
- dim_single_cond=atom_s,
517
- attn_window_queries=attn_window_queries,
518
- attn_window_keys=attn_window_keys,
519
- depth=atom_decoder_depth,
520
- heads=atom_decoder_heads,
521
- activation_checkpointing=activation_checkpointing,
522
- post_layer_norm=transformer_post_layer_norm,
523
- )
524
-
525
- if transformer_post_layer_norm:
526
- self.atom_feat_to_atom_pos_update = LinearNoBias(atom_s, 3)
527
- init.final_init_(self.atom_feat_to_atom_pos_update.weight)
528
- else:
529
- self.atom_feat_to_atom_pos_update = nn.Sequential(
530
- nn.LayerNorm(atom_s), LinearNoBias(atom_s, 3)
531
- )
532
- init.final_init_(self.atom_feat_to_atom_pos_update[1].weight)
533
-
534
- def forward(
535
- self,
536
- a, # Float['bm n 2ts'],
537
- q, # Float['bm m as'],
538
- c, # Float['bm m as'],
539
- atom_dec_bias, # Float['bm m m az'],
540
- feats,
541
- to_keys,
542
- multiplicity=1,
543
- ):
544
- with torch.autocast("cuda", enabled=False):
545
- atom_to_token = feats["atom_to_token"].float()
546
- atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
547
-
548
- a_to_q = self.a_to_q_trans(a.float())
549
- a_to_q = torch.bmm(atom_to_token, a_to_q)
550
-
551
- q = q + a_to_q.to(q)
552
- atom_mask = feats["atom_pad_mask"] # Bool['b m'],
553
- atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
554
-
555
- q = self.atom_decoder(
556
- q=q,
557
- mask=atom_mask,
558
- c=c,
559
- bias=atom_dec_bias,
560
- multiplicity=multiplicity,
561
- to_keys=to_keys,
562
- )
563
-
564
- r_update = self.atom_feat_to_atom_pos_update(q)
565
- return r_update
 
1
+ # started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
2
+ from functools import partial
3
+ from math import pi
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import nn
8
+ from torch.nn import Linear, Module, ModuleList
9
+ from torch.nn.functional import one_hot
10
+
11
+ from . import vb_layers_initialize as init
12
+ from .vb_layers_transition import Transition
13
+ from .vb_modules_transformersv2 import AtomTransformer
14
+ from .vb_modules_utils import LinearNoBias
15
+
16
+
17
+ class FourierEmbedding(Module):
18
+ """Algorithm 22."""
19
+
20
+ def __init__(self, dim):
21
+ super().__init__()
22
+ self.proj = nn.Linear(1, dim)
23
+ torch.nn.init.normal_(self.proj.weight, mean=0, std=1)
24
+ torch.nn.init.normal_(self.proj.bias, mean=0, std=1)
25
+ self.proj.requires_grad_(False)
26
+
27
+ def forward(
28
+ self,
29
+ times, # Float[' b'],
30
+ ): # -> Float['b d']:
31
+ times = rearrange(times, "b -> b 1")
32
+ rand_proj = self.proj(times)
33
+ return torch.cos(2 * pi * rand_proj)
34
+
35
+
36
+ class RelativePositionEncoder(Module):
37
+ """Algorithm 3."""
38
+
39
+ def __init__(
40
+ self, token_z, r_max=32, s_max=2, fix_sym_check=False, cyclic_pos_enc=False
41
+ ):
42
+ super().__init__()
43
+ self.r_max = r_max
44
+ self.s_max = s_max
45
+ self.linear_layer = LinearNoBias(4 * (r_max + 1) + 2 * (s_max + 1) + 1, token_z)
46
+ self.fix_sym_check = fix_sym_check
47
+ self.cyclic_pos_enc = cyclic_pos_enc
48
+
49
+ def forward(self, feats):
50
+ b_same_chain = torch.eq(
51
+ feats["asym_id"][:, :, None], feats["asym_id"][:, None, :]
52
+ )
53
+ b_same_residue = torch.eq(
54
+ feats["residue_index"][:, :, None], feats["residue_index"][:, None, :]
55
+ )
56
+ b_same_entity = torch.eq(
57
+ feats["entity_id"][:, :, None], feats["entity_id"][:, None, :]
58
+ )
59
+
60
+ d_residue = (
61
+ feats["residue_index"][:, :, None] - feats["residue_index"][:, None, :]
62
+ )
63
+
64
+ if self.cyclic_pos_enc and torch.any(feats["cyclic_period"] > 0):
65
+ period = torch.where(
66
+ feats["cyclic_period"] > 0,
67
+ feats["cyclic_period"],
68
+ torch.zeros_like(feats["cyclic_period"]) + 10000,
69
+ )
70
+ d_residue = (d_residue - period * torch.round(d_residue / period)).long()
71
+
72
+ d_residue = torch.clip(
73
+ d_residue + self.r_max,
74
+ 0,
75
+ 2 * self.r_max,
76
+ )
77
+ d_residue = torch.where(
78
+ b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1
79
+ )
80
+ a_rel_pos = one_hot(d_residue, 2 * self.r_max + 2)
81
+
82
+ d_token = torch.clip(
83
+ feats["token_index"][:, :, None]
84
+ - feats["token_index"][:, None, :]
85
+ + self.r_max,
86
+ 0,
87
+ 2 * self.r_max,
88
+ )
89
+ d_token = torch.where(
90
+ b_same_chain & b_same_residue,
91
+ d_token,
92
+ torch.zeros_like(d_token) + 2 * self.r_max + 1,
93
+ )
94
+ a_rel_token = one_hot(d_token, 2 * self.r_max + 2)
95
+
96
+ d_chain = torch.clip(
97
+ feats["sym_id"][:, :, None] - feats["sym_id"][:, None, :] + self.s_max,
98
+ 0,
99
+ 2 * self.s_max,
100
+ )
101
+ d_chain = torch.where(
102
+ (~b_same_entity) if self.fix_sym_check else b_same_chain,
103
+ torch.zeros_like(d_chain) + 2 * self.s_max + 1,
104
+ d_chain,
105
+ )
106
+ # Note: added | (~b_same_entity) based on observation of ProteinX manuscript
107
+ a_rel_chain = one_hot(d_chain, 2 * self.s_max + 2)
108
+
109
+ p = self.linear_layer(
110
+ torch.cat(
111
+ [
112
+ a_rel_pos.float(),
113
+ a_rel_token.float(),
114
+ b_same_entity.unsqueeze(-1).float(),
115
+ a_rel_chain.float(),
116
+ ],
117
+ dim=-1,
118
+ )
119
+ )
120
+ return p
121
+
122
+
123
+ class SingleConditioning(Module):
124
+ """Algorithm 21."""
125
+
126
+ def __init__(
127
+ self,
128
+ sigma_data: float,
129
+ token_s: int = 384,
130
+ dim_fourier: int = 256,
131
+ num_transitions: int = 2,
132
+ transition_expansion_factor: int = 2,
133
+ eps: float = 1e-20,
134
+ disable_times: bool = False,
135
+ ) -> None:
136
+ super().__init__()
137
+ self.eps = eps
138
+ self.sigma_data = sigma_data
139
+ self.disable_times = disable_times
140
+
141
+ self.norm_single = nn.LayerNorm(2 * token_s)
142
+ self.single_embed = nn.Linear(2 * token_s, 2 * token_s)
143
+ if not self.disable_times:
144
+ self.fourier_embed = FourierEmbedding(dim_fourier)
145
+ self.norm_fourier = nn.LayerNorm(dim_fourier)
146
+ self.fourier_to_single = LinearNoBias(dim_fourier, 2 * token_s)
147
+
148
+ transitions = ModuleList([])
149
+ for _ in range(num_transitions):
150
+ transition = Transition(
151
+ dim=2 * token_s, hidden=transition_expansion_factor * 2 * token_s
152
+ )
153
+ transitions.append(transition)
154
+
155
+ self.transitions = transitions
156
+
157
+ def forward(
158
+ self,
159
+ times, # Float[' b'],
160
+ s_trunk, # Float['b n ts'],
161
+ s_inputs, # Float['b n ts'],
162
+ ): # -> Float['b n 2ts']:
163
+ s = torch.cat((s_trunk, s_inputs), dim=-1)
164
+ s = self.single_embed(self.norm_single(s))
165
+ if not self.disable_times:
166
+ fourier_embed = self.fourier_embed(
167
+ times
168
+ ) # note: sigma rescaling done in diffusion module
169
+ normed_fourier = self.norm_fourier(fourier_embed)
170
+ fourier_to_single = self.fourier_to_single(normed_fourier)
171
+
172
+ s = rearrange(fourier_to_single, "b d -> b 1 d") + s
173
+
174
+ for transition in self.transitions:
175
+ s = transition(s) + s
176
+
177
+ return s, normed_fourier if not self.disable_times else None
178
+
179
+
180
+ class PairwiseConditioning(Module):
181
+ """Algorithm 21."""
182
+
183
+ def __init__(
184
+ self,
185
+ token_z,
186
+ dim_token_rel_pos_feats,
187
+ num_transitions=2,
188
+ transition_expansion_factor=2,
189
+ ):
190
+ super().__init__()
191
+
192
+ self.dim_pairwise_init_proj = nn.Sequential(
193
+ nn.LayerNorm(token_z + dim_token_rel_pos_feats),
194
+ LinearNoBias(token_z + dim_token_rel_pos_feats, token_z),
195
+ )
196
+
197
+ transitions = ModuleList([])
198
+ for _ in range(num_transitions):
199
+ transition = Transition(
200
+ dim=token_z, hidden=transition_expansion_factor * token_z
201
+ )
202
+ transitions.append(transition)
203
+
204
+ self.transitions = transitions
205
+
206
+ def forward(
207
+ self,
208
+ z_trunk, # Float['b n n tz'],
209
+ token_rel_pos_feats, # Float['b n n 3'],
210
+ ): # -> Float['b n n tz']:
211
+ z = torch.cat((z_trunk, token_rel_pos_feats), dim=-1)
212
+ z = self.dim_pairwise_init_proj(z)
213
+
214
+ for transition in self.transitions:
215
+ z = transition(z) + z
216
+
217
+ return z
218
+
219
+
220
+ def get_indexing_matrix(K, W, H, device):
221
+ assert W % 2 == 0
222
+ assert H % (W // 2) == 0
223
+
224
+ h = H // (W // 2)
225
+ assert h % 2 == 0
226
+
227
+ arange = torch.arange(2 * K, device=device)
228
+ index = ((arange.unsqueeze(0) - arange.unsqueeze(1)) + h // 2).clamp(
229
+ min=0, max=h + 1
230
+ )
231
+ index = index.view(K, 2, 2 * K)[:, 0, :]
232
+ onehot = one_hot(index, num_classes=h + 2)[..., 1:-1].transpose(1, 0)
233
+ return onehot.reshape(2 * K, h * K).float()
234
+
235
+
236
+ def single_to_keys(single, indexing_matrix, W, H):
237
+ B, N, D = single.shape
238
+ K = N // W
239
+ single = single.view(B, 2 * K, W // 2, D)
240
+ return torch.einsum("b j i d, j k -> b k i d", single, indexing_matrix).reshape(
241
+ B, K, H, D
242
+ ) # j = 2K, i = W//2, k = h * K
243
+
244
+
245
+ class AtomEncoder(Module):
246
+ def __init__(
247
+ self,
248
+ atom_s,
249
+ atom_z,
250
+ token_s,
251
+ token_z,
252
+ atoms_per_window_queries,
253
+ atoms_per_window_keys,
254
+ atom_feature_dim,
255
+ structure_prediction=True,
256
+ use_no_atom_char=False,
257
+ use_atom_backbone_feat=False,
258
+ use_residue_feats_atoms=False,
259
+ ):
260
+ super().__init__()
261
+
262
+ self.embed_atom_features = Linear(atom_feature_dim, atom_s)
263
+ self.embed_atompair_ref_pos = LinearNoBias(3, atom_z)
264
+ self.embed_atompair_ref_dist = LinearNoBias(1, atom_z)
265
+ self.embed_atompair_mask = LinearNoBias(1, atom_z)
266
+ self.atoms_per_window_queries = atoms_per_window_queries
267
+ self.atoms_per_window_keys = atoms_per_window_keys
268
+ self.use_no_atom_char = use_no_atom_char
269
+ self.use_atom_backbone_feat = use_atom_backbone_feat
270
+ self.use_residue_feats_atoms = use_residue_feats_atoms
271
+
272
+ self.structure_prediction = structure_prediction
273
+ if structure_prediction:
274
+ self.s_to_c_trans = nn.Sequential(
275
+ nn.LayerNorm(token_s), LinearNoBias(token_s, atom_s)
276
+ )
277
+ init.final_init_(self.s_to_c_trans[1].weight)
278
+
279
+ self.z_to_p_trans = nn.Sequential(
280
+ nn.LayerNorm(token_z), LinearNoBias(token_z, atom_z)
281
+ )
282
+ init.final_init_(self.z_to_p_trans[1].weight)
283
+
284
+ self.c_to_p_trans_k = nn.Sequential(
285
+ nn.ReLU(),
286
+ LinearNoBias(atom_s, atom_z),
287
+ )
288
+ init.final_init_(self.c_to_p_trans_k[1].weight)
289
+
290
+ self.c_to_p_trans_q = nn.Sequential(
291
+ nn.ReLU(),
292
+ LinearNoBias(atom_s, atom_z),
293
+ )
294
+ init.final_init_(self.c_to_p_trans_q[1].weight)
295
+
296
+ self.p_mlp = nn.Sequential(
297
+ nn.ReLU(),
298
+ LinearNoBias(atom_z, atom_z),
299
+ nn.ReLU(),
300
+ LinearNoBias(atom_z, atom_z),
301
+ nn.ReLU(),
302
+ LinearNoBias(atom_z, atom_z),
303
+ )
304
+ init.final_init_(self.p_mlp[5].weight)
305
+
306
+ def forward(
307
+ self,
308
+ feats,
309
+ s_trunk=None, # Float['bm n ts'],
310
+ z=None, # Float['bm n n tz'],
311
+ ):
312
+ with torch.autocast("cuda", enabled=False):
313
+ B, N, _ = feats["ref_pos"].shape
314
+ atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'],
315
+
316
+ atom_ref_pos = feats["ref_pos"] # Float['b m 3'],
317
+ atom_uid = feats["ref_space_uid"] # Long['b m'],
318
+
319
+ atom_feats = [
320
+ atom_ref_pos,
321
+ feats["ref_charge"].unsqueeze(-1),
322
+ feats["ref_element"],
323
+ ]
324
+ if not self.use_no_atom_char:
325
+ atom_feats.append(feats["ref_atom_name_chars"].reshape(B, N, 4 * 64))
326
+ if self.use_atom_backbone_feat:
327
+ atom_feats.append(feats["atom_backbone_feat"])
328
+ if self.use_residue_feats_atoms:
329
+ res_feats = torch.cat(
330
+ [
331
+ feats["res_type"],
332
+ feats["modified"].unsqueeze(-1),
333
+ one_hot(feats["mol_type"], num_classes=4).float(),
334
+ ],
335
+ dim=-1,
336
+ )
337
+ atom_to_token = feats["atom_to_token"].float()
338
+ atom_res_feats = torch.bmm(atom_to_token, res_feats)
339
+ atom_feats.append(atom_res_feats)
340
+
341
+ atom_feats = torch.cat(atom_feats, dim=-1)
342
+
343
+ c = self.embed_atom_features(atom_feats)
344
+
345
+ # note we are already creating the windows to make it more efficient
346
+ W, H = self.atoms_per_window_queries, self.atoms_per_window_keys
347
+ B, N = c.shape[:2]
348
+ K = N // W
349
+ keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device)
350
+ to_keys = partial(
351
+ single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H
352
+ )
353
+
354
+ atom_ref_pos_queries = atom_ref_pos.view(B, K, W, 1, 3)
355
+ atom_ref_pos_keys = to_keys(atom_ref_pos).view(B, K, 1, H, 3)
356
+
357
+ d = atom_ref_pos_keys - atom_ref_pos_queries # Float['b k w h 3']
358
+ d_norm = torch.sum(d * d, dim=-1, keepdim=True) # Float['b k w h 1']
359
+ d_norm = 1 / (
360
+ 1 + d_norm
361
+ ) # AF3 feeds in the reciprocal of the distance norm
362
+
363
+ atom_mask_queries = atom_mask.view(B, K, W, 1)
364
+ atom_mask_keys = (
365
+ to_keys(atom_mask.unsqueeze(-1).float()).view(B, K, 1, H).bool()
366
+ )
367
+ atom_uid_queries = atom_uid.view(B, K, W, 1)
368
+ atom_uid_keys = (
369
+ to_keys(atom_uid.unsqueeze(-1).float()).view(B, K, 1, H).long()
370
+ )
371
+ v = (
372
+ (
373
+ atom_mask_queries
374
+ & atom_mask_keys
375
+ & (atom_uid_queries == atom_uid_keys)
376
+ )
377
+ .float()
378
+ .unsqueeze(-1)
379
+ ) # Bool['b k w h 1']
380
+
381
+ p = self.embed_atompair_ref_pos(d) * v
382
+ p = p + self.embed_atompair_ref_dist(d_norm) * v
383
+ p = p + self.embed_atompair_mask(v) * v
384
+
385
+ q = c
386
+
387
+ if self.structure_prediction:
388
+ # run only in structure model not in initial encoding
389
+ atom_to_token = feats["atom_to_token"].float() # Long['b m n'],
390
+
391
+ s_to_c = self.s_to_c_trans(s_trunk.float())
392
+ s_to_c = torch.bmm(atom_to_token, s_to_c)
393
+ c = c + s_to_c.to(c)
394
+
395
+ atom_to_token_queries = atom_to_token.view(
396
+ B, K, W, atom_to_token.shape[-1]
397
+ )
398
+ atom_to_token_keys = to_keys(atom_to_token)
399
+ z_to_p = self.z_to_p_trans(z.float())
400
+ z_to_p = torch.einsum(
401
+ "bijd,bwki,bwlj->bwkld",
402
+ z_to_p,
403
+ atom_to_token_queries,
404
+ atom_to_token_keys,
405
+ )
406
+ p = p + z_to_p.to(p)
407
+
408
+ p = p + self.c_to_p_trans_q(c.view(B, K, W, 1, c.shape[-1]))
409
+ p = p + self.c_to_p_trans_k(to_keys(c).view(B, K, 1, H, c.shape[-1]))
410
+ p = p + self.p_mlp(p)
411
+ return q, c, p, to_keys
412
+
413
+
414
+ class AtomAttentionEncoder(Module):
415
+ def __init__(
416
+ self,
417
+ atom_s,
418
+ token_s,
419
+ atoms_per_window_queries,
420
+ atoms_per_window_keys,
421
+ atom_encoder_depth=3,
422
+ atom_encoder_heads=4,
423
+ structure_prediction=True,
424
+ activation_checkpointing=False,
425
+ transformer_post_layer_norm=False,
426
+ ):
427
+ super().__init__()
428
+
429
+ self.structure_prediction = structure_prediction
430
+ if structure_prediction:
431
+ self.r_to_q_trans = LinearNoBias(3, atom_s)
432
+ init.final_init_(self.r_to_q_trans.weight)
433
+
434
+ self.atom_encoder = AtomTransformer(
435
+ dim=atom_s,
436
+ dim_single_cond=atom_s,
437
+ attn_window_queries=atoms_per_window_queries,
438
+ attn_window_keys=atoms_per_window_keys,
439
+ depth=atom_encoder_depth,
440
+ heads=atom_encoder_heads,
441
+ activation_checkpointing=activation_checkpointing,
442
+ post_layer_norm=transformer_post_layer_norm,
443
+ )
444
+
445
+ self.atom_to_token_trans = nn.Sequential(
446
+ LinearNoBias(atom_s, 2 * token_s if structure_prediction else token_s),
447
+ nn.ReLU(),
448
+ )
449
+
450
+ def forward(
451
+ self,
452
+ feats,
453
+ q,
454
+ c,
455
+ atom_enc_bias,
456
+ to_keys,
457
+ r=None, # Float['bm m 3'],
458
+ multiplicity=1,
459
+ ):
460
+ B, N, _ = feats["ref_pos"].shape
461
+ atom_mask = feats["atom_pad_mask"].bool() # Bool['b m'],
462
+
463
+ if self.structure_prediction:
464
+ # only here the multiplicity kicks in because we use the different positions r
465
+ q = q.repeat_interleave(multiplicity, 0)
466
+ r_to_q = self.r_to_q_trans(r)
467
+ q = q + r_to_q
468
+
469
+ c = c.repeat_interleave(multiplicity, 0)
470
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
471
+
472
+ q = self.atom_encoder(
473
+ q=q,
474
+ mask=atom_mask,
475
+ c=c,
476
+ bias=atom_enc_bias,
477
+ multiplicity=multiplicity,
478
+ to_keys=to_keys,
479
+ )
480
+
481
+ with torch.autocast("cuda", enabled=False):
482
+ q_to_a = self.atom_to_token_trans(q).float()
483
+ atom_to_token = feats["atom_to_token"].float()
484
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
485
+ atom_to_token_mean = atom_to_token / (
486
+ atom_to_token.sum(dim=1, keepdim=True) + 1e-6
487
+ )
488
+ a = torch.bmm(atom_to_token_mean.transpose(1, 2), q_to_a)
489
+
490
+ a = a.to(q)
491
+
492
+ return a, q, c, to_keys
493
+
494
+
495
+ class AtomAttentionDecoder(Module):
496
+ """Algorithm 6."""
497
+
498
+ def __init__(
499
+ self,
500
+ atom_s,
501
+ token_s,
502
+ attn_window_queries,
503
+ attn_window_keys,
504
+ atom_decoder_depth=3,
505
+ atom_decoder_heads=4,
506
+ activation_checkpointing=False,
507
+ transformer_post_layer_norm=False,
508
+ ):
509
+ super().__init__()
510
+
511
+ self.a_to_q_trans = LinearNoBias(2 * token_s, atom_s)
512
+ init.final_init_(self.a_to_q_trans.weight)
513
+
514
+ self.atom_decoder = AtomTransformer(
515
+ dim=atom_s,
516
+ dim_single_cond=atom_s,
517
+ attn_window_queries=attn_window_queries,
518
+ attn_window_keys=attn_window_keys,
519
+ depth=atom_decoder_depth,
520
+ heads=atom_decoder_heads,
521
+ activation_checkpointing=activation_checkpointing,
522
+ post_layer_norm=transformer_post_layer_norm,
523
+ )
524
+
525
+ if transformer_post_layer_norm:
526
+ self.atom_feat_to_atom_pos_update = LinearNoBias(atom_s, 3)
527
+ init.final_init_(self.atom_feat_to_atom_pos_update.weight)
528
+ else:
529
+ self.atom_feat_to_atom_pos_update = nn.Sequential(
530
+ nn.LayerNorm(atom_s), LinearNoBias(atom_s, 3)
531
+ )
532
+ init.final_init_(self.atom_feat_to_atom_pos_update[1].weight)
533
+
534
+ def forward(
535
+ self,
536
+ a, # Float['bm n 2ts'],
537
+ q, # Float['bm m as'],
538
+ c, # Float['bm m as'],
539
+ atom_dec_bias, # Float['bm m m az'],
540
+ feats,
541
+ to_keys,
542
+ multiplicity=1,
543
+ ):
544
+ with torch.autocast("cuda", enabled=False):
545
+ atom_to_token = feats["atom_to_token"].float()
546
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
547
+
548
+ a_to_q = self.a_to_q_trans(a.float())
549
+ a_to_q = torch.bmm(atom_to_token, a_to_q)
550
+
551
+ q = q + a_to_q.to(q)
552
+ atom_mask = feats["atom_pad_mask"] # Bool['b m'],
553
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
554
+
555
+ q = self.atom_decoder(
556
+ q=q,
557
+ mask=atom_mask,
558
+ c=c,
559
+ bias=atom_dec_bias,
560
+ multiplicity=multiplicity,
561
+ to_keys=to_keys,
562
+ )
563
+
564
+ r_update = self.atom_feat_to_atom_pos_update(q)
565
+ return r_update