| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| RFDiffusion3 Transformer model. |
| |
| This module provides a diffusers-compatible implementation of the RFD3 |
| architecture for protein structure prediction and generation. The module |
| structure matches the foundry checkpoint format for direct weight loading. |
| """ |
|
|
| import math |
| from dataclasses import dataclass |
| from functools import partial |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
|
|
| @dataclass |
| class RFDiffusionTransformerOutput: |
| """Output class for RFDiffusion transformer.""" |
|
|
| xyz: torch.Tensor |
| single: torch.Tensor |
| pair: torch.Tensor |
| sequence_logits: Optional[torch.Tensor] = None |
| sequence_indices: Optional[torch.Tensor] = None |
|
|
|
|
| linearNoBias = partial(nn.Linear, bias=False) |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root Mean Square Layer Normalization.""" |
|
|
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) |
| return x / rms * self.weight |
|
|
|
|
| class FourierEmbedding(nn.Module): |
| """Fourier feature embedding for timesteps with learned weights.""" |
|
|
| def __init__(self, c: int): |
| super().__init__() |
| self.c = c |
| self.register_buffer("w", torch.zeros(c, dtype=torch.float32)) |
| self.register_buffer("b", torch.zeros(c, dtype=torch.float32)) |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| nn.init.normal_(self.w) |
| nn.init.normal_(self.b) |
|
|
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| return torch.cos(2 * math.pi * (t[..., None] * self.w + self.b)) |
|
|
|
|
| class LinearBiasInit(nn.Linear): |
| """Linear layer with custom bias initialization.""" |
|
|
| def __init__(self, *args, biasinit: float = -2.0, **kwargs): |
| self.biasinit = biasinit |
| super().__init__(*args, **kwargs) |
|
|
| def reset_parameters(self) -> None: |
| super().reset_parameters() |
| if self.bias is not None: |
| self.bias.data.fill_(self.biasinit) |
|
|
|
|
| class RMSNormNoWeight(nn.Module): |
| """RMSNorm without learnable weight (elementwise_affine=False).""" |
|
|
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.dim = dim |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) |
| return x / rms |
|
|
|
|
| class AdaLN(nn.Module): |
| """Adaptive Layer Normalization.""" |
|
|
| def __init__(self, c_a: int, c_s: int): |
| super().__init__() |
| self.ln_a = RMSNormNoWeight(c_a) |
| self.ln_s = RMSNorm(c_s) |
| self.to_gain = nn.Sequential(nn.Linear(c_s, c_a), nn.Sigmoid()) |
| self.to_bias = linearNoBias(c_s, c_a) |
|
|
| def forward(self, a: torch.Tensor, s: torch.Tensor) -> torch.Tensor: |
| a = self.ln_a(a) |
| s = self.ln_s(s) |
| return self.to_gain(s) * a + self.to_bias(s) |
|
|
|
|
| class Transition(nn.Module): |
| """SwiGLU-style transition block matching foundry naming.""" |
|
|
| def __init__(self, c: int, n: int = 4): |
| super().__init__() |
| self.layer_norm_1 = RMSNorm(c) |
| self.linear_1 = linearNoBias(c, n * c) |
| self.linear_2 = linearNoBias(c, n * c) |
| self.linear_3 = linearNoBias(n * c, c) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.layer_norm_1(x) |
| return self.linear_3(F.silu(self.linear_1(x)) * self.linear_2(x)) |
|
|
|
|
| class ConditionedTransitionBlock(nn.Module): |
| """SwiGLU transition with adaptive layer norm conditioning.""" |
|
|
| def __init__(self, c_token: int, c_s: int, n: int = 2): |
| super().__init__() |
| self.ada_ln = AdaLN(c_a=c_token, c_s=c_s) |
| self.linear_1 = linearNoBias(c_token, c_token * n) |
| self.linear_2 = linearNoBias(c_token, c_token * n) |
| self.linear_output_project = nn.Sequential( |
| LinearBiasInit(c_s, c_token, biasinit=-2.0), |
| nn.Sigmoid(), |
| ) |
| self.linear_3 = linearNoBias(c_token * n, c_token) |
|
|
| def forward(self, a: torch.Tensor, s: torch.Tensor) -> torch.Tensor: |
| a = self.ada_ln(a, s) |
| b = F.silu(self.linear_1(a)) * self.linear_2(a) |
| return self.linear_output_project(s) * self.linear_3(b) |
|
|
|
|
| class MultiDimLinear(nn.Linear): |
| """Linear layer that reshapes output to multi-dimensional shape.""" |
|
|
| def __init__(self, in_features: int, out_shape: Tuple[int, ...], norm: bool = False, **kwargs): |
| self.out_shape = out_shape |
| out_features = 1 |
| for d in out_shape: |
| out_features *= d |
| super().__init__(in_features, out_features, **kwargs) |
| if norm: |
| self.ln = RMSNorm(out_features) |
| self.use_ln = True |
| else: |
| self.use_ln = False |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| out = super().forward(x) |
| if self.use_ln: |
| out = self.ln(out) |
| return out.reshape(x.shape[:-1] + self.out_shape) |
|
|
|
|
| class AttentionPairBias(nn.Module): |
| """Attention with pairwise bias for Pairformer.""" |
|
|
| def __init__( |
| self, |
| c_a: int, |
| c_s: int, |
| c_pair: int, |
| n_head: int = 8, |
| kq_norm: bool = False, |
| ): |
| super().__init__() |
| self.n_head = n_head |
| self.c_a = c_a |
| self.c_pair = c_pair |
| self.c = c_a // n_head |
|
|
| self.to_q = MultiDimLinear(c_a, (n_head, self.c)) |
| self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=True) |
| self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=True) |
| self.to_b = linearNoBias(c_pair, n_head) |
| self.to_g = nn.Sequential( |
| MultiDimLinear(c_a, (n_head, self.c), bias=False), |
| nn.Sigmoid(), |
| ) |
| self.to_a = linearNoBias(c_a, c_a) |
| self.ln_0 = RMSNorm(c_pair) |
| self.ln_1 = RMSNorm(c_a) |
|
|
| def forward( |
| self, |
| a: torch.Tensor, |
| s: Optional[torch.Tensor], |
| z: torch.Tensor, |
| beta: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| a = self.ln_1(a) |
|
|
| q = self.to_q(a) |
| k = self.to_k(a) |
| v = self.to_v(a) |
| b = self.to_b(self.ln_0(z)) |
| if beta is not None: |
| b = b + beta[..., None] |
| g = self.to_g(a) |
|
|
| q = q / math.sqrt(self.c) |
| attn = torch.einsum("...ihd,...jhd->...ijh", q, k) + b |
| attn = F.softmax(attn, dim=-2) |
| out = torch.einsum("...ijh,...jhc->...ihc", attn, v) |
| out = g * out |
| out = out.flatten(start_dim=-2) |
| out = self.to_a(out) |
|
|
| return out |
|
|
|
|
| class LocalAttentionPairBias(nn.Module): |
| """Local attention with pairwise bias for diffusion transformer blocks.""" |
|
|
| def __init__( |
| self, |
| c_a: int, |
| c_s: int, |
| c_pair: int, |
| n_head: int = 16, |
| kq_norm: bool = True, |
| ): |
| super().__init__() |
| self.n_head = n_head |
| self.c = c_a |
| self.c_head = c_a // n_head |
| self.c_s = c_s |
| self.use_checkpointing = False |
|
|
| self.to_q = linearNoBias(c_a, c_a) |
| self.to_k = linearNoBias(c_a, c_a) |
| self.to_v = linearNoBias(c_a, c_a) |
| self.to_b = linearNoBias(c_pair, n_head) |
| self.to_g = nn.Sequential(linearNoBias(c_a, c_a), nn.Sigmoid()) |
| self.to_o = linearNoBias(c_a, c_a) |
|
|
| self.kq_norm = kq_norm |
| if kq_norm: |
| self.ln_q = RMSNorm(c_a) |
| self.ln_k = RMSNorm(c_a) |
|
|
| if c_s is not None and c_s > 0: |
| self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s) |
| self.linear_output_project = nn.Sequential( |
| LinearBiasInit(c_s, c_a, biasinit=-2.0), |
| nn.Sigmoid(), |
| ) |
| else: |
| self.ln_1 = RMSNorm(c_a) |
|
|
| def forward( |
| self, |
| a: torch.Tensor, |
| s: Optional[torch.Tensor], |
| z: torch.Tensor, |
| **kwargs, |
| ) -> torch.Tensor: |
| if self.c_s is not None and self.c_s > 0: |
| a = self.ada_ln_1(a, s) |
| else: |
| a = self.ln_1(a) |
|
|
| q = self.to_q(a) |
| k = self.to_k(a) |
| v = self.to_v(a) |
| g = self.to_g(a) |
|
|
| if self.kq_norm: |
| q = self.ln_q(q) |
| k = self.ln_k(k) |
|
|
| batch_dims = a.shape[:-2] |
| L = a.shape[-2] |
|
|
| q = q.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3) |
| k = k.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3) |
| v = v.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3) |
| g = g.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3) |
|
|
| b = self.to_b(z).permute(*range(len(batch_dims)), -1, -3, -2) |
|
|
| attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.c_head) |
| attn = attn + b |
| attn = F.softmax(attn, dim=-1) |
|
|
| out = torch.matmul(attn, v) |
| out = out * g |
| out = out.transpose(-2, -3).contiguous() |
| out = out.view(*batch_dims, L, self.c) |
| out = self.to_o(out) |
|
|
| if self.c_s is not None and self.c_s > 0: |
| out = self.linear_output_project(s) * out |
|
|
| return out |
|
|
|
|
| class PairformerBlock(nn.Module): |
| """Pairformer block with attention and transitions.""" |
|
|
| def __init__( |
| self, |
| c_s: int, |
| c_z: int, |
| attention_pair_bias: dict, |
| n_transition: int = 4, |
| p_drop: float = 0.1, |
| **kwargs, |
| ): |
| super().__init__() |
| self.z_transition = Transition(c=c_z, n=n_transition) |
|
|
| if c_s > 0: |
| self.s_transition = Transition(c=c_s, n=n_transition) |
| self.attention_pair_bias = AttentionPairBias( |
| c_a=c_s, c_s=0, c_pair=c_z, **attention_pair_bias |
| ) |
|
|
| def forward( |
| self, |
| s: torch.Tensor, |
| z: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| z = z + self.z_transition(z) |
|
|
| if s is not None: |
| beta = torch.tensor([0.0], device=z.device) |
| s = s + self.attention_pair_bias(s, None, z, beta=beta) |
| s = s + self.s_transition(s) |
|
|
| return s, z |
|
|
|
|
| class StructureLocalAtomTransformerBlock(nn.Module): |
| """Single block for atom/token transformer.""" |
|
|
| def __init__( |
| self, |
| c_atom: int, |
| c_s: Optional[int], |
| c_atompair: int, |
| n_head: int = 4, |
| dropout: float = 0.0, |
| kq_norm: bool = True, |
| **kwargs, |
| ): |
| super().__init__() |
| self.c_s = c_s |
| self.dropout = nn.Dropout(dropout) |
| self.attention_pair_bias = LocalAttentionPairBias( |
| c_a=c_atom, c_s=c_s, c_pair=c_atompair, n_head=n_head, kq_norm=kq_norm |
| ) |
| if c_s is not None and c_s > 0: |
| self.transition_block = ConditionedTransitionBlock(c_token=c_atom, c_s=c_s) |
| else: |
| self.transition_block = Transition(c=c_atom, n=4) |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| c: Optional[torch.Tensor], |
| p: torch.Tensor, |
| **kwargs, |
| ) -> torch.Tensor: |
| q = q + self.dropout(self.attention_pair_bias(q, c, p, **kwargs)) |
| if self.c_s is not None and self.c_s > 0: |
| q = q + self.transition_block(q, c) |
| else: |
| q = q + self.transition_block(q) |
| return q |
|
|
|
|
| class GatedCrossAttention(nn.Module): |
| """Gated cross attention for upcast/downcast.""" |
|
|
| def __init__( |
| self, |
| c_query: int, |
| c_kv: int, |
| c_model: int = 128, |
| n_head: int = 4, |
| kq_norm: bool = True, |
| dropout: float = 0.0, |
| **kwargs, |
| ): |
| super().__init__() |
| self.n_head = n_head |
| self.scale = 1 / math.sqrt(c_model // n_head) |
|
|
| self.ln_q = RMSNorm(c_query) |
| self.ln_kv = RMSNorm(c_kv) |
|
|
| self.to_q = linearNoBias(c_query, c_model) |
| self.to_k = linearNoBias(c_kv, c_model) |
| self.to_v = linearNoBias(c_kv, c_model) |
| self.to_g = nn.Sequential(linearNoBias(c_query, c_model), nn.Sigmoid()) |
| self.to_out = nn.Sequential(nn.Linear(c_model, c_query), nn.Dropout(dropout)) |
|
|
| self.kq_norm = kq_norm |
| if kq_norm: |
| self.k_norm = RMSNorm(c_model) |
| self.q_norm = RMSNorm(c_model) |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| kv: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| q_in = self.ln_q(q) |
| kv = self.ln_kv(kv) |
|
|
| q_proj = self.to_q(q_in) |
| k = self.to_k(kv) |
| v = self.to_v(kv) |
| g = self.to_g(q_in) |
|
|
| if self.kq_norm: |
| k = self.k_norm(k) |
| q_proj = self.q_norm(q_proj) |
|
|
| B = q.shape[0] |
| n_tok = q.shape[1] if q.ndim == 4 else 1 |
| L_q = q.shape[-2] |
| L_kv = kv.shape[-2] |
| c_head = q_proj.shape[-1] // self.n_head |
|
|
| if q.ndim == 4: |
| q_proj = q_proj.view(B, n_tok, L_q, self.n_head, c_head).permute(0, 3, 1, 2, 4) |
| k = k.view(B, n_tok, L_kv, self.n_head, c_head).permute(0, 3, 1, 2, 4) |
| v = v.view(B, n_tok, L_kv, self.n_head, c_head).permute(0, 3, 1, 2, 4) |
| g = g.view(B, n_tok, L_q, self.n_head, c_head).permute(0, 3, 1, 2, 4) |
| else: |
| q_proj = q_proj.view(B, L_q, self.n_head, c_head).permute(0, 2, 1, 3) |
| k = k.view(B, L_kv, self.n_head, c_head).permute(0, 2, 1, 3) |
| v = v.view(B, L_kv, self.n_head, c_head).permute(0, 2, 1, 3) |
| g = g.view(B, L_q, self.n_head, c_head).permute(0, 2, 1, 3) |
|
|
| attn = torch.matmul(q_proj, k.transpose(-1, -2)) * self.scale |
| if attn_mask is not None: |
| if q.ndim == 4: |
| while attn_mask.ndim < attn.ndim: |
| attn_mask = attn_mask.unsqueeze(0) |
| if attn_mask.shape[1] != self.n_head and attn_mask.shape[1] != 1: |
| attn_mask = attn_mask.unsqueeze(1) |
| else: |
| attn_mask = attn_mask.unsqueeze(-3) |
| attn = attn.masked_fill(~attn_mask, float("-inf")) |
| attn = F.softmax(attn, dim=-1) |
|
|
| out = torch.matmul(attn, v) |
| out = out * g |
|
|
| if q.ndim == 4: |
| out = out.permute(0, 2, 3, 1, 4).contiguous() |
| out = out.view(B, n_tok, L_q, -1) |
| else: |
| out = out.permute(0, 2, 1, 3).contiguous() |
| out = out.view(B, L_q, -1) |
|
|
| out = self.to_out(out) |
| return out |
|
|
|
|
| class Upcast(nn.Module): |
| """Upcast from token level to atom level.""" |
|
|
| def __init__( |
| self, |
| c_atom: int, |
| c_token: int, |
| method: str = "cross_attention", |
| cross_attention_block: Optional[dict] = None, |
| n_split: int = 6, |
| **kwargs, |
| ): |
| super().__init__() |
| self.method = method |
| self.n_split = n_split |
| if method == "broadcast": |
| self.project = nn.Sequential(RMSNorm(c_token), linearNoBias(c_token, c_atom)) |
| elif method == "cross_attention": |
| self.gca = GatedCrossAttention( |
| c_query=c_atom, |
| c_kv=c_token // n_split, |
| **(cross_attention_block or {}), |
| ) |
|
|
| def forward(self, q: torch.Tensor, a: torch.Tensor, tok_idx: torch.Tensor) -> torch.Tensor: |
| if self.method == "broadcast": |
| q = q + self.project(a)[..., tok_idx, :] |
| elif self.method == "cross_attention": |
| B, L, C = q.shape |
| I = int(tok_idx.max().item()) + 1 |
|
|
| a_split = a.view(B, I, self.n_split, -1) |
|
|
| q_grouped = self._group_atoms(q, tok_idx, I) |
| valid_mask = self._build_valid_mask(tok_idx, I, q.device) |
|
|
| attn_mask = torch.ones(I, q_grouped.shape[2], self.n_split, device=q.device, dtype=torch.bool) |
| attn_mask[~valid_mask] = False |
|
|
| q_update = self.gca(q_grouped, a_split, attn_mask=attn_mask) |
| q = q + self._ungroup_atoms(q_update, valid_mask, L) |
|
|
| return q |
|
|
| def _group_atoms(self, q: torch.Tensor, tok_idx: torch.Tensor, I: int) -> torch.Tensor: |
| B, L, C = q.shape |
| max_atoms_per_token = 14 |
| grouped = torch.zeros(B, I, max_atoms_per_token, C, device=q.device, dtype=q.dtype) |
| counts = torch.zeros(I, dtype=torch.long, device=q.device) |
|
|
| for i in range(L): |
| t = tok_idx[i].item() |
| if counts[t] < max_atoms_per_token: |
| grouped[:, t, counts[t]] = q[:, i] |
| counts[t] += 1 |
|
|
| return grouped |
|
|
| def _build_valid_mask(self, tok_idx: torch.Tensor, I: int, device: torch.device) -> torch.Tensor: |
| max_atoms_per_token = 14 |
| valid_mask = torch.zeros(I, max_atoms_per_token, dtype=torch.bool, device=device) |
| counts = torch.zeros(I, dtype=torch.long, device=device) |
|
|
| for i in range(len(tok_idx)): |
| t = tok_idx[i].item() |
| if counts[t] < max_atoms_per_token: |
| valid_mask[t, counts[t]] = True |
| counts[t] += 1 |
|
|
| return valid_mask |
|
|
| def _ungroup_atoms(self, grouped: torch.Tensor, valid_mask: torch.Tensor, L: int) -> torch.Tensor: |
| B, I, n_atoms, C = grouped.shape |
| out = torch.zeros(B, L, C, device=grouped.device, dtype=grouped.dtype) |
|
|
| idx = 0 |
| for t in range(I): |
| for a in range(n_atoms): |
| if valid_mask[t, a] and idx < L: |
| out[:, idx] = grouped[:, t, a] |
| idx += 1 |
|
|
| return out |
|
|
|
|
| class Downcast(nn.Module): |
| """Downcast from atom level to token level.""" |
|
|
| def __init__( |
| self, |
| c_atom: int, |
| c_token: int, |
| c_s: Optional[int] = None, |
| method: str = "mean", |
| cross_attention_block: Optional[dict] = None, |
| **kwargs, |
| ): |
| super().__init__() |
| self.method = method |
| self.c_token = c_token |
| self.c_atom = c_atom |
|
|
| if c_s is not None: |
| self.process_s = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_token)) |
| else: |
| self.process_s = None |
|
|
| if method == "mean": |
| self.gca = linearNoBias(c_atom, c_token) |
| elif method == "cross_attention": |
| self.gca = GatedCrossAttention( |
| c_query=c_token, |
| c_kv=c_atom, |
| **(cross_attention_block or {}), |
| ) |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| a: Optional[torch.Tensor] = None, |
| s: Optional[torch.Tensor] = None, |
| tok_idx: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| if q.ndim == 2: |
| q = q.unsqueeze(0) |
| squeeze = True |
| else: |
| squeeze = False |
|
|
| B, L, _ = q.shape |
| I = int(tok_idx.max().item()) + 1 |
|
|
| if self.method == "mean": |
| projected = self.gca(q) |
| a_update = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype) |
| counts = torch.zeros(B, I, 1, device=q.device, dtype=q.dtype) |
| for i in range(L): |
| t = tok_idx[i] |
| a_update[:, t] += projected[:, i] |
| counts[:, t] += 1 |
| a_update = a_update / (counts + 1e-8) |
| elif self.method == "cross_attention": |
| if a is None: |
| a = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype) |
| elif a.ndim == 2: |
| a = a.unsqueeze(0) |
|
|
| q_grouped, valid_mask = self._group_atoms(q, tok_idx, I) |
| attn_mask = valid_mask.unsqueeze(-2) |
| a_update = self.gca(a.unsqueeze(-2), q_grouped, attn_mask=attn_mask).squeeze(-2) |
| else: |
| a_update = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype) |
|
|
| if a is not None: |
| if a.ndim == 2: |
| a = a.unsqueeze(0) |
| a = a + a_update |
| else: |
| a = a_update |
|
|
| if self.process_s is not None and s is not None: |
| if s.ndim == 2: |
| s = s.unsqueeze(0) |
| a = a + self.process_s(s) |
|
|
| if squeeze: |
| a = a.squeeze(0) |
|
|
| return a |
|
|
| def _group_atoms( |
| self, q: torch.Tensor, tok_idx: torch.Tensor, I: int |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| B, L, C = q.shape |
| max_atoms_per_token = 14 |
| grouped = torch.zeros(B, I, max_atoms_per_token, C, device=q.device, dtype=q.dtype) |
| valid_mask = torch.zeros(I, max_atoms_per_token, dtype=torch.bool, device=q.device) |
| counts = torch.zeros(I, dtype=torch.long, device=q.device) |
|
|
| for i in range(L): |
| t = tok_idx[i].item() |
| if counts[t] < max_atoms_per_token: |
| grouped[:, t, counts[t]] = q[:, i] |
| valid_mask[t, counts[t]] = True |
| counts[t] += 1 |
|
|
| return grouped, valid_mask |
|
|
|
|
| class LinearEmbedWithPool(nn.Module): |
| """Linear embedding with pooling to token level.""" |
|
|
| def __init__(self, c_token: int): |
| super().__init__() |
| self.c_token = c_token |
| self.linear = linearNoBias(3, c_token) |
|
|
| def forward(self, r: torch.Tensor, tok_idx: torch.Tensor) -> torch.Tensor: |
| B = r.shape[0] |
| I = int(tok_idx.max().item()) + 1 |
| q = self.linear(r) |
|
|
| a = torch.zeros(B, I, self.c_token, device=r.device, dtype=q.dtype) |
| counts = torch.zeros(B, I, 1, device=r.device, dtype=q.dtype) |
|
|
| for i in range(r.shape[1]): |
| t = tok_idx[i] |
| a[:, t] += q[:, i] |
| counts[:, t] += 1 |
|
|
| return a / (counts + 1e-8) |
|
|
|
|
| class LinearSequenceHead(nn.Module): |
| """Sequence prediction head.""" |
|
|
| def __init__(self, c_token: int): |
| super().__init__() |
| n_tok_all = 32 |
| mask = torch.ones(n_tok_all, dtype=torch.bool) |
| self.register_buffer("valid_out_mask", mask) |
| self.linear = nn.Linear(c_token, n_tok_all) |
|
|
| def forward(self, a: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| logits = self.linear(a) |
| probs = F.softmax(logits, dim=-1) |
| probs = probs * self.valid_out_mask[None, None, :].to(probs.device) |
| probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8) |
| indices = probs.argmax(dim=-1) |
| return logits, indices |
|
|
|
|
| class LocalAtomTransformer(nn.Module): |
| """Atom-level transformer encoder.""" |
|
|
| def __init__( |
| self, |
| c_atom: int, |
| c_s: Optional[int], |
| c_atompair: int, |
| atom_transformer_block: dict, |
| n_blocks: int, |
| ): |
| super().__init__() |
| self.blocks = nn.ModuleList([ |
| StructureLocalAtomTransformerBlock( |
| c_atom=c_atom, |
| c_s=c_s, |
| c_atompair=c_atompair, |
| **atom_transformer_block, |
| ) |
| for _ in range(n_blocks) |
| ]) |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| c: Optional[torch.Tensor], |
| p: torch.Tensor, |
| **kwargs, |
| ) -> torch.Tensor: |
| for block in self.blocks: |
| q = block(q, c, p, **kwargs) |
| return q |
|
|
|
|
| class LocalTokenTransformer(nn.Module): |
| """Token-level transformer for diffusion.""" |
|
|
| def __init__( |
| self, |
| c_token: int, |
| c_tokenpair: int, |
| c_s: int, |
| diffusion_transformer_block: dict, |
| n_block: int, |
| **kwargs, |
| ): |
| super().__init__() |
| self.blocks = nn.ModuleList([ |
| StructureLocalAtomTransformerBlock( |
| c_atom=c_token, |
| c_s=c_s, |
| c_atompair=c_tokenpair, |
| **diffusion_transformer_block, |
| ) |
| for _ in range(n_block) |
| ]) |
|
|
| def forward( |
| self, |
| a: torch.Tensor, |
| s: torch.Tensor, |
| z: torch.Tensor, |
| **kwargs, |
| ) -> torch.Tensor: |
| for block in self.blocks: |
| a = block(a, s, z, **kwargs) |
| return a |
|
|
|
|
| class CompactStreamingDecoder(nn.Module): |
| """Decoder with upcast, atom transformer, and downcast.""" |
|
|
| def __init__( |
| self, |
| c_atom: int, |
| c_atompair: int, |
| c_token: int, |
| c_s: int, |
| c_tokenpair: int, |
| atom_transformer_block: dict, |
| upcast: dict, |
| downcast: dict, |
| n_blocks: int, |
| **kwargs, |
| ): |
| super().__init__() |
| self.n_blocks = n_blocks |
|
|
| self.upcast = nn.ModuleList([ |
| Upcast(c_atom=c_atom, c_token=c_token, **upcast) |
| for _ in range(n_blocks) |
| ]) |
| self.atom_transformer = nn.ModuleList([ |
| StructureLocalAtomTransformerBlock( |
| c_atom=c_atom, |
| c_s=c_atom, |
| c_atompair=c_atompair, |
| **atom_transformer_block, |
| ) |
| for _ in range(n_blocks) |
| ]) |
| self.downcast = Downcast(c_atom=c_atom, c_token=c_token, c_s=c_s, **downcast) |
|
|
| def forward( |
| self, |
| a: torch.Tensor, |
| s: torch.Tensor, |
| z: torch.Tensor, |
| q: torch.Tensor, |
| c: torch.Tensor, |
| p: torch.Tensor, |
| tok_idx: torch.Tensor, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, torch.Tensor, dict]: |
| for i in range(self.n_blocks): |
| q = self.upcast[i](q, a, tok_idx=tok_idx) |
| q = self.atom_transformer[i](q, c, p, **kwargs) |
|
|
| a = self.downcast(q.detach(), a.detach(), s.detach(), tok_idx=tok_idx) |
|
|
| return a, q, {} |
|
|
|
|
| class DiffusionTokenEncoder(nn.Module): |
| """Token encoder with pairformer stack for diffusion.""" |
|
|
| def __init__( |
| self, |
| c_s: int, |
| c_z: int, |
| c_token: int, |
| c_atompair: int, |
| n_pairformer_blocks: int, |
| pairformer_block: dict, |
| use_distogram: bool = True, |
| use_self: bool = True, |
| n_bins_distogram: int = 65, |
| **kwargs, |
| ): |
| super().__init__() |
|
|
| self.use_distogram = use_distogram |
| self.use_self = use_self |
| self.n_bins_distogram = n_bins_distogram |
|
|
| self.transition_1 = nn.ModuleList([ |
| Transition(c=c_s, n=2), |
| Transition(c=c_s, n=2), |
| ]) |
|
|
| n_bins_noise = n_bins_distogram |
| cat_c_z = ( |
| c_z |
| + int(use_distogram) * n_bins_noise |
| + int(use_self) * n_bins_distogram |
| ) |
|
|
| self.process_z = nn.Sequential( |
| RMSNorm(cat_c_z), |
| linearNoBias(cat_c_z, c_z), |
| ) |
|
|
| self.transition_2 = nn.ModuleList([ |
| Transition(c=c_z, n=2), |
| Transition(c=c_z, n=2), |
| ]) |
|
|
| self.pairformer_stack = nn.ModuleList([ |
| PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block) |
| for _ in range(n_pairformer_blocks) |
| ]) |
|
|
| def forward( |
| self, |
| s_init: torch.Tensor, |
| z_init: torch.Tensor, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| B = z_init.shape[0] if z_init.ndim == 4 else 1 |
|
|
| s = s_init |
| for b in range(2): |
| s = s + self.transition_1[b](s) |
|
|
| z = z_init |
| if z.ndim == 3: |
| z = z.unsqueeze(0).expand(B, -1, -1, -1) |
|
|
| z_list = [z] |
| if self.use_distogram: |
| d_noise = torch.zeros( |
| z.shape[:-1] + (self.n_bins_distogram,), |
| device=z.device, dtype=z.dtype |
| ) |
| z_list.append(d_noise) |
| if self.use_self: |
| d_self = kwargs.get("d_self") |
| if d_self is None: |
| d_self = torch.zeros( |
| z.shape[:-1] + (self.n_bins_distogram,), |
| device=z.device, dtype=z.dtype |
| ) |
| z_list.append(d_self) |
|
|
| z = torch.cat(z_list, dim=-1) |
| z = self.process_z(z) |
|
|
| for b in range(2): |
| z = z + self.transition_2[b](z) |
|
|
| for block in self.pairformer_stack: |
| s, z = block(s, z) |
|
|
| return s, z |
|
|
|
|
| class EmbeddingLayer(nn.Module): |
| """Embedding layer for 1D features - simple linear projection.""" |
|
|
| def __init__(self, n_channels: int, output_channels: int): |
| super().__init__() |
| self.weight = nn.Parameter(torch.zeros(output_channels, n_channels)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return F.linear(x, self.weight) |
|
|
|
|
| class OneDFeatureEmbedder(nn.Module): |
| """Embeds 1D features into a single vector.""" |
|
|
| def __init__(self, features: dict, output_channels: int): |
| super().__init__() |
| self.features = {k: v for k, v in features.items() if v is not None} |
| self.embedders = nn.ModuleDict({ |
| feature: EmbeddingLayer(n_channels, output_channels) |
| for feature, n_channels in self.features.items() |
| }) |
|
|
| def forward(self, f: dict) -> torch.Tensor: |
| result = None |
| for feature in self.features: |
| x = f.get(feature) |
| if x is not None: |
| emb = self.embedders[feature](x.float()) |
| result = emb if result is None else result + emb |
| return result if result is not None else torch.zeros(1) |
|
|
|
|
| class PositionPairDistEmbedder(nn.Module): |
| """Embeds pairwise position distances.""" |
|
|
| def __init__(self, c_atompair: int, embed_frame: bool = True): |
| super().__init__() |
| self.embed_frame = embed_frame |
| if embed_frame: |
| self.process_d = linearNoBias(3, c_atompair) |
| self.process_inverse_dist = linearNoBias(1, c_atompair) |
| self.process_valid_mask = linearNoBias(1, c_atompair) |
|
|
| def forward(self, ref_pos: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor: |
| D_LL = ref_pos.unsqueeze(-2) - ref_pos.unsqueeze(-3) |
| norm = torch.linalg.norm(D_LL, dim=-1, keepdim=True) ** 2 |
| norm = torch.clamp(norm, min=1e-6) |
| inv_dist = 1 / (1 + norm) |
| P_LL = self.process_inverse_dist(inv_dist) * valid_mask |
| P_LL = P_LL + self.process_valid_mask(valid_mask.float()) * valid_mask |
| return P_LL |
|
|
|
|
| class SinusoidalDistEmbed(nn.Module): |
| """Sinusoidal embedding for pairwise distances.""" |
|
|
| def __init__(self, c_atompair: int, n_freqs: int = 32): |
| super().__init__() |
| self.n_freqs = n_freqs |
| self.c_atompair = c_atompair |
| self.output_proj = linearNoBias(2 * n_freqs, c_atompair) |
| self.process_valid_mask = linearNoBias(1, c_atompair) |
|
|
| def forward(self, pos: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor: |
| D_LL = pos.unsqueeze(-2) - pos.unsqueeze(-3) |
| dist_matrix = torch.linalg.norm(D_LL, dim=-1) |
|
|
| freq = torch.exp( |
| -math.log(10000.0) * torch.arange(0, self.n_freqs, dtype=torch.float32) / self.n_freqs |
| ).to(dist_matrix.device) |
|
|
| angles = dist_matrix.unsqueeze(-1) * freq |
| sincos_embed = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) |
|
|
| P_LL = self.output_proj(sincos_embed) * valid_mask |
| P_LL = P_LL + self.process_valid_mask(valid_mask.float()) * valid_mask |
| return P_LL |
|
|
|
|
| class RelativePositionEncoding(nn.Module): |
| """Relative position encoding.""" |
|
|
| def __init__(self, r_max: int, s_max: int, c_z: int): |
| super().__init__() |
| self.r_max = r_max |
| self.s_max = s_max |
| num_tok_pos_bins = 2 * r_max + 3 |
| self.linear = linearNoBias(2 * num_tok_pos_bins + (2 * s_max + 2) + 1, c_z) |
|
|
| def forward(self, f: dict) -> torch.Tensor: |
| I = f.get("residue_index", torch.zeros(1)).shape[-1] |
| device = f.get("residue_index", torch.zeros(1)).device |
| return torch.zeros(I, I, self.linear.out_features, device=device) |
|
|
|
|
| class TokenInitializer(nn.Module): |
| """Token embedding module for RFD3 matching foundry checkpoint structure.""" |
|
|
| def __init__( |
| self, |
| c_s: int = 384, |
| c_z: int = 128, |
| c_atom: int = 128, |
| c_atompair: int = 16, |
| r_max: int = 32, |
| s_max: int = 2, |
| n_pairformer_blocks: int = 2, |
| atom_1d_features: Optional[dict] = None, |
| token_1d_features: Optional[dict] = None, |
| **kwargs, |
| ): |
| super().__init__() |
|
|
| if atom_1d_features is None: |
| atom_1d_features = { |
| "ref_atom_name_chars": 256, |
| "ref_element": 128, |
| "ref_charge": 1, |
| "ref_mask": 1, |
| "ref_is_motif_atom_with_fixed_coord": 1, |
| "ref_is_motif_atom_unindexed": 1, |
| "has_zero_occupancy": 1, |
| "ref_pos": 3, |
| "ref_atomwise_rasa": 3, |
| "active_donor": 1, |
| "active_acceptor": 1, |
| "is_atom_level_hotspot": 1, |
| } |
|
|
| if token_1d_features is None: |
| token_1d_features = { |
| "ref_motif_token_type": 3, |
| "restype": 32, |
| "ref_plddt": 1, |
| "is_non_loopy": 1, |
| } |
|
|
| cross_attention_block = {"n_head": 4, "c_model": c_atom, "dropout": 0.0, "kq_norm": True} |
|
|
| self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s) |
| self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom) |
| self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s) |
|
|
| self.downcast_atom = Downcast( |
| c_atom=c_s, c_token=c_s, c_s=None, |
| method="cross_attention", cross_attention_block=cross_attention_block |
| ) |
| self.transition_post_token = Transition(c=c_s, n=2) |
| self.transition_post_atom = Transition(c=c_s, n=2) |
| self.process_s_init = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_s)) |
|
|
| self.to_z_init_i = linearNoBias(c_s, c_z) |
| self.to_z_init_j = linearNoBias(c_s, c_z) |
| self.relative_position_encoding = RelativePositionEncoding(r_max=r_max, s_max=s_max, c_z=c_z) |
| self.relative_position_encoding2 = RelativePositionEncoding(r_max=r_max, s_max=s_max, c_z=c_z) |
| self.process_token_bonds = linearNoBias(1, c_z) |
|
|
| self.process_z_init = nn.Sequential(RMSNorm(c_z * 2), linearNoBias(c_z * 2, c_z)) |
| self.transition_1 = nn.ModuleList([Transition(c=c_z, n=2), Transition(c=c_z, n=2)]) |
| self.ref_pos_embedder_tok = PositionPairDistEmbedder(c_z, embed_frame=False) |
|
|
| pairformer_block = {"attention_pair_bias": {"n_head": 16, "kq_norm": True}, "n_transition": 4} |
| self.transformer_stack = nn.ModuleList([ |
| PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block) |
| for _ in range(n_pairformer_blocks) |
| ]) |
|
|
| self.process_s_trunk = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_atom)) |
| self.process_single_l = nn.Sequential(nn.ReLU(), linearNoBias(c_atom, c_atompair)) |
| self.process_single_m = nn.Sequential(nn.ReLU(), linearNoBias(c_atom, c_atompair)) |
| self.process_z = nn.Sequential(RMSNorm(c_z), linearNoBias(c_z, c_atompair)) |
|
|
| self.motif_pos_embedder = SinusoidalDistEmbed(c_atompair=c_atompair) |
| self.ref_pos_embedder = PositionPairDistEmbedder(c_atompair, embed_frame=False) |
| self.pair_mlp = nn.Sequential( |
| nn.ReLU(), linearNoBias(c_atompair, c_atompair), |
| nn.ReLU(), linearNoBias(c_atompair, c_atompair), |
| nn.ReLU(), linearNoBias(c_atompair, c_atompair), |
| ) |
| self.process_pll = linearNoBias(c_atompair, c_atompair) |
| self.project_pll = linearNoBias(c_atompair, c_z) |
|
|
| def forward(self, f: dict) -> dict: |
| """Compute initial representations from input features.""" |
| I = f.get("num_tokens", 100) |
| device = next(self.parameters()).device |
| dtype = next(self.parameters()).dtype |
|
|
| s_init = torch.zeros(I, self.process_s_init[1].out_features, device=device, dtype=dtype) |
| z_init = torch.zeros(I, I, self.process_z_init[1].out_features, device=device, dtype=dtype) |
|
|
| return {"S_I": s_init, "Z_II": z_init} |
|
|
|
|
| class RFD3DiffusionModule(nn.Module): |
| """ |
| RFD3 Diffusion Module matching foundry checkpoint structure. |
| |
| This module structure matches `model.diffusion_module.*` keys in the checkpoint. |
| """ |
|
|
| def __init__( |
| self, |
| c_s: int = 384, |
| c_z: int = 128, |
| c_atom: int = 128, |
| c_atompair: int = 16, |
| c_token: int = 768, |
| c_t_embed: int = 256, |
| sigma_data: float = 16.0, |
| n_pairformer_blocks: int = 2, |
| n_diffusion_blocks: int = 18, |
| n_atom_encoder_blocks: int = 3, |
| n_atom_decoder_blocks: int = 3, |
| n_head: int = 16, |
| n_pairformer_head: int = 16, |
| n_recycle: int = 2, |
| p_drop: float = 0.0, |
| ): |
| super().__init__() |
|
|
| self.sigma_data = sigma_data |
| self.n_recycle = n_recycle |
|
|
| self.process_r = linearNoBias(3, c_atom) |
| self.to_r_update = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, 3)) |
| self.sequence_head = LinearSequenceHead(c_token) |
|
|
| self.fourier_embedding = nn.ModuleList([ |
| FourierEmbedding(c_t_embed), |
| FourierEmbedding(c_t_embed), |
| ]) |
| self.process_n = nn.ModuleList([ |
| nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_atom)), |
| nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)), |
| ]) |
|
|
| cross_attention_block = { |
| "n_head": 4, |
| "c_model": c_atom, |
| "dropout": p_drop, |
| "kq_norm": True, |
| } |
|
|
| self.downcast_c = Downcast( |
| c_atom=c_atom, c_token=c_s, c_s=None, |
| method="cross_attention", cross_attention_block=cross_attention_block |
| ) |
| self.downcast_q = Downcast( |
| c_atom=c_atom, c_token=c_token, c_s=c_s, |
| method="cross_attention", cross_attention_block=cross_attention_block |
| ) |
| self.process_a = LinearEmbedWithPool(c_token) |
| self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom)) |
|
|
| atom_transformer_block = { |
| "n_head": 4, |
| "dropout": p_drop, |
| "kq_norm": True, |
| } |
|
|
| self.encoder = LocalAtomTransformer( |
| c_atom=c_atom, |
| c_s=c_atom, |
| c_atompair=c_atompair, |
| atom_transformer_block=atom_transformer_block, |
| n_blocks=n_atom_encoder_blocks, |
| ) |
|
|
| pairformer_block = { |
| "attention_pair_bias": {"n_head": n_pairformer_head, "kq_norm": False}, |
| "n_transition": 4, |
| } |
|
|
| self.diffusion_token_encoder = DiffusionTokenEncoder( |
| c_s=c_s, |
| c_z=c_z, |
| c_token=c_token, |
| c_atompair=c_atompair, |
| n_pairformer_blocks=n_pairformer_blocks, |
| pairformer_block=pairformer_block, |
| ) |
|
|
| diffusion_transformer_block = { |
| "n_head": n_head, |
| "dropout": p_drop, |
| "kq_norm": True, |
| } |
|
|
| self.diffusion_transformer = LocalTokenTransformer( |
| c_token=c_token, |
| c_tokenpair=c_z, |
| c_s=c_s, |
| diffusion_transformer_block=diffusion_transformer_block, |
| n_block=n_diffusion_blocks, |
| ) |
|
|
| decoder_upcast = { |
| "method": "cross_attention", |
| "n_split": 3, |
| "cross_attention_block": cross_attention_block, |
| } |
| decoder_downcast = { |
| "method": "cross_attention", |
| "cross_attention_block": cross_attention_block, |
| } |
|
|
| self.decoder = CompactStreamingDecoder( |
| c_atom=c_atom, |
| c_atompair=c_atompair, |
| c_token=c_token, |
| c_s=c_s, |
| c_tokenpair=c_z, |
| atom_transformer_block=atom_transformer_block, |
| upcast=decoder_upcast, |
| downcast=decoder_downcast, |
| n_blocks=n_atom_decoder_blocks, |
| ) |
|
|
| def scale_positions_in(self, x_noisy: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| if t.ndim == 1: |
| t = t[..., None, None] |
| elif t.ndim == 2: |
| t = t[..., None] |
| return x_noisy / torch.sqrt(t**2 + self.sigma_data**2) |
|
|
| def scale_positions_out( |
| self, r_update: torch.Tensor, x_noisy: torch.Tensor, t: torch.Tensor |
| ) -> torch.Tensor: |
| if t.ndim == 1: |
| t = t[..., None, None] |
| elif t.ndim == 2: |
| t = t[..., None] |
| sigma2 = self.sigma_data**2 |
| return (sigma2 / (sigma2 + t**2)) * x_noisy + ( |
| self.sigma_data * t / torch.sqrt(sigma2 + t**2) |
| ) * r_update |
|
|
| def process_time(self, t: torch.Tensor, idx: int) -> torch.Tensor: |
| t_clamped = torch.clamp(t, min=1e-20) |
| t_log = 0.25 * torch.log(t_clamped / self.sigma_data) |
| emb = self.process_n[idx](self.fourier_embedding[idx](t_log)) |
| emb = emb * (t > 0).float()[..., None] |
| return emb |
|
|
| def compute_pair_features(self, xyz: torch.Tensor, c_atompair: int) -> torch.Tensor: |
| dist = torch.cdist(xyz, xyz) |
| inv_dist = 1 / (1 + dist**2) |
| return inv_dist.unsqueeze(-1).expand(-1, -1, -1, c_atompair) |
|
|
|
|
| class RFDiffusionTransformerModel(ModelMixin, ConfigMixin): |
| """ |
| RFDiffusion3 transformer for protein structure prediction. |
| |
| This wraps the diffusion module to provide the full model interface. |
| The state dict keys match the foundry checkpoint format. |
| """ |
|
|
| config_name = "config.json" |
| _supports_gradient_checkpointing = True |
|
|
| @register_to_config |
| def __init__( |
| self, |
| c_s: int = 384, |
| c_z: int = 128, |
| c_atom: int = 128, |
| c_atompair: int = 16, |
| c_token: int = 768, |
| c_t_embed: int = 256, |
| sigma_data: float = 16.0, |
| n_pairformer_block: int = 2, |
| n_diffusion_block: int = 18, |
| n_atom_encoder_block: int = 3, |
| n_atom_decoder_block: int = 3, |
| n_head: int = 16, |
| n_pairformer_head: int = 16, |
| n_recycle: int = 2, |
| p_drop: float = 0.0, |
| ): |
| super().__init__() |
|
|
| self.token_initializer = TokenInitializer( |
| c_s=c_s, |
| c_z=c_z, |
| c_atom=c_atom, |
| c_atompair=c_atompair, |
| ) |
|
|
| self.diffusion_module = RFD3DiffusionModule( |
| c_s=c_s, |
| c_z=c_z, |
| c_atom=c_atom, |
| c_atompair=c_atompair, |
| c_token=c_token, |
| c_t_embed=c_t_embed, |
| sigma_data=sigma_data, |
| n_pairformer_blocks=n_pairformer_block, |
| n_diffusion_blocks=n_diffusion_block, |
| n_atom_encoder_blocks=n_atom_encoder_block, |
| n_atom_decoder_blocks=n_atom_decoder_block, |
| n_head=n_head, |
| n_pairformer_head=n_pairformer_head, |
| n_recycle=n_recycle, |
| p_drop=p_drop, |
| ) |
|
|
| @property |
| def sigma_data(self) -> float: |
| return self.diffusion_module.sigma_data |
|
|
| def forward( |
| self, |
| xyz_noisy: torch.Tensor, |
| t: torch.Tensor, |
| atom_to_token_map: Optional[torch.Tensor] = None, |
| motif_mask: Optional[torch.Tensor] = None, |
| s_init: Optional[torch.Tensor] = None, |
| z_init: Optional[torch.Tensor] = None, |
| n_recycle: Optional[int] = None, |
| **kwargs, |
| ) -> RFDiffusionTransformerOutput: |
| """ |
| Forward pass of the diffusion module. |
| |
| Args: |
| xyz_noisy: Noisy atom coordinates [B, L, 3] |
| t: Noise level / timestep [B] |
| atom_to_token_map: Mapping from atoms to tokens [L] |
| motif_mask: Mask for fixed motif atoms [L] |
| s_init: Initial single representation [I, c_s] |
| z_init: Initial pair representation [I, I, c_z] |
| n_recycle: Number of recycling iterations |
| |
| Returns: |
| RFDiffusionTransformerOutput with denoised coordinates |
| """ |
| dm = self.diffusion_module |
| B, L, _ = xyz_noisy.shape |
|
|
| if atom_to_token_map is None: |
| atom_to_token_map = torch.arange(L, device=xyz_noisy.device) |
| I = int(atom_to_token_map.max().item()) + 1 |
|
|
| if motif_mask is None: |
| motif_mask = torch.zeros(L, dtype=torch.bool, device=xyz_noisy.device) |
|
|
| t_L = t[:, None].expand(B, L) * (~motif_mask).float() |
| t_I = t[:, None].expand(B, I) |
|
|
| r_scaled = dm.scale_positions_in(xyz_noisy, t) |
| r_noisy = dm.scale_positions_in(xyz_noisy, t_L) |
|
|
| if s_init is None or z_init is None: |
| init_output = self.token_initializer({"num_tokens": I}) |
| if s_init is None: |
| s_init = init_output["S_I"] |
| if z_init is None: |
| z_init = init_output["Z_II"] |
|
|
| assert s_init is not None and z_init is not None |
|
|
| p = dm.compute_pair_features(r_scaled, self.config.c_atompair) |
|
|
| a_I = dm.process_a(r_noisy, tok_idx=atom_to_token_map) |
| s_init_expanded = s_init.unsqueeze(0).expand(B, -1, -1) if s_init.ndim == 2 else s_init |
| s_I = dm.downcast_c(torch.zeros(B, L, self.config.c_atom, device=xyz_noisy.device), |
| s_init_expanded, |
| tok_idx=atom_to_token_map) |
|
|
| q = dm.process_r(r_noisy) |
| c = dm.process_time(t_L, idx=0) |
| q = q + c |
| s_I = s_I + dm.process_time(t_I, idx=1) |
| c = c + dm.process_c(c) |
|
|
| q = dm.encoder(q, c, p) |
| a_I = dm.downcast_q(q, a_I, s_I, tok_idx=atom_to_token_map) |
|
|
| if n_recycle is None: |
| n_recycle = dm.n_recycle if not self.training else 1 |
| n_recycle = max(1, n_recycle) |
|
|
| z_II = z_init |
| for _ in range(n_recycle): |
| s_I, z_II = dm.diffusion_token_encoder(s_init=s_I, z_init=z_II) |
| a_I = dm.diffusion_transformer(a_I, s_I, z_II) |
|
|
| a_I, q, _ = dm.decoder(a_I, s_I, z_II, q, c, p, tok_idx=atom_to_token_map) |
|
|
| r_update = dm.to_r_update(q) |
| xyz_out = dm.scale_positions_out(r_update, xyz_noisy, t_L) |
|
|
| sequence_logits, sequence_indices = dm.sequence_head(a_I) |
|
|
| return RFDiffusionTransformerOutput( |
| xyz=xyz_out, |
| single=s_I, |
| pair=z_II, |
| sequence_logits=sequence_logits, |
| sequence_indices=sequence_indices, |
| ) |
|
|