dn6's picture
dn6 HF Staff
Upload transformer/model.py with huggingface_hub
ceca157 verified
# Copyright 2025 Dhruv Nair. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RFDiffusion3 Transformer model.
This module provides a diffusers-compatible implementation of the RFD3
architecture for protein structure prediction and generation. The module
structure matches the foundry checkpoint format for direct weight loading.
"""
import math
from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
@dataclass
class RFDiffusionTransformerOutput:
"""Output class for RFDiffusion transformer."""
xyz: torch.Tensor
single: torch.Tensor
pair: torch.Tensor
sequence_logits: Optional[torch.Tensor] = None
sequence_indices: Optional[torch.Tensor] = None
linearNoBias = partial(nn.Linear, bias=False)
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
class FourierEmbedding(nn.Module):
"""Fourier feature embedding for timesteps with learned weights."""
def __init__(self, c: int):
super().__init__()
self.c = c
self.register_buffer("w", torch.zeros(c, dtype=torch.float32))
self.register_buffer("b", torch.zeros(c, dtype=torch.float32))
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.w)
nn.init.normal_(self.b)
def forward(self, t: torch.Tensor) -> torch.Tensor:
return torch.cos(2 * math.pi * (t[..., None] * self.w + self.b))
class LinearBiasInit(nn.Linear):
"""Linear layer with custom bias initialization."""
def __init__(self, *args, biasinit: float = -2.0, **kwargs):
self.biasinit = biasinit
super().__init__(*args, **kwargs)
def reset_parameters(self) -> None:
super().reset_parameters()
if self.bias is not None:
self.bias.data.fill_(self.biasinit)
class RMSNormNoWeight(nn.Module):
"""RMSNorm without learnable weight (elementwise_affine=False)."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.dim = dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
return x / rms
class AdaLN(nn.Module):
"""Adaptive Layer Normalization."""
def __init__(self, c_a: int, c_s: int):
super().__init__()
self.ln_a = RMSNormNoWeight(c_a)
self.ln_s = RMSNorm(c_s)
self.to_gain = nn.Sequential(nn.Linear(c_s, c_a), nn.Sigmoid())
self.to_bias = linearNoBias(c_s, c_a)
def forward(self, a: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
a = self.ln_a(a)
s = self.ln_s(s)
return self.to_gain(s) * a + self.to_bias(s)
class Transition(nn.Module):
"""SwiGLU-style transition block matching foundry naming."""
def __init__(self, c: int, n: int = 4):
super().__init__()
self.layer_norm_1 = RMSNorm(c)
self.linear_1 = linearNoBias(c, n * c)
self.linear_2 = linearNoBias(c, n * c)
self.linear_3 = linearNoBias(n * c, c)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.layer_norm_1(x)
return self.linear_3(F.silu(self.linear_1(x)) * self.linear_2(x))
class ConditionedTransitionBlock(nn.Module):
"""SwiGLU transition with adaptive layer norm conditioning."""
def __init__(self, c_token: int, c_s: int, n: int = 2):
super().__init__()
self.ada_ln = AdaLN(c_a=c_token, c_s=c_s)
self.linear_1 = linearNoBias(c_token, c_token * n)
self.linear_2 = linearNoBias(c_token, c_token * n)
self.linear_output_project = nn.Sequential(
LinearBiasInit(c_s, c_token, biasinit=-2.0),
nn.Sigmoid(),
)
self.linear_3 = linearNoBias(c_token * n, c_token)
def forward(self, a: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
a = self.ada_ln(a, s)
b = F.silu(self.linear_1(a)) * self.linear_2(a)
return self.linear_output_project(s) * self.linear_3(b)
class MultiDimLinear(nn.Linear):
"""Linear layer that reshapes output to multi-dimensional shape."""
def __init__(self, in_features: int, out_shape: Tuple[int, ...], norm: bool = False, **kwargs):
self.out_shape = out_shape
out_features = 1
for d in out_shape:
out_features *= d
super().__init__(in_features, out_features, **kwargs)
if norm:
self.ln = RMSNorm(out_features)
self.use_ln = True
else:
self.use_ln = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = super().forward(x)
if self.use_ln:
out = self.ln(out)
return out.reshape(x.shape[:-1] + self.out_shape)
class AttentionPairBias(nn.Module):
"""Attention with pairwise bias for Pairformer."""
def __init__(
self,
c_a: int,
c_s: int,
c_pair: int,
n_head: int = 8,
kq_norm: bool = False,
):
super().__init__()
self.n_head = n_head
self.c_a = c_a
self.c_pair = c_pair
self.c = c_a // n_head
self.to_q = MultiDimLinear(c_a, (n_head, self.c))
self.to_k = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=True)
self.to_v = MultiDimLinear(c_a, (n_head, self.c), bias=False, norm=True)
self.to_b = linearNoBias(c_pair, n_head)
self.to_g = nn.Sequential(
MultiDimLinear(c_a, (n_head, self.c), bias=False),
nn.Sigmoid(),
)
self.to_a = linearNoBias(c_a, c_a)
self.ln_0 = RMSNorm(c_pair)
self.ln_1 = RMSNorm(c_a)
def forward(
self,
a: torch.Tensor,
s: Optional[torch.Tensor],
z: torch.Tensor,
beta: Optional[torch.Tensor] = None,
) -> torch.Tensor:
a = self.ln_1(a)
q = self.to_q(a)
k = self.to_k(a)
v = self.to_v(a)
b = self.to_b(self.ln_0(z))
if beta is not None:
b = b + beta[..., None]
g = self.to_g(a)
q = q / math.sqrt(self.c)
attn = torch.einsum("...ihd,...jhd->...ijh", q, k) + b
attn = F.softmax(attn, dim=-2)
out = torch.einsum("...ijh,...jhc->...ihc", attn, v)
out = g * out
out = out.flatten(start_dim=-2)
out = self.to_a(out)
return out
class LocalAttentionPairBias(nn.Module):
"""Local attention with pairwise bias for diffusion transformer blocks."""
def __init__(
self,
c_a: int,
c_s: int,
c_pair: int,
n_head: int = 16,
kq_norm: bool = True,
):
super().__init__()
self.n_head = n_head
self.c = c_a
self.c_head = c_a // n_head
self.c_s = c_s
self.use_checkpointing = False
self.to_q = linearNoBias(c_a, c_a)
self.to_k = linearNoBias(c_a, c_a)
self.to_v = linearNoBias(c_a, c_a)
self.to_b = linearNoBias(c_pair, n_head)
self.to_g = nn.Sequential(linearNoBias(c_a, c_a), nn.Sigmoid())
self.to_o = linearNoBias(c_a, c_a)
self.kq_norm = kq_norm
if kq_norm:
self.ln_q = RMSNorm(c_a)
self.ln_k = RMSNorm(c_a)
if c_s is not None and c_s > 0:
self.ada_ln_1 = AdaLN(c_a=c_a, c_s=c_s)
self.linear_output_project = nn.Sequential(
LinearBiasInit(c_s, c_a, biasinit=-2.0),
nn.Sigmoid(),
)
else:
self.ln_1 = RMSNorm(c_a)
def forward(
self,
a: torch.Tensor,
s: Optional[torch.Tensor],
z: torch.Tensor,
**kwargs,
) -> torch.Tensor:
if self.c_s is not None and self.c_s > 0:
a = self.ada_ln_1(a, s)
else:
a = self.ln_1(a)
q = self.to_q(a)
k = self.to_k(a)
v = self.to_v(a)
g = self.to_g(a)
if self.kq_norm:
q = self.ln_q(q)
k = self.ln_k(k)
batch_dims = a.shape[:-2]
L = a.shape[-2]
q = q.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3)
k = k.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3)
v = v.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3)
g = g.view(*batch_dims, L, self.n_head, self.c_head).transpose(-2, -3)
b = self.to_b(z).permute(*range(len(batch_dims)), -1, -3, -2)
attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.c_head)
attn = attn + b
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out * g
out = out.transpose(-2, -3).contiguous()
out = out.view(*batch_dims, L, self.c)
out = self.to_o(out)
if self.c_s is not None and self.c_s > 0:
out = self.linear_output_project(s) * out
return out
class PairformerBlock(nn.Module):
"""Pairformer block with attention and transitions."""
def __init__(
self,
c_s: int,
c_z: int,
attention_pair_bias: dict,
n_transition: int = 4,
p_drop: float = 0.1,
**kwargs,
):
super().__init__()
self.z_transition = Transition(c=c_z, n=n_transition)
if c_s > 0:
self.s_transition = Transition(c=c_s, n=n_transition)
self.attention_pair_bias = AttentionPairBias(
c_a=c_s, c_s=0, c_pair=c_z, **attention_pair_bias
)
def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
z = z + self.z_transition(z)
if s is not None:
beta = torch.tensor([0.0], device=z.device)
s = s + self.attention_pair_bias(s, None, z, beta=beta)
s = s + self.s_transition(s)
return s, z
class StructureLocalAtomTransformerBlock(nn.Module):
"""Single block for atom/token transformer."""
def __init__(
self,
c_atom: int,
c_s: Optional[int],
c_atompair: int,
n_head: int = 4,
dropout: float = 0.0,
kq_norm: bool = True,
**kwargs,
):
super().__init__()
self.c_s = c_s
self.dropout = nn.Dropout(dropout)
self.attention_pair_bias = LocalAttentionPairBias(
c_a=c_atom, c_s=c_s, c_pair=c_atompair, n_head=n_head, kq_norm=kq_norm
)
if c_s is not None and c_s > 0:
self.transition_block = ConditionedTransitionBlock(c_token=c_atom, c_s=c_s)
else:
self.transition_block = Transition(c=c_atom, n=4)
def forward(
self,
q: torch.Tensor,
c: Optional[torch.Tensor],
p: torch.Tensor,
**kwargs,
) -> torch.Tensor:
q = q + self.dropout(self.attention_pair_bias(q, c, p, **kwargs))
if self.c_s is not None and self.c_s > 0:
q = q + self.transition_block(q, c)
else:
q = q + self.transition_block(q)
return q
class GatedCrossAttention(nn.Module):
"""Gated cross attention for upcast/downcast."""
def __init__(
self,
c_query: int,
c_kv: int,
c_model: int = 128,
n_head: int = 4,
kq_norm: bool = True,
dropout: float = 0.0,
**kwargs,
):
super().__init__()
self.n_head = n_head
self.scale = 1 / math.sqrt(c_model // n_head)
self.ln_q = RMSNorm(c_query)
self.ln_kv = RMSNorm(c_kv)
self.to_q = linearNoBias(c_query, c_model)
self.to_k = linearNoBias(c_kv, c_model)
self.to_v = linearNoBias(c_kv, c_model)
self.to_g = nn.Sequential(linearNoBias(c_query, c_model), nn.Sigmoid())
self.to_out = nn.Sequential(nn.Linear(c_model, c_query), nn.Dropout(dropout))
self.kq_norm = kq_norm
if kq_norm:
self.k_norm = RMSNorm(c_model)
self.q_norm = RMSNorm(c_model)
def forward(
self,
q: torch.Tensor,
kv: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q_in = self.ln_q(q)
kv = self.ln_kv(kv)
q_proj = self.to_q(q_in)
k = self.to_k(kv)
v = self.to_v(kv)
g = self.to_g(q_in)
if self.kq_norm:
k = self.k_norm(k)
q_proj = self.q_norm(q_proj)
B = q.shape[0]
n_tok = q.shape[1] if q.ndim == 4 else 1
L_q = q.shape[-2]
L_kv = kv.shape[-2]
c_head = q_proj.shape[-1] // self.n_head
if q.ndim == 4:
q_proj = q_proj.view(B, n_tok, L_q, self.n_head, c_head).permute(0, 3, 1, 2, 4)
k = k.view(B, n_tok, L_kv, self.n_head, c_head).permute(0, 3, 1, 2, 4)
v = v.view(B, n_tok, L_kv, self.n_head, c_head).permute(0, 3, 1, 2, 4)
g = g.view(B, n_tok, L_q, self.n_head, c_head).permute(0, 3, 1, 2, 4)
else:
q_proj = q_proj.view(B, L_q, self.n_head, c_head).permute(0, 2, 1, 3)
k = k.view(B, L_kv, self.n_head, c_head).permute(0, 2, 1, 3)
v = v.view(B, L_kv, self.n_head, c_head).permute(0, 2, 1, 3)
g = g.view(B, L_q, self.n_head, c_head).permute(0, 2, 1, 3)
attn = torch.matmul(q_proj, k.transpose(-1, -2)) * self.scale
if attn_mask is not None:
if q.ndim == 4:
while attn_mask.ndim < attn.ndim:
attn_mask = attn_mask.unsqueeze(0)
if attn_mask.shape[1] != self.n_head and attn_mask.shape[1] != 1:
attn_mask = attn_mask.unsqueeze(1)
else:
attn_mask = attn_mask.unsqueeze(-3)
attn = attn.masked_fill(~attn_mask, float("-inf"))
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out * g
if q.ndim == 4:
out = out.permute(0, 2, 3, 1, 4).contiguous()
out = out.view(B, n_tok, L_q, -1)
else:
out = out.permute(0, 2, 1, 3).contiguous()
out = out.view(B, L_q, -1)
out = self.to_out(out)
return out
class Upcast(nn.Module):
"""Upcast from token level to atom level."""
def __init__(
self,
c_atom: int,
c_token: int,
method: str = "cross_attention",
cross_attention_block: Optional[dict] = None,
n_split: int = 6,
**kwargs,
):
super().__init__()
self.method = method
self.n_split = n_split
if method == "broadcast":
self.project = nn.Sequential(RMSNorm(c_token), linearNoBias(c_token, c_atom))
elif method == "cross_attention":
self.gca = GatedCrossAttention(
c_query=c_atom,
c_kv=c_token // n_split,
**(cross_attention_block or {}),
)
def forward(self, q: torch.Tensor, a: torch.Tensor, tok_idx: torch.Tensor) -> torch.Tensor:
if self.method == "broadcast":
q = q + self.project(a)[..., tok_idx, :]
elif self.method == "cross_attention":
B, L, C = q.shape
I = int(tok_idx.max().item()) + 1
a_split = a.view(B, I, self.n_split, -1)
q_grouped = self._group_atoms(q, tok_idx, I)
valid_mask = self._build_valid_mask(tok_idx, I, q.device)
attn_mask = torch.ones(I, q_grouped.shape[2], self.n_split, device=q.device, dtype=torch.bool)
attn_mask[~valid_mask] = False
q_update = self.gca(q_grouped, a_split, attn_mask=attn_mask)
q = q + self._ungroup_atoms(q_update, valid_mask, L)
return q
def _group_atoms(self, q: torch.Tensor, tok_idx: torch.Tensor, I: int) -> torch.Tensor:
B, L, C = q.shape
max_atoms_per_token = 14
grouped = torch.zeros(B, I, max_atoms_per_token, C, device=q.device, dtype=q.dtype)
counts = torch.zeros(I, dtype=torch.long, device=q.device)
for i in range(L):
t = tok_idx[i].item()
if counts[t] < max_atoms_per_token:
grouped[:, t, counts[t]] = q[:, i]
counts[t] += 1
return grouped
def _build_valid_mask(self, tok_idx: torch.Tensor, I: int, device: torch.device) -> torch.Tensor:
max_atoms_per_token = 14
valid_mask = torch.zeros(I, max_atoms_per_token, dtype=torch.bool, device=device)
counts = torch.zeros(I, dtype=torch.long, device=device)
for i in range(len(tok_idx)):
t = tok_idx[i].item()
if counts[t] < max_atoms_per_token:
valid_mask[t, counts[t]] = True
counts[t] += 1
return valid_mask
def _ungroup_atoms(self, grouped: torch.Tensor, valid_mask: torch.Tensor, L: int) -> torch.Tensor:
B, I, n_atoms, C = grouped.shape
out = torch.zeros(B, L, C, device=grouped.device, dtype=grouped.dtype)
idx = 0
for t in range(I):
for a in range(n_atoms):
if valid_mask[t, a] and idx < L:
out[:, idx] = grouped[:, t, a]
idx += 1
return out
class Downcast(nn.Module):
"""Downcast from atom level to token level."""
def __init__(
self,
c_atom: int,
c_token: int,
c_s: Optional[int] = None,
method: str = "mean",
cross_attention_block: Optional[dict] = None,
**kwargs,
):
super().__init__()
self.method = method
self.c_token = c_token
self.c_atom = c_atom
if c_s is not None:
self.process_s = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_token))
else:
self.process_s = None
if method == "mean":
self.gca = linearNoBias(c_atom, c_token)
elif method == "cross_attention":
self.gca = GatedCrossAttention(
c_query=c_token,
c_kv=c_atom,
**(cross_attention_block or {}),
)
def forward(
self,
q: torch.Tensor,
a: Optional[torch.Tensor] = None,
s: Optional[torch.Tensor] = None,
tok_idx: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if q.ndim == 2:
q = q.unsqueeze(0)
squeeze = True
else:
squeeze = False
B, L, _ = q.shape
I = int(tok_idx.max().item()) + 1
if self.method == "mean":
projected = self.gca(q)
a_update = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype)
counts = torch.zeros(B, I, 1, device=q.device, dtype=q.dtype)
for i in range(L):
t = tok_idx[i]
a_update[:, t] += projected[:, i]
counts[:, t] += 1
a_update = a_update / (counts + 1e-8)
elif self.method == "cross_attention":
if a is None:
a = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype)
elif a.ndim == 2:
a = a.unsqueeze(0)
q_grouped, valid_mask = self._group_atoms(q, tok_idx, I)
attn_mask = valid_mask.unsqueeze(-2)
a_update = self.gca(a.unsqueeze(-2), q_grouped, attn_mask=attn_mask).squeeze(-2)
else:
a_update = torch.zeros(B, I, self.c_token, device=q.device, dtype=q.dtype)
if a is not None:
if a.ndim == 2:
a = a.unsqueeze(0)
a = a + a_update
else:
a = a_update
if self.process_s is not None and s is not None:
if s.ndim == 2:
s = s.unsqueeze(0)
a = a + self.process_s(s)
if squeeze:
a = a.squeeze(0)
return a
def _group_atoms(
self, q: torch.Tensor, tok_idx: torch.Tensor, I: int
) -> Tuple[torch.Tensor, torch.Tensor]:
B, L, C = q.shape
max_atoms_per_token = 14
grouped = torch.zeros(B, I, max_atoms_per_token, C, device=q.device, dtype=q.dtype)
valid_mask = torch.zeros(I, max_atoms_per_token, dtype=torch.bool, device=q.device)
counts = torch.zeros(I, dtype=torch.long, device=q.device)
for i in range(L):
t = tok_idx[i].item()
if counts[t] < max_atoms_per_token:
grouped[:, t, counts[t]] = q[:, i]
valid_mask[t, counts[t]] = True
counts[t] += 1
return grouped, valid_mask
class LinearEmbedWithPool(nn.Module):
"""Linear embedding with pooling to token level."""
def __init__(self, c_token: int):
super().__init__()
self.c_token = c_token
self.linear = linearNoBias(3, c_token)
def forward(self, r: torch.Tensor, tok_idx: torch.Tensor) -> torch.Tensor:
B = r.shape[0]
I = int(tok_idx.max().item()) + 1
q = self.linear(r)
a = torch.zeros(B, I, self.c_token, device=r.device, dtype=q.dtype)
counts = torch.zeros(B, I, 1, device=r.device, dtype=q.dtype)
for i in range(r.shape[1]):
t = tok_idx[i]
a[:, t] += q[:, i]
counts[:, t] += 1
return a / (counts + 1e-8)
class LinearSequenceHead(nn.Module):
"""Sequence prediction head."""
def __init__(self, c_token: int):
super().__init__()
n_tok_all = 32
mask = torch.ones(n_tok_all, dtype=torch.bool)
self.register_buffer("valid_out_mask", mask)
self.linear = nn.Linear(c_token, n_tok_all)
def forward(self, a: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
logits = self.linear(a)
probs = F.softmax(logits, dim=-1)
probs = probs * self.valid_out_mask[None, None, :].to(probs.device)
probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8)
indices = probs.argmax(dim=-1)
return logits, indices
class LocalAtomTransformer(nn.Module):
"""Atom-level transformer encoder."""
def __init__(
self,
c_atom: int,
c_s: Optional[int],
c_atompair: int,
atom_transformer_block: dict,
n_blocks: int,
):
super().__init__()
self.blocks = nn.ModuleList([
StructureLocalAtomTransformerBlock(
c_atom=c_atom,
c_s=c_s,
c_atompair=c_atompair,
**atom_transformer_block,
)
for _ in range(n_blocks)
])
def forward(
self,
q: torch.Tensor,
c: Optional[torch.Tensor],
p: torch.Tensor,
**kwargs,
) -> torch.Tensor:
for block in self.blocks:
q = block(q, c, p, **kwargs)
return q
class LocalTokenTransformer(nn.Module):
"""Token-level transformer for diffusion."""
def __init__(
self,
c_token: int,
c_tokenpair: int,
c_s: int,
diffusion_transformer_block: dict,
n_block: int,
**kwargs,
):
super().__init__()
self.blocks = nn.ModuleList([
StructureLocalAtomTransformerBlock(
c_atom=c_token,
c_s=c_s,
c_atompair=c_tokenpair,
**diffusion_transformer_block,
)
for _ in range(n_block)
])
def forward(
self,
a: torch.Tensor,
s: torch.Tensor,
z: torch.Tensor,
**kwargs,
) -> torch.Tensor:
for block in self.blocks:
a = block(a, s, z, **kwargs)
return a
class CompactStreamingDecoder(nn.Module):
"""Decoder with upcast, atom transformer, and downcast."""
def __init__(
self,
c_atom: int,
c_atompair: int,
c_token: int,
c_s: int,
c_tokenpair: int,
atom_transformer_block: dict,
upcast: dict,
downcast: dict,
n_blocks: int,
**kwargs,
):
super().__init__()
self.n_blocks = n_blocks
self.upcast = nn.ModuleList([
Upcast(c_atom=c_atom, c_token=c_token, **upcast)
for _ in range(n_blocks)
])
self.atom_transformer = nn.ModuleList([
StructureLocalAtomTransformerBlock(
c_atom=c_atom,
c_s=c_atom,
c_atompair=c_atompair,
**atom_transformer_block,
)
for _ in range(n_blocks)
])
self.downcast = Downcast(c_atom=c_atom, c_token=c_token, c_s=c_s, **downcast)
def forward(
self,
a: torch.Tensor,
s: torch.Tensor,
z: torch.Tensor,
q: torch.Tensor,
c: torch.Tensor,
p: torch.Tensor,
tok_idx: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
for i in range(self.n_blocks):
q = self.upcast[i](q, a, tok_idx=tok_idx)
q = self.atom_transformer[i](q, c, p, **kwargs)
a = self.downcast(q.detach(), a.detach(), s.detach(), tok_idx=tok_idx)
return a, q, {}
class DiffusionTokenEncoder(nn.Module):
"""Token encoder with pairformer stack for diffusion."""
def __init__(
self,
c_s: int,
c_z: int,
c_token: int,
c_atompair: int,
n_pairformer_blocks: int,
pairformer_block: dict,
use_distogram: bool = True,
use_self: bool = True,
n_bins_distogram: int = 65,
**kwargs,
):
super().__init__()
self.use_distogram = use_distogram
self.use_self = use_self
self.n_bins_distogram = n_bins_distogram
self.transition_1 = nn.ModuleList([
Transition(c=c_s, n=2),
Transition(c=c_s, n=2),
])
n_bins_noise = n_bins_distogram
cat_c_z = (
c_z
+ int(use_distogram) * n_bins_noise
+ int(use_self) * n_bins_distogram
)
self.process_z = nn.Sequential(
RMSNorm(cat_c_z),
linearNoBias(cat_c_z, c_z),
)
self.transition_2 = nn.ModuleList([
Transition(c=c_z, n=2),
Transition(c=c_z, n=2),
])
self.pairformer_stack = nn.ModuleList([
PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
for _ in range(n_pairformer_blocks)
])
def forward(
self,
s_init: torch.Tensor,
z_init: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
B = z_init.shape[0] if z_init.ndim == 4 else 1
s = s_init
for b in range(2):
s = s + self.transition_1[b](s)
z = z_init
if z.ndim == 3:
z = z.unsqueeze(0).expand(B, -1, -1, -1)
z_list = [z]
if self.use_distogram:
d_noise = torch.zeros(
z.shape[:-1] + (self.n_bins_distogram,),
device=z.device, dtype=z.dtype
)
z_list.append(d_noise)
if self.use_self:
d_self = kwargs.get("d_self")
if d_self is None:
d_self = torch.zeros(
z.shape[:-1] + (self.n_bins_distogram,),
device=z.device, dtype=z.dtype
)
z_list.append(d_self)
z = torch.cat(z_list, dim=-1)
z = self.process_z(z)
for b in range(2):
z = z + self.transition_2[b](z)
for block in self.pairformer_stack:
s, z = block(s, z)
return s, z
class EmbeddingLayer(nn.Module):
"""Embedding layer for 1D features - simple linear projection."""
def __init__(self, n_channels: int, output_channels: int):
super().__init__()
self.weight = nn.Parameter(torch.zeros(output_channels, n_channels))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight)
class OneDFeatureEmbedder(nn.Module):
"""Embeds 1D features into a single vector."""
def __init__(self, features: dict, output_channels: int):
super().__init__()
self.features = {k: v for k, v in features.items() if v is not None}
self.embedders = nn.ModuleDict({
feature: EmbeddingLayer(n_channels, output_channels)
for feature, n_channels in self.features.items()
})
def forward(self, f: dict) -> torch.Tensor:
result = None
for feature in self.features:
x = f.get(feature)
if x is not None:
emb = self.embedders[feature](x.float())
result = emb if result is None else result + emb
return result if result is not None else torch.zeros(1)
class PositionPairDistEmbedder(nn.Module):
"""Embeds pairwise position distances."""
def __init__(self, c_atompair: int, embed_frame: bool = True):
super().__init__()
self.embed_frame = embed_frame
if embed_frame:
self.process_d = linearNoBias(3, c_atompair)
self.process_inverse_dist = linearNoBias(1, c_atompair)
self.process_valid_mask = linearNoBias(1, c_atompair)
def forward(self, ref_pos: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
D_LL = ref_pos.unsqueeze(-2) - ref_pos.unsqueeze(-3)
norm = torch.linalg.norm(D_LL, dim=-1, keepdim=True) ** 2
norm = torch.clamp(norm, min=1e-6)
inv_dist = 1 / (1 + norm)
P_LL = self.process_inverse_dist(inv_dist) * valid_mask
P_LL = P_LL + self.process_valid_mask(valid_mask.float()) * valid_mask
return P_LL
class SinusoidalDistEmbed(nn.Module):
"""Sinusoidal embedding for pairwise distances."""
def __init__(self, c_atompair: int, n_freqs: int = 32):
super().__init__()
self.n_freqs = n_freqs
self.c_atompair = c_atompair
self.output_proj = linearNoBias(2 * n_freqs, c_atompair)
self.process_valid_mask = linearNoBias(1, c_atompair)
def forward(self, pos: torch.Tensor, valid_mask: torch.Tensor) -> torch.Tensor:
D_LL = pos.unsqueeze(-2) - pos.unsqueeze(-3)
dist_matrix = torch.linalg.norm(D_LL, dim=-1)
freq = torch.exp(
-math.log(10000.0) * torch.arange(0, self.n_freqs, dtype=torch.float32) / self.n_freqs
).to(dist_matrix.device)
angles = dist_matrix.unsqueeze(-1) * freq
sincos_embed = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
P_LL = self.output_proj(sincos_embed) * valid_mask
P_LL = P_LL + self.process_valid_mask(valid_mask.float()) * valid_mask
return P_LL
class RelativePositionEncoding(nn.Module):
"""Relative position encoding."""
def __init__(self, r_max: int, s_max: int, c_z: int):
super().__init__()
self.r_max = r_max
self.s_max = s_max
num_tok_pos_bins = 2 * r_max + 3
self.linear = linearNoBias(2 * num_tok_pos_bins + (2 * s_max + 2) + 1, c_z)
def forward(self, f: dict) -> torch.Tensor:
I = f.get("residue_index", torch.zeros(1)).shape[-1]
device = f.get("residue_index", torch.zeros(1)).device
return torch.zeros(I, I, self.linear.out_features, device=device)
class TokenInitializer(nn.Module):
"""Token embedding module for RFD3 matching foundry checkpoint structure."""
def __init__(
self,
c_s: int = 384,
c_z: int = 128,
c_atom: int = 128,
c_atompair: int = 16,
r_max: int = 32,
s_max: int = 2,
n_pairformer_blocks: int = 2,
atom_1d_features: Optional[dict] = None,
token_1d_features: Optional[dict] = None,
**kwargs,
):
super().__init__()
if atom_1d_features is None:
atom_1d_features = {
"ref_atom_name_chars": 256,
"ref_element": 128,
"ref_charge": 1,
"ref_mask": 1,
"ref_is_motif_atom_with_fixed_coord": 1,
"ref_is_motif_atom_unindexed": 1,
"has_zero_occupancy": 1,
"ref_pos": 3,
"ref_atomwise_rasa": 3,
"active_donor": 1,
"active_acceptor": 1,
"is_atom_level_hotspot": 1,
}
if token_1d_features is None:
token_1d_features = {
"ref_motif_token_type": 3,
"restype": 32,
"ref_plddt": 1,
"is_non_loopy": 1,
}
cross_attention_block = {"n_head": 4, "c_model": c_atom, "dropout": 0.0, "kq_norm": True}
self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
self.downcast_atom = Downcast(
c_atom=c_s, c_token=c_s, c_s=None,
method="cross_attention", cross_attention_block=cross_attention_block
)
self.transition_post_token = Transition(c=c_s, n=2)
self.transition_post_atom = Transition(c=c_s, n=2)
self.process_s_init = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_s))
self.to_z_init_i = linearNoBias(c_s, c_z)
self.to_z_init_j = linearNoBias(c_s, c_z)
self.relative_position_encoding = RelativePositionEncoding(r_max=r_max, s_max=s_max, c_z=c_z)
self.relative_position_encoding2 = RelativePositionEncoding(r_max=r_max, s_max=s_max, c_z=c_z)
self.process_token_bonds = linearNoBias(1, c_z)
self.process_z_init = nn.Sequential(RMSNorm(c_z * 2), linearNoBias(c_z * 2, c_z))
self.transition_1 = nn.ModuleList([Transition(c=c_z, n=2), Transition(c=c_z, n=2)])
self.ref_pos_embedder_tok = PositionPairDistEmbedder(c_z, embed_frame=False)
pairformer_block = {"attention_pair_bias": {"n_head": 16, "kq_norm": True}, "n_transition": 4}
self.transformer_stack = nn.ModuleList([
PairformerBlock(c_s=c_s, c_z=c_z, **pairformer_block)
for _ in range(n_pairformer_blocks)
])
self.process_s_trunk = nn.Sequential(RMSNorm(c_s), linearNoBias(c_s, c_atom))
self.process_single_l = nn.Sequential(nn.ReLU(), linearNoBias(c_atom, c_atompair))
self.process_single_m = nn.Sequential(nn.ReLU(), linearNoBias(c_atom, c_atompair))
self.process_z = nn.Sequential(RMSNorm(c_z), linearNoBias(c_z, c_atompair))
self.motif_pos_embedder = SinusoidalDistEmbed(c_atompair=c_atompair)
self.ref_pos_embedder = PositionPairDistEmbedder(c_atompair, embed_frame=False)
self.pair_mlp = nn.Sequential(
nn.ReLU(), linearNoBias(c_atompair, c_atompair),
nn.ReLU(), linearNoBias(c_atompair, c_atompair),
nn.ReLU(), linearNoBias(c_atompair, c_atompair),
)
self.process_pll = linearNoBias(c_atompair, c_atompair)
self.project_pll = linearNoBias(c_atompair, c_z)
def forward(self, f: dict) -> dict:
"""Compute initial representations from input features."""
I = f.get("num_tokens", 100)
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype
s_init = torch.zeros(I, self.process_s_init[1].out_features, device=device, dtype=dtype)
z_init = torch.zeros(I, I, self.process_z_init[1].out_features, device=device, dtype=dtype)
return {"S_I": s_init, "Z_II": z_init}
class RFD3DiffusionModule(nn.Module):
"""
RFD3 Diffusion Module matching foundry checkpoint structure.
This module structure matches `model.diffusion_module.*` keys in the checkpoint.
"""
def __init__(
self,
c_s: int = 384,
c_z: int = 128,
c_atom: int = 128,
c_atompair: int = 16,
c_token: int = 768,
c_t_embed: int = 256,
sigma_data: float = 16.0,
n_pairformer_blocks: int = 2,
n_diffusion_blocks: int = 18,
n_atom_encoder_blocks: int = 3,
n_atom_decoder_blocks: int = 3,
n_head: int = 16,
n_pairformer_head: int = 16,
n_recycle: int = 2,
p_drop: float = 0.0,
):
super().__init__()
self.sigma_data = sigma_data
self.n_recycle = n_recycle
self.process_r = linearNoBias(3, c_atom)
self.to_r_update = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, 3))
self.sequence_head = LinearSequenceHead(c_token)
self.fourier_embedding = nn.ModuleList([
FourierEmbedding(c_t_embed),
FourierEmbedding(c_t_embed),
])
self.process_n = nn.ModuleList([
nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_atom)),
nn.Sequential(RMSNorm(c_t_embed), linearNoBias(c_t_embed, c_s)),
])
cross_attention_block = {
"n_head": 4,
"c_model": c_atom,
"dropout": p_drop,
"kq_norm": True,
}
self.downcast_c = Downcast(
c_atom=c_atom, c_token=c_s, c_s=None,
method="cross_attention", cross_attention_block=cross_attention_block
)
self.downcast_q = Downcast(
c_atom=c_atom, c_token=c_token, c_s=c_s,
method="cross_attention", cross_attention_block=cross_attention_block
)
self.process_a = LinearEmbedWithPool(c_token)
self.process_c = nn.Sequential(RMSNorm(c_atom), linearNoBias(c_atom, c_atom))
atom_transformer_block = {
"n_head": 4,
"dropout": p_drop,
"kq_norm": True,
}
self.encoder = LocalAtomTransformer(
c_atom=c_atom,
c_s=c_atom,
c_atompair=c_atompair,
atom_transformer_block=atom_transformer_block,
n_blocks=n_atom_encoder_blocks,
)
pairformer_block = {
"attention_pair_bias": {"n_head": n_pairformer_head, "kq_norm": False},
"n_transition": 4,
}
self.diffusion_token_encoder = DiffusionTokenEncoder(
c_s=c_s,
c_z=c_z,
c_token=c_token,
c_atompair=c_atompair,
n_pairformer_blocks=n_pairformer_blocks,
pairformer_block=pairformer_block,
)
diffusion_transformer_block = {
"n_head": n_head,
"dropout": p_drop,
"kq_norm": True,
}
self.diffusion_transformer = LocalTokenTransformer(
c_token=c_token,
c_tokenpair=c_z,
c_s=c_s,
diffusion_transformer_block=diffusion_transformer_block,
n_block=n_diffusion_blocks,
)
decoder_upcast = {
"method": "cross_attention",
"n_split": 3,
"cross_attention_block": cross_attention_block,
}
decoder_downcast = {
"method": "cross_attention",
"cross_attention_block": cross_attention_block,
}
self.decoder = CompactStreamingDecoder(
c_atom=c_atom,
c_atompair=c_atompair,
c_token=c_token,
c_s=c_s,
c_tokenpair=c_z,
atom_transformer_block=atom_transformer_block,
upcast=decoder_upcast,
downcast=decoder_downcast,
n_blocks=n_atom_decoder_blocks,
)
def scale_positions_in(self, x_noisy: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
if t.ndim == 1:
t = t[..., None, None]
elif t.ndim == 2:
t = t[..., None]
return x_noisy / torch.sqrt(t**2 + self.sigma_data**2)
def scale_positions_out(
self, r_update: torch.Tensor, x_noisy: torch.Tensor, t: torch.Tensor
) -> torch.Tensor:
if t.ndim == 1:
t = t[..., None, None]
elif t.ndim == 2:
t = t[..., None]
sigma2 = self.sigma_data**2
return (sigma2 / (sigma2 + t**2)) * x_noisy + (
self.sigma_data * t / torch.sqrt(sigma2 + t**2)
) * r_update
def process_time(self, t: torch.Tensor, idx: int) -> torch.Tensor:
t_clamped = torch.clamp(t, min=1e-20)
t_log = 0.25 * torch.log(t_clamped / self.sigma_data)
emb = self.process_n[idx](self.fourier_embedding[idx](t_log))
emb = emb * (t > 0).float()[..., None]
return emb
def compute_pair_features(self, xyz: torch.Tensor, c_atompair: int) -> torch.Tensor:
dist = torch.cdist(xyz, xyz)
inv_dist = 1 / (1 + dist**2)
return inv_dist.unsqueeze(-1).expand(-1, -1, -1, c_atompair)
class RFDiffusionTransformerModel(ModelMixin, ConfigMixin):
"""
RFDiffusion3 transformer for protein structure prediction.
This wraps the diffusion module to provide the full model interface.
The state dict keys match the foundry checkpoint format.
"""
config_name = "config.json"
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
c_s: int = 384,
c_z: int = 128,
c_atom: int = 128,
c_atompair: int = 16,
c_token: int = 768,
c_t_embed: int = 256,
sigma_data: float = 16.0,
n_pairformer_block: int = 2,
n_diffusion_block: int = 18,
n_atom_encoder_block: int = 3,
n_atom_decoder_block: int = 3,
n_head: int = 16,
n_pairformer_head: int = 16,
n_recycle: int = 2,
p_drop: float = 0.0,
):
super().__init__()
self.token_initializer = TokenInitializer(
c_s=c_s,
c_z=c_z,
c_atom=c_atom,
c_atompair=c_atompair,
)
self.diffusion_module = RFD3DiffusionModule(
c_s=c_s,
c_z=c_z,
c_atom=c_atom,
c_atompair=c_atompair,
c_token=c_token,
c_t_embed=c_t_embed,
sigma_data=sigma_data,
n_pairformer_blocks=n_pairformer_block,
n_diffusion_blocks=n_diffusion_block,
n_atom_encoder_blocks=n_atom_encoder_block,
n_atom_decoder_blocks=n_atom_decoder_block,
n_head=n_head,
n_pairformer_head=n_pairformer_head,
n_recycle=n_recycle,
p_drop=p_drop,
)
@property
def sigma_data(self) -> float:
return self.diffusion_module.sigma_data
def forward(
self,
xyz_noisy: torch.Tensor,
t: torch.Tensor,
atom_to_token_map: Optional[torch.Tensor] = None,
motif_mask: Optional[torch.Tensor] = None,
s_init: Optional[torch.Tensor] = None,
z_init: Optional[torch.Tensor] = None,
n_recycle: Optional[int] = None,
**kwargs,
) -> RFDiffusionTransformerOutput:
"""
Forward pass of the diffusion module.
Args:
xyz_noisy: Noisy atom coordinates [B, L, 3]
t: Noise level / timestep [B]
atom_to_token_map: Mapping from atoms to tokens [L]
motif_mask: Mask for fixed motif atoms [L]
s_init: Initial single representation [I, c_s]
z_init: Initial pair representation [I, I, c_z]
n_recycle: Number of recycling iterations
Returns:
RFDiffusionTransformerOutput with denoised coordinates
"""
dm = self.diffusion_module
B, L, _ = xyz_noisy.shape
if atom_to_token_map is None:
atom_to_token_map = torch.arange(L, device=xyz_noisy.device)
I = int(atom_to_token_map.max().item()) + 1
if motif_mask is None:
motif_mask = torch.zeros(L, dtype=torch.bool, device=xyz_noisy.device)
t_L = t[:, None].expand(B, L) * (~motif_mask).float()
t_I = t[:, None].expand(B, I)
r_scaled = dm.scale_positions_in(xyz_noisy, t)
r_noisy = dm.scale_positions_in(xyz_noisy, t_L)
if s_init is None or z_init is None:
init_output = self.token_initializer({"num_tokens": I})
if s_init is None:
s_init = init_output["S_I"]
if z_init is None:
z_init = init_output["Z_II"]
assert s_init is not None and z_init is not None
p = dm.compute_pair_features(r_scaled, self.config.c_atompair)
a_I = dm.process_a(r_noisy, tok_idx=atom_to_token_map)
s_init_expanded = s_init.unsqueeze(0).expand(B, -1, -1) if s_init.ndim == 2 else s_init
s_I = dm.downcast_c(torch.zeros(B, L, self.config.c_atom, device=xyz_noisy.device),
s_init_expanded,
tok_idx=atom_to_token_map)
q = dm.process_r(r_noisy)
c = dm.process_time(t_L, idx=0)
q = q + c
s_I = s_I + dm.process_time(t_I, idx=1)
c = c + dm.process_c(c)
q = dm.encoder(q, c, p)
a_I = dm.downcast_q(q, a_I, s_I, tok_idx=atom_to_token_map)
if n_recycle is None:
n_recycle = dm.n_recycle if not self.training else 1
n_recycle = max(1, n_recycle)
z_II = z_init
for _ in range(n_recycle):
s_I, z_II = dm.diffusion_token_encoder(s_init=s_I, z_init=z_II)
a_I = dm.diffusion_transformer(a_I, s_I, z_II)
a_I, q, _ = dm.decoder(a_I, s_I, z_II, q, c, p, tok_idx=atom_to_token_map)
r_update = dm.to_r_update(q)
xyz_out = dm.scale_positions_out(r_update, xyz_noisy, t_L)
sequence_logits, sequence_indices = dm.sequence_head(a_I)
return RFDiffusionTransformerOutput(
xyz=xyz_out,
single=s_I,
pair=z_II,
sequence_logits=sequence_logits,
sequence_indices=sequence_indices,
)