# Copyright 2025 Dhruv Nair. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ RFDiffusion3 Transformer model. This module provides a diffusers-compatible implementation of the RFD3 architecture for protein structure prediction and generation. The module structure matches the foundry checkpoint format for direct weight loading. """ import math from dataclasses import dataclass from functools import partial from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin @dataclass class RFDiffusionTransformerOutput: """Output class for RFDiffusion transformer.""" xyz: torch.Tensor single: torch.Tensor pair: torch.Tensor sequence_logits: Optional[torch.Tensor] = None sequence_indices: Optional[torch.Tensor] = None linearNoBias = partial(nn.Linear, bias=False) class RMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) return x / rms * self.weight class FourierEmbedding(nn.Module): """Fourier feature embedding for timesteps with learned weights.""" def __init__(self, c: int): super().__init__() self.c = c self.register_buffer("w", torch.zeros(c, dtype=torch.float32)) self.register_buffer("b", torch.zeros(c, dtype=torch.float32)) self.reset_parameters() def reset_parameters(self) -> None: nn.init.normal_(self.w) nn.init.normal_(self.b) def forward(self, t: torch.Tensor) -> torch.Tensor: return torch.cos(2 * math.pi * (t[..., None] * self.w + self.b)) class LinearBiasInit(nn.Linear): """Linear layer with custom bias initialization.""" def __init__(self, *args, biasinit: float = -2.0, **kwargs): self.biasinit = biasinit super().__init__(*args, **kwargs) def reset_parameters(self) -> None: super().reset_parameters() if self.bias is not None: self.bias.data.fill_(self.biasinit) class RMSNormNoWeight(nn.Module): """RMSNorm without learnable weight (elementwise_affine=False).""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.dim = dim def forward(self, x: torch.Tensor) -> torch.Tensor: rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps) return x / rms class AdaLN(nn.Module): """Adaptive Layer Normalization.""" def __init__(self, c_a: int, c_s: int): super().__init__() self.ln_a = RMSNormNoWeight(c_a) self.ln_s = RMSNorm(c_s) self.to_gain = nn.Sequential(nn.Linear(c_s, c_a), nn.Sigmoid()) self.to_bias = linearNoBias(c_s, c_a) def forward(self, a: torch.Tensor, s: torch.Tensor) -> torch.Tensor: a = self.ln_a(a) s = self.ln_s(s) return self.to_gain(s) * a + self.to_bias(s) class Transition(nn.Module): """SwiGLU-style transition block matching foundry naming.""" def __init__(self, c: int, n: int = 4): super().__init__() self.layer_norm_1 = RMSNorm(c) self.linear_1 = linearNoBias(c, n * c) self.linear_2 = linearNoBias(c, n * c) self.linear_3 = linearNoBias(n * c, c) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.layer_norm_1(x) return self.linear_3(F.silu(self.linear_1(x)) * self.linear_2(x)) class ConditionedTransitionBlock(nn.Module): """SwiGLU transition with adaptive layer norm conditioning.""" def __init__(self, c_token: int, c_s: int, n: int = 2): super().__init__() self.ada_ln = AdaLN(c_a=c_token, c_s=c_s) self.linear_1 = linearNoBias(c_token, c_token * n) self.linear_2 = linearNoBias(c_token, c_token * n) self.linear_output_project = nn.Sequential( LinearBiasInit(c_s, c_token, biasinit=-2.0), nn.Sigmoid(), ) self.linear_3 = linearNoBias(c_token * n, c_token) def forward(self, a: torch.Tensor, s: torch.Tensor) -> torch.Tensor: a = self.ada_ln(a, s) b = F.silu(self.linear_1(a)) * self.linear_2(a) return self.linear_output_project(s) * self.linear_3(b) class MultiDimLinear(nn.Linear): """Linear layer that reshapes output to multi-dimensional shape.""" def __init__(self, in_features: int, out_shape: Tuple[int, ...], norm: bool = False, **kwargs): self.out_shape = out_shape out_features = 1 for d in out_shape: out_features *= d super().__init__(in_features, out_features, **kwargs) if norm: self.ln = RMSNorm(out_features) self.use_ln = True else: self.use_ln = False def forward(self, x: torch.Tensor) -> torch.Tensor: out = super().forward(x) if self.use_ln: out = self.ln(out) return out.reshape(x.shape[:-1] + self.out_shape) class AttentionPairBias(nn.Module): """Attention with pairwise bias for Pairformer.""" def __init__( self, c_a: int, c_s: int, c_pair: int, n_head: int = 8, kq_norm: bool = False, ): super().__init__() self.n_head = n_head self.c_a = c_a self.c_pair = c_pair self.c = c_a // n_head self.to_q = MultiDimLinear(c_a, (n_head, self.c)) self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=True) self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=True) self.to_b = linearNoBias(c_pair, n_head) self.to_g = nn.Sequential( MultiDimLinear(c_a, (n_head, self.c), bias=False), nn.Sigmoid(), ) self.to_a = linearNoBias(c_a, c_a) self.ln_0 = RMSNorm(c_pair) self.ln_1 = RMSNorm(c_a) def forward( self, a: torch.Tensor, s: Optional[torch.Tensor], z: torch.Tensor, beta: Optional[torch.Tensor] = None, ) -> torch.Tensor: a = self.ln_1(a) q = self.to_q(a) k = self.to_k(a) v = self.to_v(a) b = self.to_b(self.ln_0(z)) if beta is not None: b = b + beta[..., None] g = self.to_g(a) q = q / math.sqrt(self.c) attn = torch.einsum("...ihd,...jhd->...ijh", q, k) + b attn = F.softmax(attn, dim=-2) out = torch.einsum("...ijh,...jhc->...ihc", attn, v) out = g * out out = out.flatten(start_dim=-2) out = self.to_a(out) return out class LocalAttentionPairBias(nn.Module): """Local attention with pairwise bias for diffusion transformer blocks.""" def __init__( self, c_a: int, c_s: int, c_pair: int, n_head: int = 16, kq_norm: bool = True, ): super().__init__() self.n_head = n_head self.c = c_a self.c_head = c_a // n_head self.c_s = c_s self.use_checkpointing = False self.to_q = linearNoBias(c_a, c_a) self.to_k = linearNoBias(c_a, c_a) self.to_v = linearNoBias(c_a, c_a) self.to_b = linearNoBias(c_pair, n_head) self.to_g = nn.Sequential(linearNoBias(c_a, c_a), nn.Sigmoid()) self.to_o = linearNoBias(c_a, c_a) self.kq_norm = kq_norm if kq_norm: self.ln_q = RMSNorm(c_a) self.ln_k = RMSNorm(c_a) if c_s is not None and c_s > 0: self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s) self.linear_output_project = nn.Sequential( LinearBiasInit(c_s, c_a, biasinit=-2.0), nn.Sigmoid(), ) else: self.ln_1 = RMSNorm(c_a) def forward( self, a: torch.Tensor, s: Optional[torch.Tensor], z: torch.Tensor, **kwargs, ) -> torch.Tensor: if self.c_s is not None and self.c_s > 0: a = self.ada_ln_1(a, s) else: a = self.ln_1(a) q = self.to_q(a) k = self.to_k(a) v = self.to_v(a) g = self.to_g(a) if self.kq_norm: q = self.ln_q(q) k = self.ln_k(k) batch_dims = a.shape[:-2] L = a.shape[-2] q = q.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3) k = k.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3) v = v.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3) g = g.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3) b = self.to_b(z).permute(*range(len(batch_dims)), -1, -3, -2) attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.c_head) attn = attn + b attn = F.softmax(attn, dim=-1) out = torch.matmul(attn, v) out = out * g out = out.transpose(-2, -3).contiguous() out = out.view(*batch_dims, L, self.c) out = self.to_o(out) if self.c_s is not None and self.c_s > 0: out = self.linear_output_project(s) * out return out class PairformerBlock(nn.Module): """Pairformer block with attention and transitions.""" def __init__( self, c_s: int, c_z: int, attention_pair_bias: dict, n_transition: int = 4, p_drop: float = 0.1, **kwargs, ): super().__init__() self.z_transition = Transition(c=c_z, n=n_transition) if c_s > 0: self.s_transition = Transition(c=c_s, n=n_transition) self.attention_pair_bias = AttentionPairBias( c_a=c_s, c_s=0, c_pair=c_z, **attention_pair_bias ) def forward( self, s: torch.Tensor, z: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: z = z + self.z_transition(z) if s is not None: beta = torch.tensor([0.0], device=z.device) s = s + self.attention_pair_bias(s, None, z, beta=beta) s = s + self.s_transition(s) return s, z class StructureLocalAtomTransformerBlock(nn.Module): """Single block for atom/token transformer.""" def __init__( self, c_atom: int, c_s: Optional[int], c_atompair: int, n_head: int = 4, dropout: float = 0.0, kq_norm: bool = True, **kwargs, ): super().__init__() self.c_s = c_s self.dropout = nn.Dropout(dropout) self.attention_pair_bias = LocalAttentionPairBias( c_a=c_atom, c_s=c_s, c_pair=c_atompair, n_head=n_head, kq_norm=kq_norm ) if c_s is not None and c_s > 0: self.transition_block = ConditionedTransitionBlock(c_token=c_atom, c_s=c_s) else: self.transition_block = Transition(c=c_atom, n=4) def forward( self, q: torch.Tensor, c: Optional[torch.Tensor], p: torch.Tensor, **kwargs, ) -> torch.Tensor: q = q + self.dropout(self.attention_pair_bias(q, c, p, **kwargs)) if self.c_s is not None and self.c_s > 0: q = q + self.transition_block(q, c) else: q = q + self.transition_block(q) return q class GatedCrossAttention(nn.Module): """Gated cross attention for upcast/downcast.""" def __init__( self, c_query: int, c_kv: int, c_model: int = 128, n_head: int = 4, kq_norm: bool = True, dropout: float = 0.0, **kwargs, ): super().__init__() self.n_head = n_head self.scale = 1 / math.sqrt(c_model // n_head) self.ln_q = RMSNorm(c_query) self.ln_kv = RMSNorm(c_kv) self.to_q = linearNoBias(c_query, c_model) self.to_k = linearNoBias(c_kv, c_model) self.to_v = linearNoBias(c_kv, c_model) self.to_g = nn.Sequential(linearNoBias(c_query, c_model), nn.Sigmoid()) self.to_out = nn.Sequential(nn.Linear(c_model, c_query), nn.Dropout(dropout)) self.kq_norm = kq_norm if kq_norm: self.k_norm = RMSNorm(c_model) self.q_norm = RMSNorm(c_model) def forward( self, q: torch.Tensor, kv: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: q_in = self.ln_q(q) kv = self.ln_kv(kv) q_proj = self.to_q(q_in) k = self.to_k(kv) v = self.to_v(kv) g = self.to_g(q_in) if self.kq_norm: k = self.k_norm(k) q_proj = self.q_norm(q_proj) B = q.shape[0] n_tok = q.shape[1] if q.ndim == 4 else 1 L_q = q.shape[-2] L_kv = kv.shape[-2] c_head = q_proj.shape[-1] // self.n_head if q.ndim == 4: q_proj = q_proj.view(B, n_tok, L_q, self.n_head, c_head).permute(0, 3, 1, 2, 4) k = k.view(B, n_tok, L_kv, self.n_head, c_head).permute(0, 3, 1, 2, 4) v = v.view(B, n_tok, L_kv, self.n_head, c_head).permute(0, 3, 1, 2, 4) g = g.view(B, n_tok, L_q, self.n_head, c_head).permute(0, 3, 1, 2, 4) else: q_proj = q_proj.view(B, L_q, self.n_head, c_head).permute(0, 2, 1, 3) k = k.view(B, L_kv, self.n_head, c_head).permute(0, 2, 1, 3) v = v.view(B, L_kv, self.n_head, c_head).permute(0, 2, 1, 3) g = g.view(B, L_q, self.n_head, c_head).permute(0, 2, 1, 3) attn = torch.matmul(q_proj, k.transpose(-1, -2)) * self.scale if attn_mask is not None: if q.ndim == 4: while attn_mask.ndim < attn.ndim: attn_mask = attn_mask.unsqueeze(0) if attn_mask.shape[1] != self.n_head and attn_mask.shape[1] != 1: attn_mask = attn_mask.unsqueeze(1) else: attn_mask = attn_mask.unsqueeze(-3) attn = attn.masked_fill(~attn_mask, float("-inf")) attn = F.softmax(attn, dim=-1) out = torch.matmul(attn, v) out = out * g if q.ndim == 4: out = out.permute(0, 2, 3, 1, 4).contiguous() out = out.view(B, n_tok, L_q, -1) else: out = out.permute(0, 2, 1, 3).contiguous() out = out.view(B, L_q, -1) out = self.to_out(out) return out class Upcast(nn.Module): """Upcast from token level to atom level.""" def __init__( self, c_atom: int, c_token: int, method: str = "cross_attention", cross_attention_block: Optional[dict] = None, n_split: int = 6, **kwargs, ): super().__init__() self.method = method self.n_split = n_split if method == "broadcast": self.project = nn.Sequential(RMSNorm(c_token), linearNoBias(c_token, c_atom)) elif method == "cross_attention": self.gca = GatedCrossAttention( c_query=c_atom, c_kv=c_token // n_split, **(cross_attention_block or {}), ) def forward(self, q: torch.Tensor, a: torch.Tensor, tok_idx: torch.Tensor) -> torch.Tensor: if self.method == "broadcast": q = q + self.project(a)[..., tok_idx, :] elif self.method == "cross_attention": B, L, C = q.shape I = int(tok_idx.max().item()) + 1 a_split = a.view(B, I, self.n_split, -1) q_grouped = self._group_atoms(q, tok_idx, I) valid_mask = self._build_valid_mask(tok_idx, I, q.device) attn_mask = torch.ones(I, q_grouped.shape[2], self.n_split, device=q.device, dtype=torch.bool) attn_mask[~valid_mask] = False q_update = self.gca(q_grouped, a_split, attn_mask=attn_mask) q = q + self._ungroup_atoms(q_update, valid_mask, L) return q def _group_atoms(self, q: torch.Tensor, tok_idx: torch.Tensor, I: int) -> torch.Tensor: B, L, C = q.shape max_atoms_per_token = 14 grouped = torch.zeros(B, I, max_atoms_per_token, C, device=q.device, dtype=q.dtype) counts = torch.zeros(I, dtype=torch.long, device=q.device) for i in range(L): t = tok_idx[i].item() if counts[t] < max_atoms_per_token: grouped[:, t, counts[t]] = q[:, i] counts[t] += 1 return grouped def _build_valid_mask(self, tok_idx: torch.Tensor, I: int, device: torch.device) -> torch.Tensor: max_atoms_per_token = 14 valid_mask = torch.zeros(I, max_atoms_per_token, dtype=torch.bool, device=device) counts = torch.zeros(I, dtype=torch.long, device=device) for i in range(len(tok_idx)): t = tok_idx[i].item() if counts[t] < max_atoms_per_token: valid_mask[t, counts[t]] = True counts[t] += 1 return valid_mask def _ungroup_atoms(self, grouped: torch.Tensor, valid_mask: torch.Tensor, L: int) -> torch.Tensor: B, I, n_atoms, C = grouped.shape out = torch.zeros(B, L, C, device=grouped.device, dtype=grouped.dtype) idx = 0 for t in range(I): for a in range(n_atoms): if valid_mask[t, a] and idx < L: out[:, idx] = grouped[:, t, a] idx += 1 return out class Downcast(nn.Module): """Downcast from atom level to token level.""" def __init__( self, c_atom: int, c_token: int, c_s: Optional[int] = None, method: str = "mean", cross_attention_block: Optional[dict] = None, **kwargs, ): super().__init__() self.method = method self.c_token = c_token self.c_atom = c_atom if c_s is not None: self.process_s = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_token)) else: self.process_s = None if method == "mean": self.gca = linearNoBias(c_atom, c_token) elif method == "cross_attention": self.gca = GatedCrossAttention( c_query=c_token, c_kv=c_atom, **(cross_attention_block or {}), ) def forward( self, q: torch.Tensor, a: Optional[torch.Tensor] = None, s: Optional[torch.Tensor] = None, tok_idx: Optional[torch.Tensor] = None, ) -> torch.Tensor: if q.ndim == 2: q = q.unsqueeze(0) squeeze = True else: squeeze = False B, L, _ = q.shape I = int(tok_idx.max().item()) + 1 if self.method == "mean": projected = self.gca(q) a_update = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype) counts = torch.zeros(B, I, 1, device=q.device, dtype=q.dtype) for i in range(L): t = tok_idx[i] a_update[:, t] += projected[:, i] counts[:, t] += 1 a_update = a_update / (counts + 1e-8) elif self.method == "cross_attention": if a is None: a = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype) elif a.ndim == 2: a = a.unsqueeze(0) q_grouped, valid_mask = self._group_atoms(q, tok_idx, I) attn_mask = valid_mask.unsqueeze(-2) a_update = self.gca(a.unsqueeze(-2), q_grouped, attn_mask=attn_mask).squeeze(-2) else: a_update = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype) if a is not None: if a.ndim == 2: a = a.unsqueeze(0) a = a + a_update else: a = a_update if self.process_s is not None and s is not None: if s.ndim == 2: s = s.unsqueeze(0) a = a + self.process_s(s) if squeeze: a = a.squeeze(0) return a def _group_atoms( self, q: torch.Tensor, tok_idx: torch.Tensor, I: int ) -> Tuple[torch.Tensor, torch.Tensor]: B, L, C = q.shape max_atoms_per_token = 14 grouped = torch.zeros(B, I, max_atoms_per_token, C, device=q.device, dtype=q.dtype) valid_mask = torch.zeros(I, max_atoms_per_token, dtype=torch.bool, device=q.device) counts = torch.zeros(I, dtype=torch.long, device=q.device) for i in range(L): t = tok_idx[i].item() if counts[t] < max_atoms_per_token: grouped[:, t, counts[t]] = q[:, i] valid_mask[t, counts[t]] = True counts[t] += 1 return grouped, valid_mask class LinearEmbedWithPool(nn.Module): """Linear embedding with pooling to token level.""" def __init__(self, c_token: int): super().__init__() self.c_token = c_token self.linear = linearNoBias(3, c_token) def forward(self, r: torch.Tensor, tok_idx: torch.Tensor) -> torch.Tensor: B = r.shape[0] I = int(tok_idx.max().item()) + 1 q = self.linear(r) a = torch.zeros(B, I, self.c_token, device=r.device, dtype=q.dtype) counts = torch.zeros(B, I, 1, device=r.device, dtype=q.dtype) for i in range(r.shape[1]): t = tok_idx[i] a[:, t] += q[:, i] counts[:, t] += 1 return a / (counts + 1e-8) class LinearSequenceHead(nn.Module): """Sequence prediction head.""" def __init__(self, c_token: int): super().__init__() n_tok_all = 32 mask = torch.ones(n_tok_all, dtype=torch.bool) self.register_buffer("valid_out_mask", mask) self.linear = nn.Linear(c_token, n_tok_all) def forward(self, a: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: logits = self.linear(a) probs = F.softmax(logits, dim=-1) probs = probs * self.valid_out_mask[None, None, :].to(probs.device) probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8) indices = probs.argmax(dim=-1) return logits, indices class LocalAtomTransformer(nn.Module): """Atom-level transformer encoder.""" def __init__( self, c_atom: int, c_s: Optional[int], c_atompair: int, atom_transformer_block: dict, n_blocks: int, ): super().__init__() self.blocks = nn.ModuleList([ StructureLocalAtomTransformerBlock( c_atom=c_atom, c_s=c_s, c_atompair=c_atompair, **atom_transformer_block, ) for _ in range(n_blocks) ]) def forward( self, q: torch.Tensor, c: Optional[torch.Tensor], p: torch.Tensor, **kwargs, ) -> torch.Tensor: for block in self.blocks: q = block(q, c, p, **kwargs) return q class LocalTokenTransformer(nn.Module): """Token-level transformer for diffusion.""" def __init__( self, c_token: int, c_tokenpair: int, c_s: int, diffusion_transformer_block: dict, n_block: int, **kwargs, ): super().__init__() self.blocks = nn.ModuleList([ StructureLocalAtomTransformerBlock( c_atom=c_token, c_s=c_s, c_atompair=c_tokenpair, **diffusion_transformer_block, ) for _ in range(n_block) ]) def forward( self, a: torch.Tensor, s: torch.Tensor, z: torch.Tensor, **kwargs, ) -> torch.Tensor: for block in self.blocks: a = block(a, s, z, **kwargs) return a class CompactStreamingDecoder(nn.Module): """Decoder with upcast, atom transformer, and downcast.""" def __init__( self, c_atom: int, c_atompair: int, c_token: int, c_s: int, c_tokenpair: int, atom_transformer_block: dict, upcast: dict, downcast: dict, n_blocks: int, **kwargs, ): super().__init__() self.n_blocks = n_blocks self.upcast = nn.ModuleList([ Upcast(c_atom=c_atom, c_token=c_token, **upcast) for _ in range(n_blocks) ]) self.atom_transformer = nn.ModuleList([ StructureLocalAtomTransformerBlock( c_atom=c_atom, c_s=c_atom, c_atompair=c_atompair, **atom_transformer_block, ) for _ in range(n_blocks) ]) self.downcast = Downcast(c_atom=c_atom, c_token=c_token, c_s=c_s, **downcast) def forward( self, a: torch.Tensor, s: torch.Tensor, z: torch.Tensor, q: torch.Tensor, c: torch.Tensor, p: torch.Tensor, tok_idx: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor, dict]: for i in range(self.n_blocks): q = self.upcast[i](q, a, tok_idx=tok_idx) q = self.atom_transformer[i](q, c, p, **kwargs) a = self.downcast(q.detach(), a.detach(), s.detach(), tok_idx=tok_idx) return a, q, {} class DiffusionTokenEncoder(nn.Module): """Token encoder with pairformer stack for diffusion.""" def __init__( self, c_s: int, c_z: int, c_token: int, c_atompair: int, n_pairformer_blocks: int, pairformer_block: dict, use_distogram: bool = True, use_self: bool = True, n_bins_distogram: int = 65, **kwargs, ): super().__init__() self.use_distogram = use_distogram self.use_self = use_self self.n_bins_distogram = n_bins_distogram self.transition_1 = nn.ModuleList([ Transition(c=c_s, n=2), Transition(c=c_s, n=2), ]) n_bins_noise = n_bins_distogram cat_c_z = ( c_z + int(use_distogram) * n_bins_noise + int(use_self) * n_bins_distogram ) self.process_z = nn.Sequential( RMSNorm(cat_c_z), linearNoBias(cat_c_z, c_z), ) self.transition_2 = nn.ModuleList([ Transition(c=c_z, n=2), Transition(c=c_z, n=2), ]) self.pairformer_stack = nn.ModuleList([ PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block) for _ in range(n_pairformer_blocks) ]) def forward( self, s_init: torch.Tensor, z_init: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: B = z_init.shape[0] if z_init.ndim == 4 else 1 s = s_init for b in range(2): s = s + self.transition_1[b](s) z = z_init if z.ndim == 3: z = z.unsqueeze(0).expand(B, -1, -1, -1) z_list = [z] if self.use_distogram: d_noise = torch.zeros( z.shape[:-1] + (self.n_bins_distogram,), device=z.device, dtype=z.dtype ) z_list.append(d_noise) if self.use_self: d_self = kwargs.get("d_self") if d_self is None: d_self = torch.zeros( z.shape[:-1] + (self.n_bins_distogram,), device=z.device, dtype=z.dtype ) z_list.append(d_self) z = torch.cat(z_list, dim=-1) z = self.process_z(z) for b in range(2): z = z + self.transition_2[b](z) for block in self.pairformer_stack: s, z = block(s, z) return s, z class EmbeddingLayer(nn.Module): """Embedding layer for 1D features - simple linear projection.""" def __init__(self, n_channels: int, output_channels: int): super().__init__() self.weight = nn.Parameter(torch.zeros(output_channels, n_channels)) def forward(self, x: torch.Tensor) -> torch.Tensor: return F.linear(x, self.weight) class OneDFeatureEmbedder(nn.Module): """Embeds 1D features into a single vector.""" def __init__(self, features: dict, output_channels: int): super().__init__() self.features = {k: v for k, v in features.items() if v is not None} self.embedders = nn.ModuleDict({ feature: EmbeddingLayer(n_channels, output_channels) for feature, n_channels in self.features.items() }) def forward(self, f: dict) -> torch.Tensor: result = None for feature in self.features: x = f.get(feature) if x is not None: emb = self.embedders[feature](x.float()) result = emb if result is None else result + emb return result if result is not None else torch.zeros(1) class PositionPairDistEmbedder(nn.Module): """Embeds pairwise position distances.""" def __init__(self, c_atompair: int, embed_frame: bool = True): super().__init__() self.embed_frame = embed_frame if embed_frame: self.process_d = linearNoBias(3, c_atompair) self.process_inverse_dist = linearNoBias(1, c_atompair) self.process_valid_mask = linearNoBias(1, c_atompair) def forward(self, ref_pos: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor: D_LL = ref_pos.unsqueeze(-2) - ref_pos.unsqueeze(-3) norm = torch.linalg.norm(D_LL, dim=-1, keepdim=True) ** 2 norm = torch.clamp(norm, min=1e-6) inv_dist = 1 / (1 + norm) P_LL = self.process_inverse_dist(inv_dist) * valid_mask P_LL = P_LL + self.process_valid_mask(valid_mask.float()) * valid_mask return P_LL class SinusoidalDistEmbed(nn.Module): """Sinusoidal embedding for pairwise distances.""" def __init__(self, c_atompair: int, n_freqs: int = 32): super().__init__() self.n_freqs = n_freqs self.c_atompair = c_atompair self.output_proj = linearNoBias(2 * n_freqs, c_atompair) self.process_valid_mask = linearNoBias(1, c_atompair) def forward(self, pos: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor: D_LL = pos.unsqueeze(-2) - pos.unsqueeze(-3) dist_matrix = torch.linalg.norm(D_LL, dim=-1) freq = torch.exp( -math.log(10000.0) * torch.arange(0, self.n_freqs, dtype=torch.float32) / self.n_freqs ).to(dist_matrix.device) angles = dist_matrix.unsqueeze(-1) * freq sincos_embed = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) P_LL = self.output_proj(sincos_embed) * valid_mask P_LL = P_LL + self.process_valid_mask(valid_mask.float()) * valid_mask return P_LL class RelativePositionEncoding(nn.Module): """Relative position encoding.""" def __init__(self, r_max: int, s_max: int, c_z: int): super().__init__() self.r_max = r_max self.s_max = s_max num_tok_pos_bins = 2 * r_max + 3 self.linear = linearNoBias(2 * num_tok_pos_bins + (2 * s_max + 2) + 1, c_z) def forward(self, f: dict) -> torch.Tensor: I = f.get("residue_index", torch.zeros(1)).shape[-1] device = f.get("residue_index", torch.zeros(1)).device return torch.zeros(I, I, self.linear.out_features, device=device) class TokenInitializer(nn.Module): """Token embedding module for RFD3 matching foundry checkpoint structure.""" def __init__( self, c_s: int = 384, c_z: int = 128, c_atom: int = 128, c_atompair: int = 16, r_max: int = 32, s_max: int = 2, n_pairformer_blocks: int = 2, atom_1d_features: Optional[dict] = None, token_1d_features: Optional[dict] = None, **kwargs, ): super().__init__() if atom_1d_features is None: atom_1d_features = { "ref_atom_name_chars": 256, "ref_element": 128, "ref_charge": 1, "ref_mask": 1, "ref_is_motif_atom_with_fixed_coord": 1, "ref_is_motif_atom_unindexed": 1, "has_zero_occupancy": 1, "ref_pos": 3, "ref_atomwise_rasa": 3, "active_donor": 1, "active_acceptor": 1, "is_atom_level_hotspot": 1, } if token_1d_features is None: token_1d_features = { "ref_motif_token_type": 3, "restype": 32, "ref_plddt": 1, "is_non_loopy": 1, } cross_attention_block = {"n_head": 4, "c_model": c_atom, "dropout": 0.0, "kq_norm": True} self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s) self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom) self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s) self.downcast_atom = Downcast( c_atom=c_s, c_token=c_s, c_s=None, method="cross_attention", cross_attention_block=cross_attention_block ) self.transition_post_token = Transition(c=c_s, n=2) self.transition_post_atom = Transition(c=c_s, n=2) self.process_s_init = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_s)) self.to_z_init_i = linearNoBias(c_s, c_z) self.to_z_init_j = linearNoBias(c_s, c_z) self.relative_position_encoding = RelativePositionEncoding(r_max=r_max, s_max=s_max, c_z=c_z) self.relative_position_encoding2 = RelativePositionEncoding(r_max=r_max, s_max=s_max, c_z=c_z) self.process_token_bonds = linearNoBias(1, c_z) self.process_z_init = nn.Sequential(RMSNorm(c_z * 2), linearNoBias(c_z * 2, c_z)) self.transition_1 = nn.ModuleList([Transition(c=c_z, n=2), Transition(c=c_z, n=2)]) self.ref_pos_embedder_tok = PositionPairDistEmbedder(c_z, embed_frame=False) pairformer_block = {"attention_pair_bias": {"n_head": 16, "kq_norm": True}, "n_transition": 4} self.transformer_stack = nn.ModuleList([ PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block) for _ in range(n_pairformer_blocks) ]) self.process_s_trunk = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_atom)) self.process_single_l = nn.Sequential(nn.ReLU(), linearNoBias(c_atom, c_atompair)) self.process_single_m = nn.Sequential(nn.ReLU(), linearNoBias(c_atom, c_atompair)) self.process_z = nn.Sequential(RMSNorm(c_z), linearNoBias(c_z, c_atompair)) self.motif_pos_embedder = SinusoidalDistEmbed(c_atompair=c_atompair) self.ref_pos_embedder = PositionPairDistEmbedder(c_atompair, embed_frame=False) self.pair_mlp = nn.Sequential( nn.ReLU(), linearNoBias(c_atompair, c_atompair), nn.ReLU(), linearNoBias(c_atompair, c_atompair), nn.ReLU(), linearNoBias(c_atompair, c_atompair), ) self.process_pll = linearNoBias(c_atompair, c_atompair) self.project_pll = linearNoBias(c_atompair, c_z) def forward(self, f: dict) -> dict: """Compute initial representations from input features.""" I = f.get("num_tokens", 100) device = next(self.parameters()).device dtype = next(self.parameters()).dtype s_init = torch.zeros(I, self.process_s_init[1].out_features, device=device, dtype=dtype) z_init = torch.zeros(I, I, self.process_z_init[1].out_features, device=device, dtype=dtype) return {"S_I": s_init, "Z_II": z_init} class RFD3DiffusionModule(nn.Module): """ RFD3 Diffusion Module matching foundry checkpoint structure. This module structure matches `model.diffusion_module.*` keys in the checkpoint. """ def __init__( self, c_s: int = 384, c_z: int = 128, c_atom: int = 128, c_atompair: int = 16, c_token: int = 768, c_t_embed: int = 256, sigma_data: float = 16.0, n_pairformer_blocks: int = 2, n_diffusion_blocks: int = 18, n_atom_encoder_blocks: int = 3, n_atom_decoder_blocks: int = 3, n_head: int = 16, n_pairformer_head: int = 16, n_recycle: int = 2, p_drop: float = 0.0, ): super().__init__() self.sigma_data = sigma_data self.n_recycle = n_recycle self.process_r = linearNoBias(3, c_atom) self.to_r_update = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, 3)) self.sequence_head = LinearSequenceHead(c_token) self.fourier_embedding = nn.ModuleList([ FourierEmbedding(c_t_embed), FourierEmbedding(c_t_embed), ]) self.process_n = nn.ModuleList([ nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_atom)), nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)), ]) cross_attention_block = { "n_head": 4, "c_model": c_atom, "dropout": p_drop, "kq_norm": True, } self.downcast_c = Downcast( c_atom=c_atom, c_token=c_s, c_s=None, method="cross_attention", cross_attention_block=cross_attention_block ) self.downcast_q = Downcast( c_atom=c_atom, c_token=c_token, c_s=c_s, method="cross_attention", cross_attention_block=cross_attention_block ) self.process_a = LinearEmbedWithPool(c_token) self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom)) atom_transformer_block = { "n_head": 4, "dropout": p_drop, "kq_norm": True, } self.encoder = LocalAtomTransformer( c_atom=c_atom, c_s=c_atom, c_atompair=c_atompair, atom_transformer_block=atom_transformer_block, n_blocks=n_atom_encoder_blocks, ) pairformer_block = { "attention_pair_bias": {"n_head": n_pairformer_head, "kq_norm": False}, "n_transition": 4, } self.diffusion_token_encoder = DiffusionTokenEncoder( c_s=c_s, c_z=c_z, c_token=c_token, c_atompair=c_atompair, n_pairformer_blocks=n_pairformer_blocks, pairformer_block=pairformer_block, ) diffusion_transformer_block = { "n_head": n_head, "dropout": p_drop, "kq_norm": True, } self.diffusion_transformer = LocalTokenTransformer( c_token=c_token, c_tokenpair=c_z, c_s=c_s, diffusion_transformer_block=diffusion_transformer_block, n_block=n_diffusion_blocks, ) decoder_upcast = { "method": "cross_attention", "n_split": 3, "cross_attention_block": cross_attention_block, } decoder_downcast = { "method": "cross_attention", "cross_attention_block": cross_attention_block, } self.decoder = CompactStreamingDecoder( c_atom=c_atom, c_atompair=c_atompair, c_token=c_token, c_s=c_s, c_tokenpair=c_z, atom_transformer_block=atom_transformer_block, upcast=decoder_upcast, downcast=decoder_downcast, n_blocks=n_atom_decoder_blocks, ) def scale_positions_in(self, x_noisy: torch.Tensor, t: torch.Tensor) -> torch.Tensor: if t.ndim == 1: t = t[..., None, None] elif t.ndim == 2: t = t[..., None] return x_noisy / torch.sqrt(t**2 + self.sigma_data**2) def scale_positions_out( self, r_update: torch.Tensor, x_noisy: torch.Tensor, t: torch.Tensor ) -> torch.Tensor: if t.ndim == 1: t = t[..., None, None] elif t.ndim == 2: t = t[..., None] sigma2 = self.sigma_data**2 return (sigma2 / (sigma2 + t**2)) * x_noisy + ( self.sigma_data * t / torch.sqrt(sigma2 + t**2) ) * r_update def process_time(self, t: torch.Tensor, idx: int) -> torch.Tensor: t_clamped = torch.clamp(t, min=1e-20) t_log = 0.25 * torch.log(t_clamped / self.sigma_data) emb = self.process_n[idx](self.fourier_embedding[idx](t_log)) emb = emb * (t > 0).float()[..., None] return emb def compute_pair_features(self, xyz: torch.Tensor, c_atompair: int) -> torch.Tensor: dist = torch.cdist(xyz, xyz) inv_dist = 1 / (1 + dist**2) return inv_dist.unsqueeze(-1).expand(-1, -1, -1, c_atompair) class RFDiffusionTransformerModel(ModelMixin, ConfigMixin): """ RFDiffusion3 transformer for protein structure prediction. This wraps the diffusion module to provide the full model interface. The state dict keys match the foundry checkpoint format. """ config_name = "config.json" _supports_gradient_checkpointing = True @register_to_config def __init__( self, c_s: int = 384, c_z: int = 128, c_atom: int = 128, c_atompair: int = 16, c_token: int = 768, c_t_embed: int = 256, sigma_data: float = 16.0, n_pairformer_block: int = 2, n_diffusion_block: int = 18, n_atom_encoder_block: int = 3, n_atom_decoder_block: int = 3, n_head: int = 16, n_pairformer_head: int = 16, n_recycle: int = 2, p_drop: float = 0.0, ): super().__init__() self.token_initializer = TokenInitializer( c_s=c_s, c_z=c_z, c_atom=c_atom, c_atompair=c_atompair, ) self.diffusion_module = RFD3DiffusionModule( c_s=c_s, c_z=c_z, c_atom=c_atom, c_atompair=c_atompair, c_token=c_token, c_t_embed=c_t_embed, sigma_data=sigma_data, n_pairformer_blocks=n_pairformer_block, n_diffusion_blocks=n_diffusion_block, n_atom_encoder_blocks=n_atom_encoder_block, n_atom_decoder_blocks=n_atom_decoder_block, n_head=n_head, n_pairformer_head=n_pairformer_head, n_recycle=n_recycle, p_drop=p_drop, ) @property def sigma_data(self) -> float: return self.diffusion_module.sigma_data def forward( self, xyz_noisy: torch.Tensor, t: torch.Tensor, atom_to_token_map: Optional[torch.Tensor] = None, motif_mask: Optional[torch.Tensor] = None, s_init: Optional[torch.Tensor] = None, z_init: Optional[torch.Tensor] = None, n_recycle: Optional[int] = None, **kwargs, ) -> RFDiffusionTransformerOutput: """ Forward pass of the diffusion module. Args: xyz_noisy: Noisy atom coordinates [B, L, 3] t: Noise level / timestep [B] atom_to_token_map: Mapping from atoms to tokens [L] motif_mask: Mask for fixed motif atoms [L] s_init: Initial single representation [I, c_s] z_init: Initial pair representation [I, I, c_z] n_recycle: Number of recycling iterations Returns: RFDiffusionTransformerOutput with denoised coordinates """ dm = self.diffusion_module B, L, _ = xyz_noisy.shape if atom_to_token_map is None: atom_to_token_map = torch.arange(L, device=xyz_noisy.device) I = int(atom_to_token_map.max().item()) + 1 if motif_mask is None: motif_mask = torch.zeros(L, dtype=torch.bool, device=xyz_noisy.device) t_L = t[:, None].expand(B, L) * (~motif_mask).float() t_I = t[:, None].expand(B, I) r_scaled = dm.scale_positions_in(xyz_noisy, t) r_noisy = dm.scale_positions_in(xyz_noisy, t_L) if s_init is None or z_init is None: init_output = self.token_initializer({"num_tokens": I}) if s_init is None: s_init = init_output["S_I"] if z_init is None: z_init = init_output["Z_II"] assert s_init is not None and z_init is not None p = dm.compute_pair_features(r_scaled, self.config.c_atompair) a_I = dm.process_a(r_noisy, tok_idx=atom_to_token_map) s_init_expanded = s_init.unsqueeze(0).expand(B, -1, -1) if s_init.ndim == 2 else s_init s_I = dm.downcast_c(torch.zeros(B, L, self.config.c_atom, device=xyz_noisy.device), s_init_expanded, tok_idx=atom_to_token_map) q = dm.process_r(r_noisy) c = dm.process_time(t_L, idx=0) q = q + c s_I = s_I + dm.process_time(t_I, idx=1) c = c + dm.process_c(c) q = dm.encoder(q, c, p) a_I = dm.downcast_q(q, a_I, s_I, tok_idx=atom_to_token_map) if n_recycle is None: n_recycle = dm.n_recycle if not self.training else 1 n_recycle = max(1, n_recycle) z_II = z_init for _ in range(n_recycle): s_I, z_II = dm.diffusion_token_encoder(s_init=s_I, z_init=z_II) a_I = dm.diffusion_transformer(a_I, s_I, z_II) a_I, q, _ = dm.decoder(a_I, s_I, z_II, q, c, p, tok_idx=atom_to_token_map) r_update = dm.to_r_update(q) xyz_out = dm.scale_positions_out(r_update, xyz_noisy, t_L) sequence_logits, sequence_indices = dm.sequence_head(a_I) return RFDiffusionTransformerOutput( xyz=xyz_out, single=s_I, pair=z_II, sequence_logits=sequence_logits, sequence_indices=sequence_indices, )