dn6 HF Staff commited on
Commit
7e6b4cd
·
verified ·
1 Parent(s): 288a16c

Upload transformer/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. transformer/model.py +1183 -0
transformer/model.py ADDED
@@ -0,0 +1,1183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Dhruv Nair. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ RFDiffusion3 Transformer model.
17
+
18
+ This module provides a diffusers-compatible implementation of the RFD3
19
+ architecture for protein structure prediction and generation. The module
20
+ structure matches the foundry checkpoint format for direct weight loading.
21
+ """
22
+
23
+ import math
24
+ from dataclasses import dataclass
25
+ from functools import partial
26
+ from typing import Optional, Tuple
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+
35
+
36
+ @dataclass
37
+ class RFDiffusionTransformerOutput:
38
+ """Output class for RFDiffusion transformer."""
39
+
40
+ xyz: torch.Tensor
41
+ single: torch.Tensor
42
+ pair: torch.Tensor
43
+ sequence_logits: Optional[torch.Tensor] = None
44
+ sequence_indices: Optional[torch.Tensor] = None
45
+
46
+
47
+ linearNoBias = partial(nn.Linear, bias=False)
48
+
49
+
50
+ class RMSNorm(nn.Module):
51
+ """Root Mean Square Layer Normalization."""
52
+
53
+ def __init__(self, dim: int, eps: float = 1e-6):
54
+ super().__init__()
55
+ self.eps = eps
56
+ self.weight = nn.Parameter(torch.ones(dim))
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
60
+ return x / rms * self.weight
61
+
62
+
63
+ class FourierEmbedding(nn.Module):
64
+ """Fourier feature embedding for timesteps with learned weights."""
65
+
66
+ def __init__(self, c: int):
67
+ super().__init__()
68
+ self.c = c
69
+ self.register_buffer("w", torch.zeros(c, dtype=torch.float32))
70
+ self.register_buffer("b", torch.zeros(c, dtype=torch.float32))
71
+ self.reset_parameters()
72
+
73
+ def reset_parameters(self) -> None:
74
+ nn.init.normal_(self.w)
75
+ nn.init.normal_(self.b)
76
+
77
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
78
+ return torch.cos(2 * math.pi * (t[..., None] * self.w + self.b))
79
+
80
+
81
+ class LinearBiasInit(nn.Linear):
82
+ """Linear layer with custom bias initialization."""
83
+
84
+ def __init__(self, *args, biasinit: float = -2.0, **kwargs):
85
+ self.biasinit = biasinit
86
+ super().__init__(*args, **kwargs)
87
+
88
+ def reset_parameters(self) -> None:
89
+ super().reset_parameters()
90
+ if self.bias is not None:
91
+ self.bias.data.fill_(self.biasinit)
92
+
93
+
94
+ class RMSNormNoWeight(nn.Module):
95
+ """RMSNorm without learnable weight (elementwise_affine=False)."""
96
+
97
+ def __init__(self, dim: int, eps: float = 1e-6):
98
+ super().__init__()
99
+ self.eps = eps
100
+ self.dim = dim
101
+
102
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
104
+ return x / rms
105
+
106
+
107
+ class AdaLN(nn.Module):
108
+ """Adaptive Layer Normalization."""
109
+
110
+ def __init__(self, c_a: int, c_s: int):
111
+ super().__init__()
112
+ self.ln_a = RMSNormNoWeight(c_a)
113
+ self.ln_s = RMSNorm(c_s)
114
+ self.to_gain = nn.Sequential(nn.Linear(c_s, c_a), nn.Sigmoid())
115
+ self.to_bias = linearNoBias(c_s, c_a)
116
+
117
+ def forward(self, a: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
118
+ a = self.ln_a(a)
119
+ s = self.ln_s(s)
120
+ return self.to_gain(s) * a + self.to_bias(s)
121
+
122
+
123
+ class Transition(nn.Module):
124
+ """SwiGLU-style transition block matching foundry naming."""
125
+
126
+ def __init__(self, c: int, n: int = 4):
127
+ super().__init__()
128
+ self.layer_norm_1 = RMSNorm(c)
129
+ self.linear_1 = linearNoBias(c, n * c)
130
+ self.linear_2 = linearNoBias(c, n * c)
131
+ self.linear_3 = linearNoBias(n * c, c)
132
+
133
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
134
+ x = self.layer_norm_1(x)
135
+ return self.linear_3(F.silu(self.linear_1(x)) * self.linear_2(x))
136
+
137
+
138
+ class ConditionedTransitionBlock(nn.Module):
139
+ """SwiGLU transition with adaptive layer norm conditioning."""
140
+
141
+ def __init__(self, c_token: int, c_s: int, n: int = 2):
142
+ super().__init__()
143
+ self.ada_ln = AdaLN(c_a=c_token, c_s=c_s)
144
+ self.linear_1 = linearNoBias(c_token, c_token * n)
145
+ self.linear_2 = linearNoBias(c_token, c_token * n)
146
+ self.linear_output_project = nn.Sequential(
147
+ LinearBiasInit(c_s, c_token, biasinit=-2.0),
148
+ nn.Sigmoid(),
149
+ )
150
+ self.linear_3 = linearNoBias(c_token * n, c_token)
151
+
152
+ def forward(self, a: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
153
+ a = self.ada_ln(a, s)
154
+ b = F.silu(self.linear_1(a)) * self.linear_2(a)
155
+ return self.linear_output_project(s) * self.linear_3(b)
156
+
157
+
158
+ class MultiDimLinear(nn.Linear):
159
+ """Linear layer that reshapes output to multi-dimensional shape."""
160
+
161
+ def __init__(self, in_features: int, out_shape: Tuple[int, ...], norm: bool = False, **kwargs):
162
+ self.out_shape = out_shape
163
+ out_features = 1
164
+ for d in out_shape:
165
+ out_features *= d
166
+ super().__init__(in_features, out_features, **kwargs)
167
+ if norm:
168
+ self.ln = RMSNorm(out_features)
169
+ self.use_ln = True
170
+ else:
171
+ self.use_ln = False
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ out = super().forward(x)
175
+ if self.use_ln:
176
+ out = self.ln(out)
177
+ return out.reshape(x.shape[:-1] + self.out_shape)
178
+
179
+
180
+ class AttentionPairBias(nn.Module):
181
+ """Attention with pairwise bias for Pairformer."""
182
+
183
+ def __init__(
184
+ self,
185
+ c_a: int,
186
+ c_s: int,
187
+ c_pair: int,
188
+ n_head: int = 8,
189
+ kq_norm: bool = False,
190
+ ):
191
+ super().__init__()
192
+ self.n_head = n_head
193
+ self.c_a = c_a
194
+ self.c_pair = c_pair
195
+ self.c = c_a // n_head
196
+
197
+ self.to_q = MultiDimLinear(c_a, (n_head, self.c))
198
+ self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=True)
199
+ self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=True)
200
+ self.to_b = linearNoBias(c_pair, n_head)
201
+ self.to_g = nn.Sequential(
202
+ MultiDimLinear(c_a, (n_head, self.c), bias=False),
203
+ nn.Sigmoid(),
204
+ )
205
+ self.to_a = linearNoBias(c_a, c_a)
206
+ self.ln_0 = RMSNorm(c_pair)
207
+ self.ln_1 = RMSNorm(c_a)
208
+
209
+ def forward(
210
+ self,
211
+ a: torch.Tensor,
212
+ s: Optional[torch.Tensor],
213
+ z: torch.Tensor,
214
+ beta: Optional[torch.Tensor] = None,
215
+ ) -> torch.Tensor:
216
+ a = self.ln_1(a)
217
+
218
+ q = self.to_q(a)
219
+ k = self.to_k(a)
220
+ v = self.to_v(a)
221
+ b = self.to_b(self.ln_0(z))
222
+ if beta is not None:
223
+ b = b + beta[..., None]
224
+ g = self.to_g(a)
225
+
226
+ q = q / math.sqrt(self.c)
227
+ attn = torch.einsum("...ihd,...jhd->...ijh", q, k) + b
228
+ attn = F.softmax(attn, dim=-2)
229
+ out = torch.einsum("...ijh,...jhc->...ihc", attn, v)
230
+ out = g * out
231
+ out = out.flatten(start_dim=-2)
232
+ out = self.to_a(out)
233
+
234
+ return out
235
+
236
+
237
+ class LocalAttentionPairBias(nn.Module):
238
+ """Local attention with pairwise bias for diffusion transformer blocks."""
239
+
240
+ def __init__(
241
+ self,
242
+ c_a: int,
243
+ c_s: int,
244
+ c_pair: int,
245
+ n_head: int = 16,
246
+ kq_norm: bool = True,
247
+ ):
248
+ super().__init__()
249
+ self.n_head = n_head
250
+ self.c = c_a
251
+ self.c_head = c_a // n_head
252
+ self.c_s = c_s
253
+ self.use_checkpointing = False
254
+
255
+ self.to_q = linearNoBias(c_a, c_a)
256
+ self.to_k = linearNoBias(c_a, c_a)
257
+ self.to_v = linearNoBias(c_a, c_a)
258
+ self.to_b = linearNoBias(c_pair, n_head)
259
+ self.to_g = nn.Sequential(linearNoBias(c_a, c_a), nn.Sigmoid())
260
+ self.to_o = linearNoBias(c_a, c_a)
261
+
262
+ self.kq_norm = kq_norm
263
+ if kq_norm:
264
+ self.ln_q = RMSNorm(c_a)
265
+ self.ln_k = RMSNorm(c_a)
266
+
267
+ if c_s is not None and c_s > 0:
268
+ self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
269
+ self.linear_output_project = nn.Sequential(
270
+ LinearBiasInit(c_s, c_a, biasinit=-2.0),
271
+ nn.Sigmoid(),
272
+ )
273
+ else:
274
+ self.ln_1 = RMSNorm(c_a)
275
+
276
+ def forward(
277
+ self,
278
+ a: torch.Tensor,
279
+ s: Optional[torch.Tensor],
280
+ z: torch.Tensor,
281
+ **kwargs,
282
+ ) -> torch.Tensor:
283
+ if self.c_s is not None and self.c_s > 0:
284
+ a = self.ada_ln_1(a, s)
285
+ else:
286
+ a = self.ln_1(a)
287
+
288
+ q = self.to_q(a)
289
+ k = self.to_k(a)
290
+ v = self.to_v(a)
291
+ g = self.to_g(a)
292
+
293
+ if self.kq_norm:
294
+ q = self.ln_q(q)
295
+ k = self.ln_k(k)
296
+
297
+ batch_dims = a.shape[:-2]
298
+ L = a.shape[-2]
299
+
300
+ q = q.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3)
301
+ k = k.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3)
302
+ v = v.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3)
303
+ g = g.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3)
304
+
305
+ b = self.to_b(z).permute(*range(len(batch_dims)), -1, -3, -2)
306
+
307
+ attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.c_head)
308
+ attn = attn + b
309
+ attn = F.softmax(attn, dim=-1)
310
+
311
+ out = torch.matmul(attn, v)
312
+ out = out * g
313
+ out = out.transpose(-2, -3).contiguous()
314
+ out = out.view(*batch_dims, L, self.c)
315
+ out = self.to_o(out)
316
+
317
+ if self.c_s is not None and self.c_s > 0:
318
+ out = self.linear_output_project(s) * out
319
+
320
+ return out
321
+
322
+
323
+ class PairformerBlock(nn.Module):
324
+ """Pairformer block with attention and transitions."""
325
+
326
+ def __init__(
327
+ self,
328
+ c_s: int,
329
+ c_z: int,
330
+ attention_pair_bias: dict,
331
+ n_transition: int = 4,
332
+ p_drop: float = 0.1,
333
+ **kwargs,
334
+ ):
335
+ super().__init__()
336
+ self.z_transition = Transition(c=c_z, n=n_transition)
337
+
338
+ if c_s > 0:
339
+ self.s_transition = Transition(c=c_s, n=n_transition)
340
+ self.attention_pair_bias = AttentionPairBias(
341
+ c_a=c_s, c_s=0, c_pair=c_z, **attention_pair_bias
342
+ )
343
+
344
+ def forward(
345
+ self,
346
+ s: torch.Tensor,
347
+ z: torch.Tensor,
348
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
349
+ z = z + self.z_transition(z)
350
+
351
+ if s is not None:
352
+ beta = torch.tensor([0.0], device=z.device)
353
+ s = s + self.attention_pair_bias(s, None, z, beta=beta)
354
+ s = s + self.s_transition(s)
355
+
356
+ return s, z
357
+
358
+
359
+ class StructureLocalAtomTransformerBlock(nn.Module):
360
+ """Single block for atom/token transformer."""
361
+
362
+ def __init__(
363
+ self,
364
+ c_atom: int,
365
+ c_s: Optional[int],
366
+ c_atompair: int,
367
+ n_head: int = 4,
368
+ dropout: float = 0.0,
369
+ kq_norm: bool = True,
370
+ **kwargs,
371
+ ):
372
+ super().__init__()
373
+ self.c_s = c_s
374
+ self.dropout = nn.Dropout(dropout)
375
+ self.attention_pair_bias = LocalAttentionPairBias(
376
+ c_a=c_atom, c_s=c_s, c_pair=c_atompair, n_head=n_head, kq_norm=kq_norm
377
+ )
378
+ if c_s is not None and c_s > 0:
379
+ self.transition_block = ConditionedTransitionBlock(c_token=c_atom, c_s=c_s)
380
+ else:
381
+ self.transition_block = Transition(c=c_atom, n=4)
382
+
383
+ def forward(
384
+ self,
385
+ q: torch.Tensor,
386
+ c: Optional[torch.Tensor],
387
+ p: torch.Tensor,
388
+ **kwargs,
389
+ ) -> torch.Tensor:
390
+ q = q + self.dropout(self.attention_pair_bias(q, c, p, **kwargs))
391
+ if self.c_s is not None and self.c_s > 0:
392
+ q = q + self.transition_block(q, c)
393
+ else:
394
+ q = q + self.transition_block(q)
395
+ return q
396
+
397
+
398
+ class GatedCrossAttention(nn.Module):
399
+ """Gated cross attention for upcast/downcast."""
400
+
401
+ def __init__(
402
+ self,
403
+ c_query: int,
404
+ c_kv: int,
405
+ c_model: int = 128,
406
+ n_head: int = 4,
407
+ kq_norm: bool = True,
408
+ dropout: float = 0.0,
409
+ **kwargs,
410
+ ):
411
+ super().__init__()
412
+ self.n_head = n_head
413
+ self.scale = 1 / math.sqrt(c_model // n_head)
414
+
415
+ self.ln_q = RMSNorm(c_query)
416
+ self.ln_kv = RMSNorm(c_kv)
417
+
418
+ self.to_q = linearNoBias(c_query, c_model)
419
+ self.to_k = linearNoBias(c_kv, c_model)
420
+ self.to_v = linearNoBias(c_kv, c_model)
421
+ self.to_g = nn.Sequential(linearNoBias(c_query, c_model), nn.Sigmoid())
422
+ self.to_out = nn.Sequential(nn.Linear(c_model, c_query), nn.Dropout(dropout))
423
+
424
+ self.kq_norm = kq_norm
425
+ if kq_norm:
426
+ self.k_norm = RMSNorm(c_model)
427
+ self.q_norm = RMSNorm(c_model)
428
+
429
+ def forward(
430
+ self,
431
+ q: torch.Tensor,
432
+ kv: torch.Tensor,
433
+ attn_mask: Optional[torch.Tensor] = None,
434
+ ) -> torch.Tensor:
435
+ q_in = self.ln_q(q)
436
+ kv = self.ln_kv(kv)
437
+
438
+ q_proj = self.to_q(q_in)
439
+ k = self.to_k(kv)
440
+ v = self.to_v(kv)
441
+ g = self.to_g(q_in)
442
+
443
+ if self.kq_norm:
444
+ k = self.k_norm(k)
445
+ q_proj = self.q_norm(q_proj)
446
+
447
+ B = q.shape[0]
448
+ n_tok = q.shape[1] if q.ndim == 4 else 1
449
+ L_q = q.shape[-2]
450
+ L_kv = kv.shape[-2]
451
+ c_head = q_proj.shape[-1] // self.n_head
452
+
453
+ if q.ndim == 4:
454
+ q_proj = q_proj.view(B, n_tok, L_q, self.n_head, c_head).permute(0, 3, 1, 2, 4)
455
+ k = k.view(B, n_tok, L_kv, self.n_head, c_head).permute(0, 3, 1, 2, 4)
456
+ v = v.view(B, n_tok, L_kv, self.n_head, c_head).permute(0, 3, 1, 2, 4)
457
+ g = g.view(B, n_tok, L_q, self.n_head, c_head).permute(0, 3, 1, 2, 4)
458
+ else:
459
+ q_proj = q_proj.view(B, L_q, self.n_head, c_head).permute(0, 2, 1, 3)
460
+ k = k.view(B, L_kv, self.n_head, c_head).permute(0, 2, 1, 3)
461
+ v = v.view(B, L_kv, self.n_head, c_head).permute(0, 2, 1, 3)
462
+ g = g.view(B, L_q, self.n_head, c_head).permute(0, 2, 1, 3)
463
+
464
+ attn = torch.matmul(q_proj, k.transpose(-1, -2)) * self.scale
465
+ if attn_mask is not None:
466
+ if q.ndim == 4:
467
+ while attn_mask.ndim < attn.ndim:
468
+ attn_mask = attn_mask.unsqueeze(0)
469
+ if attn_mask.shape[1] != self.n_head and attn_mask.shape[1] != 1:
470
+ attn_mask = attn_mask.unsqueeze(1)
471
+ else:
472
+ attn_mask = attn_mask.unsqueeze(-3)
473
+ attn = attn.masked_fill(~attn_mask, float("-inf"))
474
+ attn = F.softmax(attn, dim=-1)
475
+
476
+ out = torch.matmul(attn, v)
477
+ out = out * g
478
+
479
+ if q.ndim == 4:
480
+ out = out.permute(0, 2, 3, 1, 4).contiguous()
481
+ out = out.view(B, n_tok, L_q, -1)
482
+ else:
483
+ out = out.permute(0, 2, 1, 3).contiguous()
484
+ out = out.view(B, L_q, -1)
485
+
486
+ out = self.to_out(out)
487
+ return out
488
+
489
+
490
+ class Upcast(nn.Module):
491
+ """Upcast from token level to atom level."""
492
+
493
+ def __init__(
494
+ self,
495
+ c_atom: int,
496
+ c_token: int,
497
+ method: str = "cross_attention",
498
+ cross_attention_block: Optional[dict] = None,
499
+ n_split: int = 6,
500
+ **kwargs,
501
+ ):
502
+ super().__init__()
503
+ self.method = method
504
+ self.n_split = n_split
505
+ if method == "broadcast":
506
+ self.project = nn.Sequential(RMSNorm(c_token), linearNoBias(c_token, c_atom))
507
+ elif method == "cross_attention":
508
+ self.gca = GatedCrossAttention(
509
+ c_query=c_atom,
510
+ c_kv=c_token // n_split,
511
+ c_model=c_atom,
512
+ **(cross_attention_block or {}),
513
+ )
514
+
515
+ def forward(self, q: torch.Tensor, a: torch.Tensor, tok_idx: torch.Tensor) -> torch.Tensor:
516
+ if self.method == "broadcast":
517
+ q = q + self.project(a)[..., tok_idx, :]
518
+ elif self.method == "cross_attention":
519
+ B, L, C = q.shape
520
+ I = int(tok_idx.max().item()) + 1
521
+
522
+ a_split = a.view(B, I, self.n_split, -1)
523
+
524
+ q_grouped = self._group_atoms(q, tok_idx, I)
525
+ valid_mask = self._build_valid_mask(tok_idx, I, q.device)
526
+
527
+ attn_mask = torch.ones(I, q_grouped.shape[2], self.n_split, device=q.device, dtype=torch.bool)
528
+ attn_mask[~valid_mask] = False
529
+
530
+ q_update = self.gca(q_grouped, a_split, attn_mask=attn_mask)
531
+ q = q + self._ungroup_atoms(q_update, valid_mask, L)
532
+
533
+ return q
534
+
535
+ def _group_atoms(self, q: torch.Tensor, tok_idx: torch.Tensor, I: int) -> torch.Tensor:
536
+ B, L, C = q.shape
537
+ max_atoms_per_token = 14
538
+ grouped = torch.zeros(B, I, max_atoms_per_token, C, device=q.device, dtype=q.dtype)
539
+ counts = torch.zeros(I, dtype=torch.long, device=q.device)
540
+
541
+ for i in range(L):
542
+ t = tok_idx[i].item()
543
+ if counts[t] < max_atoms_per_token:
544
+ grouped[:, t, counts[t]] = q[:, i]
545
+ counts[t] += 1
546
+
547
+ return grouped
548
+
549
+ def _build_valid_mask(self, tok_idx: torch.Tensor, I: int, device: torch.device) -> torch.Tensor:
550
+ max_atoms_per_token = 14
551
+ valid_mask = torch.zeros(I, max_atoms_per_token, dtype=torch.bool, device=device)
552
+ counts = torch.zeros(I, dtype=torch.long, device=device)
553
+
554
+ for i in range(len(tok_idx)):
555
+ t = tok_idx[i].item()
556
+ if counts[t] < max_atoms_per_token:
557
+ valid_mask[t, counts[t]] = True
558
+ counts[t] += 1
559
+
560
+ return valid_mask
561
+
562
+ def _ungroup_atoms(self, grouped: torch.Tensor, valid_mask: torch.Tensor, L: int) -> torch.Tensor:
563
+ B, I, n_atoms, C = grouped.shape
564
+ out = torch.zeros(B, L, C, device=grouped.device, dtype=grouped.dtype)
565
+
566
+ idx = 0
567
+ for t in range(I):
568
+ for a in range(n_atoms):
569
+ if valid_mask[t, a] and idx < L:
570
+ out[:, idx] = grouped[:, t, a]
571
+ idx += 1
572
+
573
+ return out
574
+
575
+
576
+ class Downcast(nn.Module):
577
+ """Downcast from atom level to token level."""
578
+
579
+ def __init__(
580
+ self,
581
+ c_atom: int,
582
+ c_token: int,
583
+ c_s: Optional[int] = None,
584
+ method: str = "mean",
585
+ cross_attention_block: Optional[dict] = None,
586
+ **kwargs,
587
+ ):
588
+ super().__init__()
589
+ self.method = method
590
+ self.c_token = c_token
591
+ self.c_atom = c_atom
592
+
593
+ if c_s is not None:
594
+ self.process_s = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_token))
595
+ else:
596
+ self.process_s = None
597
+
598
+ if method == "mean":
599
+ self.gca = linearNoBias(c_atom, c_token)
600
+ elif method == "cross_attention":
601
+ self.gca = GatedCrossAttention(
602
+ c_query=c_token,
603
+ c_kv=c_atom,
604
+ c_model=c_token,
605
+ **(cross_attention_block or {}),
606
+ )
607
+
608
+ def forward(
609
+ self,
610
+ q: torch.Tensor,
611
+ a: Optional[torch.Tensor] = None,
612
+ s: Optional[torch.Tensor] = None,
613
+ tok_idx: Optional[torch.Tensor] = None,
614
+ ) -> torch.Tensor:
615
+ if q.ndim == 2:
616
+ q = q.unsqueeze(0)
617
+ squeeze = True
618
+ else:
619
+ squeeze = False
620
+
621
+ B, L, _ = q.shape
622
+ I = int(tok_idx.max().item()) + 1
623
+
624
+ if self.method == "mean":
625
+ projected = self.gca(q)
626
+ a_update = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype)
627
+ counts = torch.zeros(B, I, 1, device=q.device, dtype=q.dtype)
628
+ for i in range(L):
629
+ t = tok_idx[i]
630
+ a_update[:, t] += projected[:, i]
631
+ counts[:, t] += 1
632
+ a_update = a_update / (counts + 1e-8)
633
+ elif self.method == "cross_attention":
634
+ if a is None:
635
+ a = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype)
636
+ elif a.ndim == 2:
637
+ a = a.unsqueeze(0)
638
+
639
+ q_grouped, valid_mask = self._group_atoms(q, tok_idx, I)
640
+ attn_mask = valid_mask.unsqueeze(-2)
641
+ a_update = self.gca(a.unsqueeze(-2), q_grouped, attn_mask=attn_mask).squeeze(-2)
642
+ else:
643
+ a_update = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype)
644
+
645
+ if a is not None:
646
+ if a.ndim == 2:
647
+ a = a.unsqueeze(0)
648
+ a = a + a_update
649
+ else:
650
+ a = a_update
651
+
652
+ if self.process_s is not None and s is not None:
653
+ if s.ndim == 2:
654
+ s = s.unsqueeze(0)
655
+ a = a + self.process_s(s)
656
+
657
+ if squeeze:
658
+ a = a.squeeze(0)
659
+
660
+ return a
661
+
662
+ def _group_atoms(
663
+ self, q: torch.Tensor, tok_idx: torch.Tensor, I: int
664
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
665
+ B, L, C = q.shape
666
+ max_atoms_per_token = 14
667
+ grouped = torch.zeros(B, I, max_atoms_per_token, C, device=q.device, dtype=q.dtype)
668
+ valid_mask = torch.zeros(I, max_atoms_per_token, dtype=torch.bool, device=q.device)
669
+ counts = torch.zeros(I, dtype=torch.long, device=q.device)
670
+
671
+ for i in range(L):
672
+ t = tok_idx[i].item()
673
+ if counts[t] < max_atoms_per_token:
674
+ grouped[:, t, counts[t]] = q[:, i]
675
+ valid_mask[t, counts[t]] = True
676
+ counts[t] += 1
677
+
678
+ return grouped, valid_mask
679
+
680
+
681
+ class LinearEmbedWithPool(nn.Module):
682
+ """Linear embedding with pooling to token level."""
683
+
684
+ def __init__(self, c_token: int):
685
+ super().__init__()
686
+ self.c_token = c_token
687
+ self.linear = linearNoBias(3, c_token)
688
+
689
+ def forward(self, r: torch.Tensor, tok_idx: torch.Tensor) -> torch.Tensor:
690
+ B = r.shape[0]
691
+ I = int(tok_idx.max().item()) + 1
692
+ q = self.linear(r)
693
+
694
+ a = torch.zeros(B, I, self.c_token, device=r.device, dtype=q.dtype)
695
+ counts = torch.zeros(B, I, 1, device=r.device, dtype=q.dtype)
696
+
697
+ for i in range(r.shape[1]):
698
+ t = tok_idx[i]
699
+ a[:, t] += q[:, i]
700
+ counts[:, t] += 1
701
+
702
+ return a / (counts + 1e-8)
703
+
704
+
705
+ class LinearSequenceHead(nn.Module):
706
+ """Sequence prediction head."""
707
+
708
+ def __init__(self, c_token: int):
709
+ super().__init__()
710
+ n_tok_all = 32
711
+ mask = torch.ones(n_tok_all, dtype=torch.bool)
712
+ self.register_buffer("valid_out_mask", mask)
713
+ self.linear = nn.Linear(c_token, n_tok_all)
714
+
715
+ def forward(self, a: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
716
+ logits = self.linear(a)
717
+ probs = F.softmax(logits, dim=-1)
718
+ probs = probs * self.valid_out_mask[None, None, :].to(probs.device)
719
+ probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8)
720
+ indices = probs.argmax(dim=-1)
721
+ return logits, indices
722
+
723
+
724
+ class LocalAtomTransformer(nn.Module):
725
+ """Atom-level transformer encoder."""
726
+
727
+ def __init__(
728
+ self,
729
+ c_atom: int,
730
+ c_s: Optional[int],
731
+ c_atompair: int,
732
+ atom_transformer_block: dict,
733
+ n_blocks: int,
734
+ ):
735
+ super().__init__()
736
+ self.blocks = nn.ModuleList([
737
+ StructureLocalAtomTransformerBlock(
738
+ c_atom=c_atom,
739
+ c_s=c_s,
740
+ c_atompair=c_atompair,
741
+ **atom_transformer_block,
742
+ )
743
+ for _ in range(n_blocks)
744
+ ])
745
+
746
+ def forward(
747
+ self,
748
+ q: torch.Tensor,
749
+ c: Optional[torch.Tensor],
750
+ p: torch.Tensor,
751
+ **kwargs,
752
+ ) -> torch.Tensor:
753
+ for block in self.blocks:
754
+ q = block(q, c, p, **kwargs)
755
+ return q
756
+
757
+
758
+ class LocalTokenTransformer(nn.Module):
759
+ """Token-level transformer for diffusion."""
760
+
761
+ def __init__(
762
+ self,
763
+ c_token: int,
764
+ c_tokenpair: int,
765
+ c_s: int,
766
+ diffusion_transformer_block: dict,
767
+ n_block: int,
768
+ **kwargs,
769
+ ):
770
+ super().__init__()
771
+ self.blocks = nn.ModuleList([
772
+ StructureLocalAtomTransformerBlock(
773
+ c_atom=c_token,
774
+ c_s=c_s,
775
+ c_atompair=c_tokenpair,
776
+ **diffusion_transformer_block,
777
+ )
778
+ for _ in range(n_block)
779
+ ])
780
+
781
+ def forward(
782
+ self,
783
+ a: torch.Tensor,
784
+ s: torch.Tensor,
785
+ z: torch.Tensor,
786
+ **kwargs,
787
+ ) -> torch.Tensor:
788
+ for block in self.blocks:
789
+ a = block(a, s, z, **kwargs)
790
+ return a
791
+
792
+
793
+ class CompactStreamingDecoder(nn.Module):
794
+ """Decoder with upcast, atom transformer, and downcast."""
795
+
796
+ def __init__(
797
+ self,
798
+ c_atom: int,
799
+ c_atompair: int,
800
+ c_token: int,
801
+ c_s: int,
802
+ c_tokenpair: int,
803
+ atom_transformer_block: dict,
804
+ upcast: dict,
805
+ downcast: dict,
806
+ n_blocks: int,
807
+ **kwargs,
808
+ ):
809
+ super().__init__()
810
+ self.n_blocks = n_blocks
811
+
812
+ self.upcast = nn.ModuleList([
813
+ Upcast(c_atom=c_atom, c_token=c_token, **upcast)
814
+ for _ in range(n_blocks)
815
+ ])
816
+ self.atom_transformer = nn.ModuleList([
817
+ StructureLocalAtomTransformerBlock(
818
+ c_atom=c_atom,
819
+ c_s=c_atom,
820
+ c_atompair=c_atompair,
821
+ **atom_transformer_block,
822
+ )
823
+ for _ in range(n_blocks)
824
+ ])
825
+ self.downcast = Downcast(c_atom=c_atom, c_token=c_token, c_s=c_s, **downcast)
826
+
827
+ def forward(
828
+ self,
829
+ a: torch.Tensor,
830
+ s: torch.Tensor,
831
+ z: torch.Tensor,
832
+ q: torch.Tensor,
833
+ c: torch.Tensor,
834
+ p: torch.Tensor,
835
+ tok_idx: torch.Tensor,
836
+ **kwargs,
837
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
838
+ for i in range(self.n_blocks):
839
+ q = self.upcast[i](q, a, tok_idx=tok_idx)
840
+ q = self.atom_transformer[i](q, c, p, **kwargs)
841
+
842
+ a = self.downcast(q.detach(), a.detach(), s.detach(), tok_idx=tok_idx)
843
+
844
+ return a, q, {}
845
+
846
+
847
+ class DiffusionTokenEncoder(nn.Module):
848
+ """Token encoder with pairformer stack for diffusion."""
849
+
850
+ def __init__(
851
+ self,
852
+ c_s: int,
853
+ c_z: int,
854
+ c_token: int,
855
+ c_atompair: int,
856
+ n_pairformer_blocks: int,
857
+ pairformer_block: dict,
858
+ **kwargs,
859
+ ):
860
+ super().__init__()
861
+
862
+ self.transition_1 = nn.ModuleList([
863
+ Transition(c=c_s, n=2),
864
+ Transition(c=c_s, n=2),
865
+ ])
866
+
867
+ self.process_z = nn.Sequential(
868
+ RMSNorm(c_z),
869
+ linearNoBias(c_z, c_z),
870
+ )
871
+
872
+ self.transition_2 = nn.ModuleList([
873
+ Transition(c=c_z, n=2),
874
+ Transition(c=c_z, n=2),
875
+ ])
876
+
877
+ self.pairformer_stack = nn.ModuleList([
878
+ PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
879
+ for _ in range(n_pairformer_blocks)
880
+ ])
881
+
882
+ def forward(
883
+ self,
884
+ s_init: torch.Tensor,
885
+ z_init: torch.Tensor,
886
+ **kwargs,
887
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
888
+ B = z_init.shape[0] if z_init.ndim == 4 else 1
889
+
890
+ s = s_init
891
+ for b in range(2):
892
+ s = s + self.transition_1[b](s)
893
+
894
+ z = z_init
895
+ if z.ndim == 3:
896
+ z = z.unsqueeze(0).expand(B, -1, -1, -1)
897
+
898
+ z = self.process_z(z)
899
+
900
+ for b in range(2):
901
+ z = z + self.transition_2[b](z)
902
+
903
+ for block in self.pairformer_stack:
904
+ s, z = block(s, z)
905
+
906
+ return s, z
907
+
908
+
909
+ class RFD3DiffusionModule(nn.Module):
910
+ """
911
+ RFD3 Diffusion Module matching foundry checkpoint structure.
912
+
913
+ This module structure matches `model.diffusion_module.*` keys in the checkpoint.
914
+ """
915
+
916
+ def __init__(
917
+ self,
918
+ c_s: int = 384,
919
+ c_z: int = 128,
920
+ c_atom: int = 128,
921
+ c_atompair: int = 16,
922
+ c_token: int = 768,
923
+ c_t_embed: int = 256,
924
+ sigma_data: float = 16.0,
925
+ n_pairformer_blocks: int = 2,
926
+ n_diffusion_blocks: int = 18,
927
+ n_atom_encoder_blocks: int = 3,
928
+ n_atom_decoder_blocks: int = 3,
929
+ n_head: int = 16,
930
+ n_recycle: int = 2,
931
+ p_drop: float = 0.0,
932
+ ):
933
+ super().__init__()
934
+
935
+ self.sigma_data = sigma_data
936
+ self.n_recycle = n_recycle
937
+
938
+ self.process_r = linearNoBias(3, c_atom)
939
+ self.to_r_update = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, 3))
940
+ self.sequence_head = LinearSequenceHead(c_token)
941
+
942
+ self.fourier_embedding = nn.ModuleList([
943
+ FourierEmbedding(c_t_embed),
944
+ FourierEmbedding(c_t_embed),
945
+ ])
946
+ self.process_n = nn.ModuleList([
947
+ nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_atom)),
948
+ nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)),
949
+ ])
950
+
951
+ self.downcast_c = Downcast(c_atom=c_atom, c_token=c_s, c_s=None, method="cross_attention")
952
+ self.downcast_q = Downcast(c_atom=c_atom, c_token=c_token, c_s=c_s, method="cross_attention")
953
+ self.process_a = LinearEmbedWithPool(c_token)
954
+ self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom))
955
+
956
+ atom_transformer_block = {
957
+ "n_head": 4,
958
+ "dropout": p_drop,
959
+ "kq_norm": True,
960
+ }
961
+
962
+ self.encoder = LocalAtomTransformer(
963
+ c_atom=c_atom,
964
+ c_s=c_atom,
965
+ c_atompair=c_atompair,
966
+ atom_transformer_block=atom_transformer_block,
967
+ n_blocks=n_atom_encoder_blocks,
968
+ )
969
+
970
+ pairformer_block = {
971
+ "attention_pair_bias": {"n_head": 4, "kq_norm": False},
972
+ "n_transition": 4,
973
+ }
974
+
975
+ self.diffusion_token_encoder = DiffusionTokenEncoder(
976
+ c_s=c_s,
977
+ c_z=c_z,
978
+ c_token=c_token,
979
+ c_atompair=c_atompair,
980
+ n_pairformer_blocks=n_pairformer_blocks,
981
+ pairformer_block=pairformer_block,
982
+ )
983
+
984
+ diffusion_transformer_block = {
985
+ "n_head": n_head,
986
+ "dropout": p_drop,
987
+ "kq_norm": True,
988
+ }
989
+
990
+ self.diffusion_transformer = LocalTokenTransformer(
991
+ c_token=c_token,
992
+ c_tokenpair=c_z,
993
+ c_s=c_s,
994
+ diffusion_transformer_block=diffusion_transformer_block,
995
+ n_block=n_diffusion_blocks,
996
+ )
997
+
998
+ decoder_upcast = {"method": "cross_attention"}
999
+ decoder_downcast = {"method": "cross_attention"}
1000
+
1001
+ self.decoder = CompactStreamingDecoder(
1002
+ c_atom=c_atom,
1003
+ c_atompair=c_atompair,
1004
+ c_token=c_token,
1005
+ c_s=c_s,
1006
+ c_tokenpair=c_z,
1007
+ atom_transformer_block=atom_transformer_block,
1008
+ upcast=decoder_upcast,
1009
+ downcast=decoder_downcast,
1010
+ n_blocks=n_atom_decoder_blocks,
1011
+ )
1012
+
1013
+ def scale_positions_in(self, x_noisy: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
1014
+ if t.ndim == 1:
1015
+ t = t[..., None, None]
1016
+ elif t.ndim == 2:
1017
+ t = t[..., None]
1018
+ return x_noisy / torch.sqrt(t**2 + self.sigma_data**2)
1019
+
1020
+ def scale_positions_out(
1021
+ self, r_update: torch.Tensor, x_noisy: torch.Tensor, t: torch.Tensor
1022
+ ) -> torch.Tensor:
1023
+ if t.ndim == 1:
1024
+ t = t[..., None, None]
1025
+ elif t.ndim == 2:
1026
+ t = t[..., None]
1027
+ sigma2 = self.sigma_data**2
1028
+ return (sigma2 / (sigma2 + t**2)) * x_noisy + (
1029
+ self.sigma_data * t / torch.sqrt(sigma2 + t**2)
1030
+ ) * r_update
1031
+
1032
+ def process_time(self, t: torch.Tensor, idx: int) -> torch.Tensor:
1033
+ t_clamped = torch.clamp(t, min=1e-20)
1034
+ t_log = 0.25 * torch.log(t_clamped / self.sigma_data)
1035
+ emb = self.process_n[idx](self.fourier_embedding[idx](t_log))
1036
+ emb = emb * (t > 0).float()[..., None]
1037
+ return emb
1038
+
1039
+ def compute_pair_features(self, xyz: torch.Tensor, c_atompair: int) -> torch.Tensor:
1040
+ dist = torch.cdist(xyz, xyz)
1041
+ inv_dist = 1 / (1 + dist**2)
1042
+ return inv_dist.unsqueeze(-1).expand(-1, -1, -1, c_atompair)
1043
+
1044
+
1045
+ class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
1046
+ """
1047
+ RFDiffusion3 transformer for protein structure prediction.
1048
+
1049
+ This wraps the diffusion module to provide the full model interface.
1050
+ The state dict keys match the foundry checkpoint format.
1051
+ """
1052
+
1053
+ config_name = "config.json"
1054
+ _supports_gradient_checkpointing = True
1055
+
1056
+ @register_to_config
1057
+ def __init__(
1058
+ self,
1059
+ c_s: int = 384,
1060
+ c_z: int = 128,
1061
+ c_atom: int = 128,
1062
+ c_atompair: int = 16,
1063
+ c_token: int = 768,
1064
+ c_t_embed: int = 256,
1065
+ sigma_data: float = 16.0,
1066
+ n_pairformer_block: int = 2,
1067
+ n_diffusion_block: int = 18,
1068
+ n_atom_encoder_block: int = 3,
1069
+ n_atom_decoder_block: int = 3,
1070
+ n_head: int = 16,
1071
+ n_recycle: int = 2,
1072
+ p_drop: float = 0.0,
1073
+ ):
1074
+ super().__init__()
1075
+
1076
+ self.diffusion_module = RFD3DiffusionModule(
1077
+ c_s=c_s,
1078
+ c_z=c_z,
1079
+ c_atom=c_atom,
1080
+ c_atompair=c_atompair,
1081
+ c_token=c_token,
1082
+ c_t_embed=c_t_embed,
1083
+ sigma_data=sigma_data,
1084
+ n_pairformer_blocks=n_pairformer_block,
1085
+ n_diffusion_blocks=n_diffusion_block,
1086
+ n_atom_encoder_blocks=n_atom_encoder_block,
1087
+ n_atom_decoder_blocks=n_atom_decoder_block,
1088
+ n_head=n_head,
1089
+ n_recycle=n_recycle,
1090
+ p_drop=p_drop,
1091
+ )
1092
+
1093
+ self.s_init = nn.Parameter(torch.zeros(1, 1, c_s))
1094
+ self.z_init = nn.Parameter(torch.zeros(1, 1, 1, c_z))
1095
+
1096
+ @property
1097
+ def sigma_data(self) -> float:
1098
+ return self.diffusion_module.sigma_data
1099
+
1100
+ def forward(
1101
+ self,
1102
+ xyz_noisy: torch.Tensor,
1103
+ t: torch.Tensor,
1104
+ atom_to_token_map: Optional[torch.Tensor] = None,
1105
+ motif_mask: Optional[torch.Tensor] = None,
1106
+ s_init: Optional[torch.Tensor] = None,
1107
+ z_init: Optional[torch.Tensor] = None,
1108
+ n_recycle: Optional[int] = None,
1109
+ **kwargs,
1110
+ ) -> RFDiffusionTransformerOutput:
1111
+ """
1112
+ Forward pass of the diffusion module.
1113
+
1114
+ Args:
1115
+ xyz_noisy: Noisy atom coordinates [B, L, 3]
1116
+ t: Noise level / timestep [B]
1117
+ atom_to_token_map: Mapping from atoms to tokens [L]
1118
+ motif_mask: Mask for fixed motif atoms [L]
1119
+ s_init: Initial single representation [I, c_s]
1120
+ z_init: Initial pair representation [I, I, c_z]
1121
+ n_recycle: Number of recycling iterations
1122
+
1123
+ Returns:
1124
+ RFDiffusionTransformerOutput with denoised coordinates
1125
+ """
1126
+ dm = self.diffusion_module
1127
+ B, L, _ = xyz_noisy.shape
1128
+
1129
+ if atom_to_token_map is None:
1130
+ atom_to_token_map = torch.arange(L, device=xyz_noisy.device)
1131
+ I = atom_to_token_map.max() + 1
1132
+
1133
+ if motif_mask is None:
1134
+ motif_mask = torch.zeros(L, dtype=torch.bool, device=xyz_noisy.device)
1135
+
1136
+ t_L = t[:, None].expand(B, L) * (~motif_mask).float()
1137
+ t_I = t[:, None].expand(B, I)
1138
+
1139
+ r_scaled = dm.scale_positions_in(xyz_noisy, t)
1140
+ r_noisy = dm.scale_positions_in(xyz_noisy, t_L)
1141
+
1142
+ if s_init is None:
1143
+ s_init = self.s_init.squeeze(0).expand(I, -1)
1144
+ if z_init is None:
1145
+ z_init = self.z_init.squeeze(0).expand(I, I, -1)
1146
+
1147
+ p = dm.compute_pair_features(r_scaled, self.config.c_atompair)
1148
+
1149
+ a_I = dm.process_a(r_noisy, tok_idx=atom_to_token_map)
1150
+ s_I = dm.downcast_c(torch.zeros(B, L, self.config.c_atom, device=xyz_noisy.device),
1151
+ s_init.unsqueeze(0).expand(B, -1, -1) if s_init.ndim == 2 else s_init,
1152
+ tok_idx=atom_to_token_map)
1153
+
1154
+ q = dm.process_r(r_noisy)
1155
+ c = dm.process_time(t_L, idx=0)
1156
+ q = q + c
1157
+ s_I = s_I + dm.process_time(t_I, idx=1)
1158
+ c = c + dm.process_c(c)
1159
+
1160
+ q = dm.encoder(q, c, p)
1161
+ a_I = dm.downcast_q(q, a_I, s_I, tok_idx=atom_to_token_map)
1162
+
1163
+ if n_recycle is None:
1164
+ n_recycle = dm.n_recycle if not self.training else 1
1165
+
1166
+ for _ in range(n_recycle):
1167
+ s_I, z_II = dm.diffusion_token_encoder(s_init=s_I, z_init=z_init)
1168
+ a_I = dm.diffusion_transformer(a_I, s_I, z_II)
1169
+
1170
+ a_I, q, _ = dm.decoder(a_I, s_I, z_II, q, c, p, tok_idx=atom_to_token_map)
1171
+
1172
+ r_update = dm.to_r_update(q)
1173
+ xyz_out = dm.scale_positions_out(r_update, xyz_noisy, t_L)
1174
+
1175
+ sequence_logits, sequence_indices = dm.sequence_head(a_I)
1176
+
1177
+ return RFDiffusionTransformerOutput(
1178
+ xyz=xyz_out,
1179
+ single=s_I,
1180
+ pair=z_II,
1181
+ sequence_logits=sequence_logits,
1182
+ sequence_indices=sequence_indices,
1183
+ )